diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 259555e8fa4df..0000000000000 --- a/.flake8 +++ /dev/null @@ -1,28 +0,0 @@ -[flake8] -select = E -exclude = - ./build, - # Exclude third-party libraries - ./third_party/**, - ./python/paddle/utils/gast/**, -ignore = - # Whitespace before ‘,’, ‘;’, or ‘:’, it is not compatible with black - E203, - # Module level import not at top of file - E402, - # Line too long (82 > 79 characters) - E501, - # Do not compare types, use `isinstance()` - E721, - # Do not use bare except, specify exception instead - E722, - # Do not assign a lambda expression, use a def - E731, - # Do not use variables named ‘l’, ‘O’, or ‘I’ - E741 -per-file-ignores = - # These files need tabs for testing. - test/dygraph_to_static/test_legacy_error.py:E101 - - # Ignore compare with True in sot unittest - test/sot/test_dup_top.py:E712 diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 8a8c9c7fa1e50..8757059d30367 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,9 +1,13 @@ + -### PR types - -### PR changes - +### PR Category + + + +### PR Types + + ### Description diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41e77280a9f95..3d1ac6a170243 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,7 +36,7 @@ repos: # Exclude some unit test files that require tabs. exclude: | (?x)^( - test/dygraph_to_static/test_legacy_error.py + test/dygraph_to_static/test_error.py )$ - repo: local hooks: @@ -56,13 +56,8 @@ repos: hooks: - id: black files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ -- repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 - hooks: - - id: flake8 - args: ["--config=.flake8"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.2.0 + rev: v0.3.0 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --no-cache] diff --git a/CMakeLists.txt b/CMakeLists.txt index d5e260f323a0c..8f8c8cd616ab4 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -63,9 +63,11 @@ option(WITH_ONNXRUNTIME "Compile PaddlePaddle with ONNXRUNTIME" OFF) option(WITH_CUSPARSELT "Compile PaddlePaddle with CUSPARSELT" OFF) option(WITH_SETUP_INSTALL "Compile PaddlePaddle with setup.py" OFF) option(WITH_SHARED_PHI "Compile PaddlePaddle with SHARED LIB of PHI" ON) -option(CINN_ONLY "Compile CINN only in Paddle" OFF) option(CINN_WITH_CUDNN "Compile CINN with CUDNN support" ON) - +option(WITH_PIP_CUDA_LIBRARIES + "Paddle uses the CUDA library provided by NVIDIA" OFF) +option(WITH_NIGHTLY_BUILD + "Compile nightly paddle whl package of the develop branch" OFF) find_package(Git REQUIRED) # config GIT_URL with github mirrors to speed up dependent repos clone @@ -97,11 +99,16 @@ endif() if(WITH_GPU AND NOT APPLE) #(Note risemeup1): The cudart dynamic library libcudart.so is used by set CUDA_USE_STATIC_CUDA_RUNTIME and CMAKE_CUDA_FLAGS - if(LINUX) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux" AND CMAKE_SYSTEM_PROCESSOR STREQUAL + "x86_64") set(CUDA_USE_STATIC_CUDA_RUNTIME OFF CACHE BOOL "" FORCE) set(CMAKE_CUDA_FLAGS "--cudart shared") + if(WITH_PIP_CUDA_LIBRARIES) + #(Note risemeup1): Flag 'WITH_PIP_CUDA_LIBRARIES' will be used in dynamic_loader.cc to search for CUDA-related .so files through the Python libraries provided by NVIDIA. + add_definitions(-DWITH_PIP_CUDA_LIBRARIES) + endif() endif() enable_language(CUDA) message(STATUS "CUDA compiler: ${CMAKE_CUDA_COMPILER}, version: " @@ -135,7 +142,10 @@ endif() if(WIN32) option(MSVC_STATIC_CRT "use static C Runtime library by default" ON) message("Build static library of PHI") - set(CMAKE_SUPPRESS_REGENERATION ON) + # (Note xuxinyi04): If CMAKE_SUPPRESS_REGENERATION is OFF, which is default, then CMake adds a + # special target on which all other targets depend that checks the build system and optionally + # re-runs CMake to regenerate the build system when the target specification source changes. + set(CMAKE_SUPPRESS_REGENERATION OFF) set(CMAKE_STATIC_LIBRARY_PREFIX lib) set(WITH_SHARED_PHI OFF @@ -233,6 +243,8 @@ if(WIN32) "${${flag_var}} /ignore:4049 /ignore:4217 /ignore:4006 /ignore:4221") if(MSVC_STATIC_CRT) set(${flag_var} "${${flag_var}} /NODEFAULTLIB:MSVCRT.LIB") + else() + set(${flag_var} "${${flag_var}} /NODEFAULTLIB:LIBCMT.LIB") endif() endforeach() @@ -618,18 +630,6 @@ if(WITH_CINN) include(cmake/cinn.cmake) add_definitions(-DPADDLE_WITH_CINN) - - if(CINN_ONLY) - add_definitions(-DCINN_WITH_ONLY) - if(WITH_PYTHON) - add_subdirectory(python) - endif() - add_subdirectory(test) - if(NOT WITH_GFLAGS) - add_subdirectory(paddle/utils) - endif() - return() - endif() endif() #------------- cinn cmake config end -------------- diff --git a/cmake/ccache.cmake b/cmake/ccache.cmake index 08b6720416fe2..55ec609110314 100644 --- a/cmake/ccache.cmake +++ b/cmake/ccache.cmake @@ -11,8 +11,9 @@ if(NOT WIN32) # show statistics summary of ccache message("ccache version\t\t\t " ${ccache_version} "\n" ${cache_directory}) - set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ${CCACHE_PATH}) - set_property(GLOBAL PROPERTY RULE_LAUNCH_LINK ${CCACHE_PATH}) + set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_PATH}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PATH}) + set(CMAKE_CUDA_COMPILER_LAUNCHER ${CCACHE_PATH}) endif() elseif("${CMAKE_GENERATOR}" STREQUAL "Ninja") # (Note:zhouwei25) Only Ninja Generator can support sccache now diff --git a/cmake/cinn.cmake b/cmake/cinn.cmake index 0609b280aba3e..3b001ac0fe899 100644 --- a/cmake/cinn.cmake +++ b/cmake/cinn.cmake @@ -164,13 +164,13 @@ cinn_cc_library( isl ginac pybind + group_cluster + cinn_op_dialect ${jitify_deps}) add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) -if(NOT CINN_ONLY) - target_link_libraries(cinnapi op_dialect pir phi) - add_dependencies(cinnapi op_dialect pir phi) -endif() +target_link_libraries(cinnapi op_dialect pir phi) +add_dependencies(cinnapi op_dialect pir phi) target_link_libraries(cinnapi ${PYTHON_LIBRARIES}) @@ -183,11 +183,6 @@ if(WITH_MKL) endif() endif() -if(CINN_ONLY) - target_link_libraries(cinnapi common) - add_dependencies(cinnapi common) -endif() - if(WITH_GPU) target_link_libraries( cinnapi @@ -227,15 +222,17 @@ function(gen_cinncore LINKTYPE) schedule_desc_proto absl isl - ginac) + ginac + pybind + group_cluster + cinn_op_dialect + ${jitify_deps}) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB) add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps}) - if(NOT CINN_ONLY) - target_link_libraries(${CINNCORE_TARGET} op_dialect pir phi) - add_dependencies(${CINNCORE_TARGET} op_dialect pir phi) - endif() + target_link_libraries(${CINNCORE_TARGET} op_dialect pir phi) + add_dependencies(${CINNCORE_TARGET} op_dialect pir phi) - add_dependencies(${CINNCORE_TARGET} pybind) + # add_dependencies(${CINNCORE_TARGET} pybind) target_link_libraries(${CINNCORE_TARGET} ${PYTHON_LIBRARIES}) if(WITH_MKL) @@ -247,11 +244,6 @@ function(gen_cinncore LINKTYPE) endif() endif() - if(CINN_ONLY) - target_link_libraries(${CINNCORE_TARGET} common) - add_dependencies(${CINNCORE_TARGET} common) - endif() - if(WITH_GPU) target_link_libraries( ${CINNCORE_TARGET} @@ -261,16 +253,16 @@ function(gen_cinncore LINKTYPE) ${CUBLAS} ${CUDNN} ${CURAND} - ${CUSOLVER} - ${jitify_deps}) + ${CUSOLVER}) + # ${jitify_deps}) if(NVTX_FOUND) target_link_libraries(${CINNCORE_TARGET} ${CUDA_NVTX_LIB}) endif() endif() if(WITH_CUTLASS) - target_link_libraries(cinnapi cutlass) - add_dependencies(cinnapi cutlass) + target_link_libraries(${CINNCORE_TARGET} cutlass) + add_dependencies(${CINNCORE_TARGET} cutlass) endif() endfunction() diff --git a/cmake/coveralls.cmake b/cmake/coveralls.cmake index e8263e48af3aa..58b34df69019a 100644 --- a/cmake/coveralls.cmake +++ b/cmake/coveralls.cmake @@ -60,8 +60,8 @@ endfunction() if(WITH_COVERAGE) if(WITH_INCREMENTAL_COVERAGE) - # if *.h changed, generate coverage report totaly. - # if pybind.cc changed, generate coverage report totaly. + # if *.h changed, generate coverage report totally. + # if pybind.cc changed, generate coverage report totally. # Because if pybind.cc add '-g -O0 -fprofile-arcs -ftest-coverage' only, some testcase will fail. if((NOT ("$ENV{PADDLE_GIT_DIFF_H_FILE}" STREQUAL "")) OR ("$ENV{PADDLE_GIT_DIFF_CC_FILE}" MATCHES "pybind.cc")) diff --git a/cmake/coverallsGcovJsons.cmake b/cmake/coverallsGcovJsons.cmake index c31b2457c1742..c2b48615cef1a 100644 --- a/cmake/coverallsGcovJsons.cmake +++ b/cmake/coverallsGcovJsons.cmake @@ -248,7 +248,7 @@ foreach(GCOV_FILE ${GCOV_FILES}) # Instead of trying to parse the source from the # gcov file, simply read the file contents from the source file. # (Parsing it from the gcov is hard because C-code uses ; in many places - # which also happens to be the same as the CMake list delimeter). + # which also happens to be the same as the CMake list delimiter). file(READ ${GCOV_SRC_PATH} GCOV_FILE_SOURCE) string(REPLACE "\\" "\\\\" GCOV_FILE_SOURCE "${GCOV_FILE_SOURCE}") diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake index 81a7228629d25..e0a2a7eb34739 100644 --- a/cmake/cuda.cmake +++ b/cmake/cuda.cmake @@ -294,7 +294,7 @@ select_nvcc_arch_flags(NVCC_FLAGS_EXTRA NVCC_ARCH_BIN) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}") message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}") -# Set C++14 support +# Set C++17 support set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. diff --git a/cmake/experiments/cuda_module_loading_lazy.cmake b/cmake/experiments/cuda_module_loading_lazy.cmake index 281560c48a0c7..75276379fd227 100644 --- a/cmake/experiments/cuda_module_loading_lazy.cmake +++ b/cmake/experiments/cuda_module_loading_lazy.cmake @@ -13,7 +13,7 @@ # limitations under the License. # this file contains experimental build options for lazy cuda module loading -# cuda moduel lazy loading is supported by CUDA 11.7+ +# cuda module lazy loading is supported by CUDA 11.7+ # this experiment option makes Paddle supports lazy loading before CUDA 11.7. if(LINUX) diff --git a/cmake/phi_header.cmake b/cmake/export_paddle_header.cmake similarity index 52% rename from cmake/phi_header.cmake rename to cmake/export_paddle_header.cmake index ac633b747bcef..726103fd679b4 100644 --- a/cmake/phi_header.cmake +++ b/cmake/export_paddle_header.cmake @@ -15,33 +15,57 @@ set(PADDLE_INFERENCE_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_install_dir") -function(phi_header_path_compat TARGET_PATH) - message(STATUS "phi header path compat processing: ${TARGET_PATH}") +function(header_path_compat TARGET_PATH) + message(STATUS "header path compat processing: ${TARGET_PATH}") file(GLOB HEADERS "${TARGET_PATH}/*" "*.h") foreach(header ${HEADERS}) if(${header} MATCHES ".*.h$") file(READ ${header} HEADER_CONTENT) string(REPLACE "paddle/fluid/platform/" "paddle/phi/" HEADER_CONTENT "${HEADER_CONTENT}") + string(REPLACE "paddle/pir/include/" "paddle/pir/" HEADER_CONTENT + "${HEADER_CONTENT}") + string(REPLACE "paddle/fluid/pir/drr/include/" "paddle/pir/drr/" + HEADER_CONTENT "${HEADER_CONTENT}") + string(REPLACE "paddle/fluid/pir/utils/" "paddle/pir/utils/" + HEADER_CONTENT "${HEADER_CONTENT}") file(WRITE ${header} "${HEADER_CONTENT}") - message(STATUS "phi header path compat processing complete: ${header}") + message(STATUS "header path compat processing complete: ${header}") endif() endforeach() endfunction() -phi_header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle) -phi_header_path_compat( - ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi) -phi_header_path_compat( +header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle) +header_path_compat(${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi) +header_path_compat( ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/api) -phi_header_path_compat( +header_path_compat( ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/api/ext) -phi_header_path_compat( +header_path_compat( ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/api/include) -phi_header_path_compat( +header_path_compat( ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/common) -phi_header_path_compat( +header_path_compat( ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/core) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/core/parser) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/dialect/control_flow/ir +) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/dialect/shape/ir) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/dialect/shape/utils) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/drr) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/pass) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/pattern_rewrite) +header_path_compat( + ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/utils) # NOTE(liuyuanle): In inference lib, no need include paddle/utils/pybind.h, so we delete this. file(READ ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/extension.h diff --git a/cmake/external/cccl.cmake b/cmake/external/cccl.cmake index db09c01f92e74..18b9d010adde3 100755 --- a/cmake/external/cccl.cmake +++ b/cmake/external/cccl.cmake @@ -15,12 +15,18 @@ set(CCCL_INCLUDE_DIR ${CCCL_SOURCE_DIR}) message("CCCL_INCLUDE_DIR is ${CCCL_INCLUDE_DIR}") include_directories(${CCCL_INCLUDE_DIR}) +file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/cccl/util_device.cuh.patch + native_src) +set(CCCL_PATCH_COMMAND git checkout -- . && git checkout ${CCCL_TAG} && patch + -p1 -Nd ${CCCL_SOURCE_DIR} < ${native_src}) + ExternalProject_Add( extern_cccl ${EXTERNAL_PROJECT_LOG_ARGS} SOURCE_DIR ${CCCL_SOURCE_DIR} PREFIX ${CCCL_PREFIX_DIR} UPDATE_COMMAND "" + PATCH_COMMAND ${CCCL_PATCH_COMMAND} CONFIGURE_COMMAND "" BUILD_COMMAND "" INSTALL_COMMAND "" diff --git a/cmake/external/dirent.cmake b/cmake/external/dirent.cmake index 7bec37d5f1b7e..41d5de412c044 100644 --- a/cmake/external/dirent.cmake +++ b/cmake/external/dirent.cmake @@ -27,7 +27,9 @@ if((NOT DEFINED DIRENT_NAME) OR (NOT DEFINED DIRENT_URL)) set(DIRENT_URL "${GIT_URL}/tronkko/dirent/archive/refs/tags/1.23.2.tar.gz" CACHE STRING "" FORCE) - set(DIRENT_CACHE_FILENAME "1.23.2.tar.gz") + set(DIRENT_CACHE_FILENAME + "1.23.2.tar.gz" + CACHE STRING "" FORCE) endif() message(STATUS "DIRENT_NAME: ${DIRENT_NAME}, DIRENT_URL: ${DIRENT_URL}") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 8638d4bdc84b5..f36a51d9c1cd3 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -25,7 +25,7 @@ if(WIN32) elseif(LINUX) if(WITH_ROCM) # For HIPCC Eigen::internal::device::numeric_limits is not EIGEN_DEVICE_FUNC - # which will cause compiler error of using __host__ funciont + # which will cause compiler error of using __host__ function # in __host__ __device__ file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Meta.h native_src) file(TO_NATIVE_PATH ${SOURCE_DIR}/Eigen/src/Core/util/Meta.h native_dst) @@ -39,7 +39,7 @@ elseif(LINUX) endif() endif() -if(CMAKE_COMPILER_IS_GNUCC) +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/TensorRandom.h.patch tensor_random_header) # See: [Why calling some `git` commands before `patch`?] @@ -47,19 +47,11 @@ if(CMAKE_COMPILER_IS_GNUCC) git checkout -- . && git checkout ${EIGEN_TAG} && patch -Nd ${SOURCE_DIR}/unsupported/Eigen/CXX11/src/Tensor < ${tensor_random_header}) - execute_process(COMMAND ${CMAKE_C_COMPILER} -dumpfullversion -dumpversion - OUTPUT_VARIABLE GCC_VERSION) - string(REGEX MATCHALL "[0-9]+" GCC_VERSION_COMPONENTS ${GCC_VERSION}) - list(GET GCC_VERSION_COMPONENTS 0 GCC_MAJOR) - list(GET GCC_VERSION_COMPONENTS 1 GCC_MINOR) - set(GCC_VERSION "${GCC_MAJOR}.${GCC_MINOR}") - if(GCC_VERSION GREATER_EQUAL 12.0) - file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Complex.h.patch - complex_header) - set(EIGEN_PATCH_COMMAND - ${EIGEN_PATCH_COMMAND} && patch -Nd - ${SOURCE_DIR}/Eigen/src/Core/arch/SSE/ < ${complex_header}) - endif() + file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Complex.h.patch + complex_header) + set(EIGEN_PATCH_COMMAND + ${EIGEN_PATCH_COMMAND} && patch -Nd + ${SOURCE_DIR}/Eigen/src/Core/arch/SSE/ < ${complex_header}) endif() set(EIGEN_INCLUDE_DIR ${SOURCE_DIR}) diff --git a/cmake/external/flashattn.cmake b/cmake/external/flashattn.cmake index c8461f57a575a..86364e0ed67d1 100644 --- a/cmake/external/flashattn.cmake +++ b/cmake/external/flashattn.cmake @@ -98,6 +98,7 @@ ExternalProject_Add( -DCMAKE_CXX_FLAGS=${FLASHATTN_CXX_FLAGS} -DCMAKE_CXX_FLAGS_RELEASE=${FLASHATTN_CXX_FLAGS_RELEASE} -DCMAKE_CXX_FLAGS_DEBUG=${FLASHATTN_CXX_FLAGS_DEBUG} + -DCMAKE_CUDA_COMPILER_LAUNCHER=${CMAKE_CUDA_COMPILER_LAUNCHER} -DCMAKE_INSTALL_PREFIX=${FLASHATTN_INSTALL_DIR} -DWITH_GPU=${WITH_GPU} -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} diff --git a/cmake/external/gloo.cmake b/cmake/external/gloo.cmake index 529f72b662e3e..dcaab7e2842eb 100755 --- a/cmake/external/gloo.cmake +++ b/cmake/external/gloo.cmake @@ -16,82 +16,57 @@ include(ExternalProject) set(GLOO_PROJECT "extern_gloo") set(GLOO_PREFIX_DIR ${THIRD_PARTY_PATH}/gloo) -set(GLOO_SOURCE_DIR ${THIRD_PARTY_PATH}/gloo/src/extern_gloo) +set(GLOO_SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/gloo) set(GLOO_INSTALL_DIR ${THIRD_PARTY_PATH}/install/gloo) set(GLOO_INCLUDE_DIR - "${GLOO_INSTALL_DIR}/include" + ${GLOO_INSTALL_DIR}/include CACHE PATH "gloo include directory." FORCE) set(GLOO_LIBRARY_DIR - "${GLOO_INSTALL_DIR}/lib" + ${GLOO_INSTALL_DIR}/lib CACHE PATH "gloo library directory." FORCE) + # As we add extra features for gloo, we use the non-official repo set(GLOO_TAG v0.0.3) set(GLOO_LIBRARIES - "${GLOO_INSTALL_DIR}/lib/libgloo.a" + ${GLOO_INSTALL_DIR}/lib/libgloo.a CACHE FILEPATH "gloo library." FORCE) -set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/gloo) -set(GLOO_PATCH_COMMAND "") -if(WITH_GPU) - if(${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0 AND ${CMAKE_CXX_COMPILER_VERSION} - VERSION_GREATER 12.0) - file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/device.cc.patch - native_dst) - set(GLOO_PATCH_COMMAND - git checkout -- . && git checkout ${GLOO_TAG} && patch -Nd - ${SOURCE_DIR}/gloo/transport/tcp < ${native_dst}) - endif() -endif() -if(CMAKE_COMPILER_IS_GNUCC) - execute_process(COMMAND ${CMAKE_C_COMPILER} -dumpfullversion -dumpversion - OUTPUT_VARIABLE GCC_VERSION) - string(REGEX MATCHALL "[0-9]+" GCC_VERSION_COMPONENTS ${GCC_VERSION}) - list(GET GCC_VERSION_COMPONENTS 0 GCC_MAJOR) - list(GET GCC_VERSION_COMPONENTS 1 GCC_MINOR) - set(GCC_VERSION "${GCC_MAJOR}.${GCC_MINOR}") - if(GCC_VERSION GREATER_EQUAL "12.0") - file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/device.cc.patch - native_dst) - file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/types.h.patch - types_header) - # See: [Why calling some `git` commands before `patch`?] - set(GLOO_PATCH_COMMAND - git checkout -- . && git checkout ${GLOO_TAG} && patch -Nd - ${SOURCE_DIR}/gloo/transport/tcp < ${native_dst} && patch -Nd - ${SOURCE_DIR}/gloo/ < ${types_header}) - endif() -endif() +# Setup gloo patch command +set(GLOO_PATCH_COMMAND git checkout -- . && git checkout ${GLOO_TAG}) +file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/device.cc.patch + native_dst) +file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/types.h.patch + types_header) file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/gloo/linux.cc.patch linux_cc_ethtool) -if(GLOO_PATCH_COMMAND STREQUAL "") - set(GLOO_PATCH_COMMAND git checkout -- . && git checkout ${GLOO_TAG} && patch - -Nd ${SOURCE_DIR}/gloo/common/ < ${linux_cc_ethtool}) -else() - set(GLOO_PATCH_COMMAND ${GLOO_PATCH_COMMAND} && patch -Nd - ${SOURCE_DIR}/gloo/common/ < ${linux_cc_ethtool}) -endif() -include_directories(${GLOO_INCLUDE_DIR}) +# cmake-format: off +list(APPEND GLOO_PATCH_COMMAND + && patch -Nd ${GLOO_SOURCE_DIR}/gloo/transport/tcp < ${native_dst} + && patch -Nd ${GLOO_SOURCE_DIR}/gloo/ < ${types_header} + && patch -Nd ${GLOO_SOURCE_DIR}/gloo/common/ < ${linux_cc_ethtool}) +# cmake-format: on + +set(GLOO_CMAKE_C_FLAGS "-O3 -fPIC") +set(GLOO_CMAKE_CXX_FLAGS "-O3 -fPIC") ExternalProject_Add( ${GLOO_PROJECT} ${EXTERNAL_PROJECT_LOG_ARGS} - SOURCE_DIR ${SOURCE_DIR} - PREFIX "${GLOO_PREFIX_DIR}" - UPDATE_COMMAND "" + SOURCE_DIR ${GLOO_SOURCE_DIR} + PREFIX ${GLOO_PREFIX_DIR} PATCH_COMMAND ${GLOO_PATCH_COMMAND} - CONFIGURE_COMMAND "" - BUILD_COMMAND - mkdir -p ${GLOO_SOURCE_DIR}/build && cd ${GLOO_SOURCE_DIR}/build && cmake - ${SOURCE_DIR} -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} && ${CMAKE_COMMAND} - --build . && mkdir -p ${GLOO_LIBRARY_DIR} ${GLOO_INCLUDE_DIR}/glo - INSTALL_COMMAND ${CMAKE_COMMAND} -E copy - ${GLOO_SOURCE_DIR}/build/gloo/libgloo.a ${GLOO_LIBRARY_DIR} - COMMAND ${CMAKE_COMMAND} -E copy_directory "${SOURCE_DIR}/gloo/" - "${GLOO_INCLUDE_DIR}/gloo" + CMAKE_ARGS -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_PREFIX=${GLOO_INSTALL_DIR} + -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} + -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} + -DCMAKE_C_FLAGS=${GLOO_CMAKE_C_FLAGS} + -DCMAKE_CXX_FLAGS=${GLOO_CMAKE_CXX_FLAGS} BUILD_BYPRODUCTS ${GLOO_LIBRARIES}) add_library(gloo STATIC IMPORTED GLOBAL) set_property(TARGET gloo PROPERTY IMPORTED_LOCATION ${GLOO_LIBRARIES}) add_dependencies(gloo ${GLOO_PROJECT}) + +include_directories(${GLOO_INCLUDE_DIR}) diff --git a/cmake/external/lapack.cmake b/cmake/external/lapack.cmake index 62da0987085d1..2865dabdaccce 100644 --- a/cmake/external/lapack.cmake +++ b/cmake/external/lapack.cmake @@ -48,19 +48,34 @@ elseif(WIN32) set(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran-3.dll") set(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.dll") set(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.dll") -else() - set(LAPACK_FILE - "lapack_mac_v3.10.0.20210628.tar.gz" - CACHE STRING "" FORCE) - set(LAPACK_URL - "https://paddlepaddledeps.bj.bcebos.com/${LAPACK_FILE}" - CACHE STRING "" FORCE) - set(LAPACK_URL_MD5 427aecf8dee8523de3566ca8e47944d7) - set(GNU_RT_LIB_1 "${LAPACK_LIB_DIR}/libquadmath.0.dylib") - set(GNU_RT_LIB_2 "${LAPACK_LIB_DIR}/libgcc_s.1.dylib") - set(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran.5.dylib") - set(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.3.dylib") - set(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.3.dylib") +else() # MacOS + if(APPLE AND WITH_ARM) + set(LAPACK_FILE + "lapack_mac_arm64_v0.3.26.tar.gz" + CACHE STRING "" FORCE) + set(LAPACK_URL + "https://paddlepaddledeps.bj.bcebos.com/${LAPACK_FILE}" + CACHE STRING "" FORCE) + set(LAPACK_URL_MD5 3f6412105ae2b7465e5ee90c8673e6d4) + set(GNU_RT_LIB_1 "${LAPACK_LIB_DIR}/libquadmath.0.dylib") + set(GNU_RT_LIB_2 "${LAPACK_LIB_DIR}/libgcc_s.1.dylib") + set(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran.5.dylib") + set(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.dylib") + set(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.dylib") + else() + set(LAPACK_FILE + "lapack_mac_v3.10.0.20210628.tar.gz" + CACHE STRING "" FORCE) + set(LAPACK_URL + "https://paddlepaddledeps.bj.bcebos.com/${LAPACK_FILE}" + CACHE STRING "" FORCE) + set(LAPACK_URL_MD5 427aecf8dee8523de3566ca8e47944d7) + set(GNU_RT_LIB_1 "${LAPACK_LIB_DIR}/libquadmath.0.dylib") + set(GNU_RT_LIB_2 "${LAPACK_LIB_DIR}/libgcc_s.1.dylib") + set(GFORTRAN_LIB "${LAPACK_LIB_DIR}/libgfortran.5.dylib") + set(BLAS_LIB "${LAPACK_LIB_DIR}/libblas.3.dylib") + set(LAPACK_LIB "${LAPACK_LIB_DIR}/liblapack.3.dylib") + endif() endif() function(download_lapack) diff --git a/cmake/external/pslib.cmake b/cmake/external/pslib.cmake index d7de1aae86015..9800eab1e0992 100644 --- a/cmake/external/pslib.cmake +++ b/cmake/external/pslib.cmake @@ -69,7 +69,7 @@ ExternalProject_Add( -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${PSLIB_INSTALL_ROOT} -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE} - BUILD_BYPRODUCTS ${PSLIB_LIB}) + BUILD_BYPRODUCTS ${PSLIB_LIB} ${JVM_LIB}) add_library(pslib SHARED IMPORTED GLOBAL) set_property(TARGET pslib PROPERTY IMPORTED_LOCATION ${PSLIB_LIB}) diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake index b8ab55f604186..488540b3af295 100644 --- a/cmake/external/python.cmake +++ b/cmake/external/python.cmake @@ -16,7 +16,7 @@ include(python_module) check_py_version(${PY_VERSION}) -# Find Python with mnimum PY_VERSION specified or will raise error! +# Find Python with minimum PY_VERSION specified or will raise error! find_package(PythonInterp ${PY_VERSION} REQUIRED) find_package(PythonLibs ${PY_VERSION} REQUIRED) diff --git a/cmake/external/rocksdb.cmake b/cmake/external/rocksdb.cmake index 5bf2a896c47d3..072658e54705a 100644 --- a/cmake/external/rocksdb.cmake +++ b/cmake/external/rocksdb.cmake @@ -39,7 +39,7 @@ set(ROCKSDB_FLAGS "-DNDEBUG -DROCKSDB_JEMALLOC -DJEMALLOC_NO_DEMANGLE -DROCKSDB_PLATFORM_POSIX -DROCKSDB_LIB_IO_POSIX -DOS_LINUX -DROCKSDB_FALLOCATE_PRESENT -DHAVE_PCLMUL -DZLIB -DROCKSDB_MALLOC_USABLE_SIZE -DROCKSDB_PTHREAD_ADAPTIVE_MUTEX -DROCKSDB_BACKTRACE -DROCKSDB_SUPPORT_THREAD_LOCAL -DROCKSDB_USE_RTTI -DROCKSDB_SCHED_GETCPU_PRESENT -DROCKSDB_RANGESYNC_PRESENT -DROCKSDB_AUXV_GETAUXVAL_PRESENT" ) set(ROCKSDB_CMAKE_CXX_FLAGS - "${ROCKSDB_COMMON_FLAGS} -DROCKSDB_LIBAIO_PRESENT ${ROCKSDB_FLAGS} -fPIC -I${JEMALLOC_INCLUDE_DIR} -Wl,--no-as-needed -lz -ldl" + "${ROCKSDB_COMMON_FLAGS} -DROCKSDB_LIBAIO_PRESENT ${ROCKSDB_FLAGS} -fPIC -I${JEMALLOC_INCLUDE_DIR}" ) if(NOT WITH_ARM) set(ROCKSDB_FLAGS "${ROCKSDB_FLAGS} -DHAVE_SSE42") @@ -47,12 +47,14 @@ if(NOT WITH_ARM) "${ROCKSDB_CMAKE_CXX_FLAGS} -msse -msse4.2 -mpclmul") endif() set(ROCKSDB_CMAKE_C_FLAGS - "${ROCKSDB_COMMON_FLAGS} ${ROCKSDB_FLAGS} -DROCKSDB_LIBAIO_PRESENT -fPIC -I${JEMALLOC_INCLUDE_DIR}" + "${ROCKSDB_COMMON_FLAGS} ${ROCKSDB_FLAGS} -DROCKSDB_LIBAIO_PRESENT -fPIC -I${JEMALLOC_INCLUDE_DIR}" ) include_directories(${ROCKSDB_INCLUDE_DIR}) -set(CMAKE_CXX_LINK_EXECUTABLE - "${CMAKE_CXX_LINK_EXECUTABLE} -pthread -Wl,--no-as-needed -ldl -lrt -lz") +set(CMAKE_CXX_LINK_EXECUTABLE "${CMAKE_CXX_LINK_EXECUTABLE} -pthread") + +set(ROCKSDB_CMAKE_SHARED_LINKER_FLAGS "-ldl -lrt -lz") + if(WITH_ARM) file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/rocksdb/libaio.h.patch native_src) @@ -75,10 +77,12 @@ ExternalProject_Add( -DWITH_TESTS=OFF -DWITH_JEMALLOC=ON -DWITH_BENCHMARK_TOOLS=OFF + -DFAIL_ON_WARNINGS=OFF # For Clang compatibility -DJeMalloc_LIBRARIES=${JEMALLOC_LIBRARIES} -DJeMalloc_INCLUDE_DIRS=${JEMALLOC_INCLUDE_DIR} -DCMAKE_CXX_FLAGS=${ROCKSDB_CMAKE_CXX_FLAGS} - -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS} + -DCMAKE_C_FLAGS=${ROCKSDB_CMAKE_C_FLAGS} + -DCMAKE_SHARED_LINKER_FLAGS=${ROCKSDB_CMAKE_SHARED_LINKER_FLAGS} INSTALL_COMMAND mkdir -p ${ROCKSDB_INSTALL_DIR}/lib/ && cp ${ROCKSDB_PREFIX_DIR}/src/extern_rocksdb-build/librocksdb.a diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index e39923d703da9..5b8dd6e0ffe59 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -29,7 +29,7 @@ if(NOT DEFINED XPU_BASE_DATE) set(XPU_BASE_DATE "20240104") endif() if(NOT DEFINED XPU_XHPC_BASE_DATE) - set(XPU_XHPC_BASE_DATE "20240226") + set(XPU_XHPC_BASE_DATE "20240328") endif() set(XPU_XCCL_BASE_VERSION "1.1.8.1") if(NOT DEFINED XPU_XFT_BASE_VERSION) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index ee60dd1485818..5a40695202525 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -4,7 +4,7 @@ include(CheckCCompilerFlag) include(CheckCXXSymbolExists) include(CheckTypeSize) -function(CheckCompilerCXX14Flag) +function(check_compiler_cxx14_flag) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.4) message(FATAL_ERROR "Unsupported GCC version. GCC >= 5.4 required.") @@ -14,8 +14,7 @@ function(CheckCompilerCXX14Flag) "Found GCC ${CMAKE_CXX_COMPILER_VERSION} which is too high, recommended to use GCC 8.2" ) endif() - elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID - STREQUAL "Clang") + elseif(CMAKE_CXX_COMPILER_ID MATCHES "AppleClang|Clang") # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # Apple Clang is a different compiler than upstream Clang which has different version numbers. # https://gist.github.com/yamaya/2924292 @@ -33,7 +32,8 @@ function(CheckCompilerCXX14Flag) endif() endfunction() -checkcompilercxx14flag() +check_compiler_cxx14_flag() + if(NOT WIN32) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") else() @@ -158,6 +158,27 @@ if(NOT WIN32) -Wimplicit-fallthrough=0 # Warning in tinyformat.h ${fsanitize}) + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0) + set(COMMON_FLAGS ${COMMON_FLAGS} -Wno-error=deprecated-copy) + endif() + endif() + + if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(COMMON_FLAGS + ${COMMON_FLAGS} + -Wno-error=unknown-warning-option # For some unknown warning options in lower version clang + -Wno-error=unused-private-field + -Wno-error=unused-const-variable + -Wno-error=deprecated-copy-with-user-provided-copy # For three/five/zeros rule, clang + -Wno-error=deprecated-copy # Same above + -Wno-error=inconsistent-missing-override # For lots of warnings when not using override for virtual functions, clang + -Wno-error=bitwise-instead-of-logical # Warning in "unsupported/Eigen/CXX11/Tensor" + -Wno-error=overloaded-virtual # For some inconsistent virtual function signature, clang + -Wno-error=defaulted-function-deleted # header file from GLOO, clang + ) + endif() + if(WITH_IPU) set(COMMON_FLAGS ${COMMON_FLAGS} -Wno-sign-compare # Warnings in Popart -Wno-non-virtual-dtor # Warnings in Popart diff --git a/cmake/generic.cmake b/cmake/generic.cmake index c18e25fa84a64..d618c9667de83 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -613,7 +613,7 @@ function(paddle_test_build TARGET_NAME) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) target_link_libraries(${TARGET_NAME} ${PYTHON_LIBRARIES}) endif() - if(WITH_CINN AND NOT CINN_ONLY) + if(WITH_CINN) target_link_libraries(${TARGET_NAME} $ cinn_transforms) add_dependencies(${TARGET_NAME} cinnapi) diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index f4a8286985094..3005da8aea125 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -354,12 +354,54 @@ copy( SRCS ${PADDLE_SOURCE_DIR}/paddle/extension.h DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/) -# the include path of phi needs to be changed to adapt to inference api path +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/core/parser/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/core/parser/) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/core/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/core/) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/dialect/control_flow/ir/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/dialect/control_flow/ir/ +) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/dialect/shape/ir/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/dialect/shape/ir/ +) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/dialect/shape/utils/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/dialect/shape/utils/ +) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/pass/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/pass/) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/pir/include/pattern_rewrite/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/pattern_rewrite/ +) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/drr/include/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/drr/) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/utils/general_functions.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/pir/utils/) + +# the include path of paddle needs to be changed to adapt to inference api path add_custom_command( TARGET inference_lib_dist POST_BUILD - COMMAND ${CMAKE_COMMAND} -P "${PADDLE_SOURCE_DIR}/cmake/phi_header.cmake" - COMMENT "Change phi header include path to adapt to inference api path") + COMMAND ${CMAKE_COMMAND} -P + "${PADDLE_SOURCE_DIR}/cmake/export_paddle_header.cmake" + COMMENT "Change paddle header include path to adapt to inference api path") # CAPI inference library for only inference set(PADDLE_INFERENCE_C_INSTALL_DIR diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7b1987f1c3cf2..1713a2ea71626 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -102,42 +102,42 @@ function(register_cu_kernel TARGET) endforeach() endfunction() -# Just for those mkldnn kernels locating at "fluid/operators/mkldnn/", such as 'layer_norm_mkldnn_op.cc'. +# Just for those onednn kernels locating at "fluid/operators/onednn/", such as 'layer_norm_onednn_op.cc'. # Add other file modes if need in the future. -function(register_mkldnn_kernel TARGET) +function(register_onednn_kernel TARGET) set(options "") set(oneValueArgs "") set(multiValueArgs SRCS DEPS) - cmake_parse_arguments(register_mkldnn_kernel "${options}" "${oneValueArgs}" + cmake_parse_arguments(register_onednn_kernel "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - set(mkldnn_cc_srcs) + set(onednn_cc_srcs) set(op_common_deps operator op_registry phi layer common_infer_shape_functions) - foreach(mkldnn_src ${register_mkldnn_kernel_SRCS}) - if(${mkldnn_src} MATCHES ".*_mkldnn_op.cc$") - list(APPEND mkldnn_cc_srcs mkldnn/${mkldnn_src}) + foreach(onednn_src ${register_onednn_kernel_SRCS}) + if(${onednn_src} MATCHES ".*_onednn_op.cc$") + list(APPEND onednn_cc_srcs onednn/${onednn_src}) endif() endforeach() - list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) - if(${mkldnn_cc_srcs_len} EQUAL 0) + list(LENGTH onednn_cc_srcs onednn_cc_srcs_len) + if(${onednn_cc_srcs_len} EQUAL 0) message( FATAL_ERROR - "The MKLDNN kernel file of ${TARGET} should contains at least one *.*_mkldnn_op.cc file" + "The MKLDNN kernel file of ${TARGET} should contains at least one *.*_onednn_op.cc file" ) endif() if(WITH_MKLDNN) cc_library( ${TARGET} - SRCS ${mkldnn_cc_srcs} + SRCS ${onednn_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs") - foreach(mkldnn_src ${mkldnn_cc_srcs}) + foreach(onednn_src ${onednn_cc_srcs}) set(op_name "") - find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name) + find_register(${onednn_src} "REGISTER_OP_KERNEL" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n") endif() @@ -161,7 +161,7 @@ function(op_library TARGET) set(miopen_cu_srcs) set(CUDNN_FILE) set(MIOPEN_FILE) - set(mkldnn_cc_srcs) + set(onednn_cc_srcs) set(MKLDNN_FILE) set(op_common_deps operator op_registry phi layer common_infer_shape_functions) @@ -238,9 +238,9 @@ function(op_library TARGET) endif() endif() if(WITH_MKLDNN) - string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}") - if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mkldnn/${MKLDNN_FILE}.cc) - list(APPEND mkldnn_cc_srcs mkldnn/${MKLDNN_FILE}.cc) + string(REPLACE "_op" "_onednn_op" MKLDNN_FILE "${TARGET}") + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/onednn/${MKLDNN_FILE}.cc) + list(APPEND onednn_cc_srcs onednn/${MKLDNN_FILE}.cc) endif() endif() if(WITH_XPU) @@ -275,8 +275,8 @@ function(op_library TARGET) list(APPEND cudnn_cu_cc_srcs ${src}) elseif(WITH_GPU AND ${src} MATCHES ".*\\.cu.cc$") list(APPEND cu_cc_srcs ${src}) - elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$") - list(APPEND mkldnn_cc_srcs ${src}) + elseif(WITH_MKLDNN AND ${src} MATCHES ".*_onednn_op.cc$") + list(APPEND onednn_cc_srcs ${src}) elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$") list(APPEND xpu_cc_srcs ${src}) elseif(WITH_XPU_KP AND ${src} MATCHES ".*\\.xpu$") @@ -349,7 +349,7 @@ function(op_library TARGET) if(WITH_UNITY_BUILD AND op_library_UNITY) # Combine the cc and cu source files. compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${cu_cc_srcs} - ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs}) + ${cudnn_cu_cc_srcs} ${onednn_cc_srcs}) compose_unity_target_sources(${UNITY_TARGET} cu ${cudnn_cu_srcs} ${cu_srcs}) if(TARGET ${UNITY_TARGET}) @@ -369,7 +369,7 @@ function(op_library TARGET) nv_library( ${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${cudnn_cu_srcs} - ${mkldnn_cc_srcs} ${cu_srcs} + ${onednn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() elseif(WITH_ROCM) @@ -389,19 +389,19 @@ function(op_library TARGET) hip_library( ${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} - ${mkldnn_cc_srcs} ${hip_srcs} + ${onednn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) elseif(WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0) xpu_library( ${TARGET} - SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${xpu_kp_cc_srcs} + SRCS ${cc_srcs} ${onednn_cc_srcs} ${xpu_cc_srcs} ${xpu_kp_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) else() # Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`. if(WITH_UNITY_BUILD AND op_library_UNITY) # Combine the cc source files. compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} - ${mkldnn_cc_srcs} ${xpu_cc_srcs}) + ${onednn_cc_srcs} ${xpu_cc_srcs}) if(TARGET ${UNITY_TARGET}) # If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`. target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources}) @@ -417,7 +417,7 @@ function(op_library TARGET) else() cc_library( ${TARGET} - SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} + SRCS ${cc_srcs} ${onednn_cc_srcs} ${xpu_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) endif() endif() @@ -426,7 +426,7 @@ function(op_library TARGET) list(LENGTH hip_srcs hip_srcs_len) list(LENGTH cu_cc_srcs cu_cc_srcs_len) list(LENGTH hip_cc_srcs hip_cc_srcs_len) - list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len) + list(LENGTH onednn_cc_srcs onednn_cc_srcs_len) list(LENGTH xpu_cc_srcs xpu_cc_srcs_len) list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len) @@ -463,7 +463,7 @@ function(op_library TARGET) find_register(${cc_src} "REGISTER_OPERATOR" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") - # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn + # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in onednn set(TARGET ${op_name}) set(pybind_flag 1) endif() @@ -474,7 +474,7 @@ function(op_library TARGET) find_register(${cc_src} "REGISTER_ACTIVATION_OP" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") - # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn + # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in onednn set(TARGET ${op_name}) set(pybind_flag 1) endif() @@ -483,7 +483,7 @@ function(op_library TARGET) find_register(${cc_src} "REGISTER_OP_WITHOUT_GRADIENT" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_ITSELF(${op_name});\n") - # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in mkldnn + # hack: for example, the target in conv_transpose_op.cc is conv2d_transpose, used in onednn set(TARGET ${op_name}) set(pybind_flag 1) endif() @@ -494,10 +494,10 @@ function(op_library TARGET) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CPU);\n") # why change TARGET here? - # when building padle with on_infer, the REGISTER_OPERATOR(*_grad) will be removed before compiling (see details in remove_grad_op_and_kernel.py) + # when building paddle with on_infer, the REGISTER_OPERATOR(*_grad) will be removed before compiling (see details in remove_grad_op_and_kernel.py) # in elementwise_op.cc, it will find REGISTER_OPERATOR(grad_add) and set TARGET to grad_add - # and, in the following "mkldnn" part, it will add USE_OP_DEVICE_KERNEL(grad_add, MKLDNN) to pybind.h - # however, grad_add has no mkldnn kernel. + # and, in the following "onednn" part, it will add USE_OP_DEVICE_KERNEL(grad_add, MKLDNN) to pybind.h + # however, grad_add has no onednn kernel. set(TARGET ${op_name}) set(pybind_flag 1) endif() @@ -520,16 +520,16 @@ function(op_library TARGET) endif() endforeach() - # pybind USE_OP_DEVICE_KERNEL for operators/mkldnn/* - list(APPEND mkldnn_srcs ${mkldnn_cc_srcs}) - foreach(mkldnn_src ${mkldnn_srcs}) + # pybind USE_OP_DEVICE_KERNEL for operators/onednn/* + list(APPEND onednn_srcs ${onednn_cc_srcs}) + foreach(onednn_src ${onednn_srcs}) set(op_name "") # Add PHI Kernel Registry Message - find_phi_register(${mkldnn_src} ${pybind_file} "PD_REGISTER_KERNEL") - find_phi_register(${mkldnn_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL") - find_phi_register(${mkldnn_src} ${pybind_file} + find_phi_register(${onednn_src} ${pybind_file} "PD_REGISTER_KERNEL") + find_phi_register(${onednn_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL") + find_phi_register(${onednn_src} ${pybind_file} "PD_REGISTER_KERNEL_FOR_ALL_DTYPE") - find_register(${mkldnn_src} "REGISTER_OP_CUDA_KERNEL" op_name) + find_register(${onednn_src} "REGISTER_OP_CUDA_KERNEL" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, CUDA);\n") set(pybind_flag 1) @@ -610,14 +610,14 @@ function(op_library TARGET) endif() # pybind USE_OP_DEVICE_KERNEL for MKLDNN - if(WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0) + if(WITH_MKLDNN AND ${onednn_cc_srcs_len} GREATER 0) # Append first implemented MKLDNN activation operator - if(${MKLDNN_FILE} STREQUAL "activation_mkldnn_op") + if(${MKLDNN_FILE} STREQUAL "activation_onednn_op") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(softplus, MKLDNN);\n") else() - foreach(mkldnn_src ${mkldnn_cc_srcs}) + foreach(onednn_src ${onednn_cc_srcs}) set(op_name "") - find_register(${mkldnn_src} "REGISTER_OP_KERNEL" op_name) + find_register(${onednn_src} "REGISTER_OP_KERNEL" op_name) if(NOT ${op_name} EQUAL "") file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${op_name}, MKLDNN);\n") @@ -666,7 +666,7 @@ function(register_operators) GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc") - string(REPLACE "_mkldnn" "" OPS "${OPS}") + string(REPLACE "_onednn" "" OPS "${OPS}") string(REPLACE "_xpu" "" OPS "${OPS}") string(REPLACE ".cc" "" OPS "${OPS}") list(REMOVE_DUPLICATES OPS) diff --git a/cmake/simd.cmake b/cmake/simd.cmake index 3d730657062a0..676a25118303c 100644 --- a/cmake/simd.cmake +++ b/cmake/simd.cmake @@ -1,12 +1,10 @@ # This file is use to check all support level of AVX on your machine -# so that PaddlePaddle can unleash the vectorization power of muticore. +# so that PaddlePaddle can unleash the vectorization power of multicore. include(CheckCXXSourceRuns) include(CheckCXXSourceCompiles) -if(CMAKE_COMPILER_IS_GNUCC - OR CMAKE_COMPILER_IS_GNUCXX - OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") set(MMX_FLAG "-mmmx") set(SSE2_FLAG "-msse2") set(SSE3_FLAG "-msse3") diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 2d8020adcf7d0..9839f32f83c2b 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -15,6 +15,11 @@ include(ExternalProject) # Create a target named "third_party", which can compile external dependencies on all platform(windows/linux/mac) +# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24 +if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.24.0") + cmake_policy(SET CMP0135 NEW) +endif() + set(THIRD_PARTY_PATH "${CMAKE_BINARY_DIR}/third_party" CACHE STRING @@ -315,22 +320,6 @@ if(WITH_CINN) include(cmake/cinn/external/jitify.cmake) endif() -# cinn_only includes third-party libraries separately -if(CINN_ONLY) - include(external/gtest) - include(external/protobuf) - if(WITH_PYTHON) - include(external/pybind11) - endif() - if(WITH_MKL) - include(external/mklml) - endif() - if(WITH_MKLDNN) - include(external/mkldnn) - endif() - return() -endif() - include(external/eigen) # download eigen3 include(external/threadpool) # download threadpool include(external/dlpack) # download dlpack diff --git a/cmake/version.cmake b/cmake/version.cmake index e6707665a3851..28f022e0afa0e 100644 --- a/cmake/version.cmake +++ b/cmake/version.cmake @@ -1,5 +1,17 @@ # Get the latest git tag. set(PADDLE_VERSION $ENV{PADDLE_VERSION}) +if(WITH_NIGHTLY_BUILD) + execute_process( + COMMAND ${GIT_EXECUTABLE} show -s --format=%ci HEAD + WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} + OUTPUT_VARIABLE GIT_COMMIT_TIME + OUTPUT_STRIP_TRAILING_WHITESPACE) + string(REGEX REPLACE " (.*)$" "" DATE_ONLY "${GIT_COMMIT_TIME}") + string(REPLACE "-" "" DATE_ONLY "${DATE_ONLY}") + # Print the last commit date + message(STATUS "Last commit date: ${DATE_ONLY}") + set(PADDLE_VERSION "${PADDLE_VERSION}.dev${DATE_ONLY}") +endif() set(tmp_version "HEAD") set(TAG_VERSION_REGEX "[0-9]+\\.[0-9]+\\.[0-9]+(\\.(a|b|rc)\\.[0-9]+)?") set(COMMIT_VERSION_REGEX "[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+[0-9a-f]+") @@ -65,6 +77,7 @@ string(REPLACE "." ";" PADDLE_VER_LIST ${PADDLE_VER_LIST}) list(GET PADDLE_VER_LIST 0 PADDLE_MAJOR_VER) list(GET PADDLE_VER_LIST 1 PADDLE_MINOR_VER) list(GET PADDLE_VER_LIST 2 PADDLE_PATCH_VER) + math(EXPR PADDLE_VERSION_INTEGER "${PADDLE_MAJOR_VER} * 1000000 + ${PADDLE_MINOR_VER} * 1000 + ${PADDLE_PATCH_VER}") diff --git a/paddle/cinn/README.md b/paddle/cinn/README.md index 204feab7f2798..3d3517ccf7745 100644 --- a/paddle/cinn/README.md +++ b/paddle/cinn/README.md @@ -51,13 +51,7 @@ cd build Build paddle with cinn: ``` -cmake .. -DCINN_ONLY=OFF -DWITH_CINN=ON -DWITH_GPU=ON -``` - -Build cinn only: - -``` -cmake .. -DCINN_ONLY=ON -DWITH_CINN=ON -DWITH_GPU=ON +cmake .. -DWITH_CINN=ON -DWITH_GPU=ON ``` And then diff --git a/paddle/cinn/adt/CMakeLists.txt b/paddle/cinn/adt/CMakeLists.txt index 682e3931176b2..acbbb0f9a965f 100644 --- a/paddle/cinn/adt/CMakeLists.txt +++ b/paddle/cinn/adt/CMakeLists.txt @@ -1,44 +1,41 @@ -if(NOT CINN_ONLY) - add_subdirectory(print_utils) +add_subdirectory(print_utils) - core_gather_headers() +core_gather_headers() - gather_srcs( - cinnapi_src - SRCS - adapter_tensor.cc - anchor_sd_equation_context.cc - equation_function.cc - equation_solver.cc - equation_value.cc - generate_map_expr.cc - get_sub_reshape_dim_ranges.cc - igroup.cc - index_expr_infer_context.cc - kgroup.cc - m_ir.cc - naive_bidirection_equation_generator.cc - naive_op_equation_context.cc - partition_op_stmts.cc - schedule_descriptor.cc - schedule_dim.cc - schedule_mesh.cc - dim_expr.cc - simplify_value.cc - write_broadcast_disabled_bidirection_equation_generator.cc) +gather_srcs( + cinnapi_src + SRCS + adapter_tensor.cc + anchor_sd_equation_context.cc + equation_function.cc + equation_solver.cc + equation_value.cc + generate_map_expr.cc + get_sub_reshape_dim_ranges.cc + igroup.cc + index_expr_infer_context.cc + kgroup.cc + m_ir.cc + naive_bidirection_equation_generator.cc + naive_op_equation_context.cc + partition_op_stmts.cc + schedule_descriptor.cc + schedule_dim.cc + schedule_mesh.cc + dim_expr.cc + simplify_value.cc + write_broadcast_disabled_bidirection_equation_generator.cc) - cinn_cc_test(equation_value_match_trait_test SRCS - equation_value_match_trait_test.cc DEPS gtest glog) +cinn_cc_test(equation_value_match_trait_test SRCS + equation_value_match_trait_test.cc DEPS gtest glog) - cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) +cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog) - cinn_cc_test( - inline_translator_test - SRCS - inline_translator_test.cc - DEPS - gtest - glog - absl) - -endif() +cinn_cc_test( + inline_translator_test + SRCS + inline_translator_test.cc + DEPS + gtest + glog + absl) diff --git a/paddle/cinn/adt/adapter_dynamic_tensor.h b/paddle/cinn/adt/adapter_dynamic_tensor.h index d3610f654f218..fdecc71cfb71a 100644 --- a/paddle/cinn/adt/adapter_dynamic_tensor.h +++ b/paddle/cinn/adt/adapter_dynamic_tensor.h @@ -18,13 +18,13 @@ #include "paddle/cinn/adt/adt.h" #include "paddle/cinn/adt/dim_expr.h" #include "paddle/cinn/adt/symbolic_dim.h" -#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" namespace cinn::adt::adapter { struct DynamicTensor final { ::pir::Value node_data; - const hlir::framework::pir::Group* group; + const hlir::framework::pir::OpLoweringGroup* group; bool operator==(const DynamicTensor& other) const { return this->node_data == other.node_data; diff --git a/paddle/cinn/adt/adt.h b/paddle/cinn/adt/adt.h index 5af2a25cdd597..2ab5837d24a04 100644 --- a/paddle/cinn/adt/adt.h +++ b/paddle/cinn/adt/adt.h @@ -283,7 +283,7 @@ struct Ok final { bool operator!=(const Ok&) const { return false; } }; -#define ADT_TODO() LOG(FATAL) << "TODO" +#define ADT_TODO() PADDLE_THROW(phi::errors::Fatal("TODO")) inline std::size_t hash_combine(std::size_t lhs, std::size_t rhs) { return lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); diff --git a/paddle/cinn/adt/equation_solver.cc b/paddle/cinn/adt/equation_solver.cc index 90675fb3db161..b0eff3dc8355c 100644 --- a/paddle/cinn/adt/equation_solver.cc +++ b/paddle/cinn/adt/equation_solver.cc @@ -273,7 +273,8 @@ void CheckEquationsSolvable( [&](const auto& opt_old_value, const auto& simplified_value) { LOG(ERROR) << "old_value: " << ToTxtString(opt_old_value); LOG(ERROR) << "simplified_value: " << ToTxtString(simplified_value); - LOG(FATAL) << "CheckEquationsSolvable Failed"; + PADDLE_THROW( + phi::errors::InvalidArgument("CheckEquationsSolvable Failed")); return tValueInferSuccess{false}; }); }; diff --git a/paddle/cinn/adt/generate_map_expr.cc b/paddle/cinn/adt/generate_map_expr.cc index 339d68a3cbe59..ab5ffc28c17fe 100644 --- a/paddle/cinn/adt/generate_map_expr.cc +++ b/paddle/cinn/adt/generate_map_expr.cc @@ -109,8 +109,9 @@ bool HasDynamicShape(const ::pir::Value& tensor) { return false; } -List MakeOpStmtInputList(const ::pir::Operation* op, - const hlir::framework::pir::Group* group) { +List MakeOpStmtInputList( + const ::pir::Operation* op, + const hlir::framework::pir::OpLoweringGroup* group) { List ret{}; VisitEachInputTensor(op, [&](const ::pir::Value& tensor) { @@ -131,8 +132,9 @@ void VisitEachOutputTensor(const ::pir::Operation* op, const DoEachT& DoEach) { } } -List MakeOpStmtOutputList(const ::pir::Operation* op, - const hlir::framework::pir::Group* group) { +List MakeOpStmtOutputList( + const ::pir::Operation* op, + const hlir::framework::pir::OpLoweringGroup* group) { List ret{}; VisitEachOutputTensor(op, [&](const ::pir::Value& tensor) { @@ -147,9 +149,10 @@ List MakeOpStmtOutputList(const ::pir::Operation* op, } template -void VisitEachOpStmt(const std::shared_ptr& group, - const DoEachT& DoEach) { - for (const auto* op : group->CollectOps()) { +void VisitEachOpStmt( + const std::shared_ptr& group, + const DoEachT& DoEach) { + for (const auto* op : group->ops()) { DoEach(OpStmt{MakeOp(op), MakeOpStmtInputList(op, group.get()), MakeOpStmtOutputList(op, group.get())}); @@ -187,7 +190,7 @@ void CollectRewrittenOpStmts(const OpStmt& op_stmt, List* ret) { } List MakeOpStmts( - const std::shared_ptr& group) { + const std::shared_ptr& group) { List ret{}; VisitEachOpStmt(group, [&](const auto& op_stmt) { @@ -223,7 +226,7 @@ std::shared_ptr MakeIGroup(const AnchorGroup& igroup_spec) { } std::vector> GenerateIGroups( - const std::shared_ptr& group) { + const std::shared_ptr& group) { std::vector> ret{}; List op_stmts = MakeOpStmts(group); @@ -237,7 +240,7 @@ std::vector> GenerateIGroups( } std::shared_ptr GenerateKGroups( - const std::shared_ptr& group, + const std::shared_ptr& group, const std::vector>& igroups) { CHECK_EQ(igroups.size(), 1); return std::make_shared(group, igroups); @@ -352,7 +355,7 @@ Tensor GetAnchorTensor(const std::shared_ptr& igroup) { } template -void VisitInputTensor(const hlir::framework::pir::Group& group, +void VisitInputTensor(const hlir::framework::pir::OpLoweringGroup& group, const DoEachT& DoEach) { for (const ::pir::Value& node_data : group.GetInputOpValues()) { DoEach(node_data); @@ -360,7 +363,7 @@ void VisitInputTensor(const hlir::framework::pir::Group& group, } template -void VisitOutputTensor(const hlir::framework::pir::Group& group, +void VisitOutputTensor(const hlir::framework::pir::OpLoweringGroup& group, const DoEachT& DoEach) { for (const ::pir::Value& node_data : group.GetOutputOpValues()) { DoEach(node_data); @@ -444,7 +447,7 @@ MapExpr GenerateMapExpr(const std::shared_ptr& kgroup) { } // namespace MapExpr GenerateMapExpr( - const std::shared_ptr& group) { + const std::shared_ptr& group) { const auto& igroups = GenerateIGroups(group); const auto& kgroup = GenerateKGroups(group, igroups); @@ -453,13 +456,14 @@ MapExpr GenerateMapExpr( } void TryGenerateMapExprFromGroup( - const std::shared_ptr& fusion_group) { + const std::shared_ptr& + fusion_group) { if (!FLAGS_cinn_enable_map_expr) { return; } const auto& map_expr = GenerateMapExpr(fusion_group); VLOG(4) << "Generate MapExpr: \n" - << ToTxtString(map_expr, fusion_group->group_id); + << ToTxtString(map_expr, fusion_group->group_id()); fusion_group->set_map_expr_ctx(std::make_shared(map_expr)); } diff --git a/paddle/cinn/adt/generate_map_expr.h b/paddle/cinn/adt/generate_map_expr.h index 00dabaffbf899..a71fc031ae542 100644 --- a/paddle/cinn/adt/generate_map_expr.h +++ b/paddle/cinn/adt/generate_map_expr.h @@ -20,17 +20,16 @@ namespace cinn::hlir::framework::pir { -struct Group; -using GroupList = std::vector>; +struct OpLoweringGroup; } // namespace cinn::hlir::framework::pir namespace cinn::adt { MapExpr GenerateMapExpr( - const std::shared_ptr& group); + const std::shared_ptr& group); void TryGenerateMapExprFromGroup( - const std::shared_ptr& fusion_group); + const std::shared_ptr& fusion_group); } // namespace cinn::adt diff --git a/paddle/cinn/adt/get_sub_reshape_dim_ranges.cc b/paddle/cinn/adt/get_sub_reshape_dim_ranges.cc index f7f84a6e15e3a..8dc63e319e690 100644 --- a/paddle/cinn/adt/get_sub_reshape_dim_ranges.cc +++ b/paddle/cinn/adt/get_sub_reshape_dim_ranges.cc @@ -82,7 +82,7 @@ GetSubReshapeDimRanges(const List& lhs_dims, } else if (LhsAcc() > RhsAcc()) { rhs_end++; } else { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } CHECK(lhs_end == lhs_dims->size() && rhs_end == rhs_dims->size()); diff --git a/paddle/cinn/adt/igroup.cc b/paddle/cinn/adt/igroup.cc index 333721815d348..328d194c11ba2 100644 --- a/paddle/cinn/adt/igroup.cc +++ b/paddle/cinn/adt/igroup.cc @@ -102,10 +102,10 @@ List IGroup::GetIndexIterators(const Index& index) const { } else if (arg_pos.Has()) { // do nothing } else { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } - LOG(FATAL) << "Can not find anchor iterators"; + PADDLE_THROW(phi::errors::Fatal("Can not find anchor iterators")); } } // namespace cinn::adt diff --git a/paddle/cinn/adt/kgroup.h b/paddle/cinn/adt/kgroup.h index 0c536ddb1c654..e69f1dedd5b05 100644 --- a/paddle/cinn/adt/kgroup.h +++ b/paddle/cinn/adt/kgroup.h @@ -21,7 +21,7 @@ namespace cinn::hlir::framework::pir { -struct Group; +struct OpLoweringGroup; } // namespace cinn::hlir::framework::pir @@ -39,11 +39,11 @@ using cinn::adt::LoopDescriptors; class KGroup final { public: explicit KGroup( - const std::shared_ptr& cinn_group, + const std::shared_ptr& cinn_group, const std::vector>& igroups) : cinn_group_(cinn_group), igroups_(igroups) {} - std::shared_ptr cinn_group() const { + std::shared_ptr cinn_group() const { return CHECK_NOTNULL(cinn_group_.lock()); } @@ -58,7 +58,7 @@ class KGroup final { const std::shared_ptr& igroup) const; private: - std::weak_ptr cinn_group_; + std::weak_ptr cinn_group_; // NOTE: Use single igroup temporarily. Actually KGroup contains // multiple IGroups std::vector> igroups_; diff --git a/paddle/cinn/adt/m_ir.cc b/paddle/cinn/adt/m_ir.cc index 003b6880c813a..5e4ffabd71548 100644 --- a/paddle/cinn/adt/m_ir.cc +++ b/paddle/cinn/adt/m_ir.cc @@ -38,12 +38,12 @@ void CollectTensorIndexIterators(const TensorIndexExpr& tensor_index_expr, void CollectTensorIndexIteratorsImpl(const Undefined& tensor_index_expr, std::unordered_set* ret) { - LOG(FATAL) << "Not Implemented"; + PADDLE_THROW(phi::errors::Unimplemented("Not Implemented")); } void CollectTensorIndexIteratorsImpl(const Ok& ok, std::unordered_set* ret) { - LOG(FATAL) << "Not Implemented"; + PADDLE_THROW(phi::errors::Unimplemented("Not Implemented")); } void CollectTensorIndexIteratorsImpl(const Iterator& iterator, @@ -134,7 +134,7 @@ LoopIterators GetAnchorTensorLoopIterators( namespace { Tensor GetTensorImpl(const OpStmt& op_stmt, const Undefined& undefined) { - LOG(FATAL) << "position not found"; + PADDLE_THROW(phi::errors::Fatal("position not found")); } Tensor GetTensorImpl(const OpStmt& op_stmt, const tIn& pos) { diff --git a/paddle/cinn/adt/naive_op_equation_context.cc b/paddle/cinn/adt/naive_op_equation_context.cc index a65ba537a68bc..bc1dc11c7c3f9 100644 --- a/paddle/cinn/adt/naive_op_equation_context.cc +++ b/paddle/cinn/adt/naive_op_equation_context.cc @@ -240,7 +240,7 @@ std::optional GetArgDimSizeImpl( const Undefined&, const GetArgStaticDimT& GetInDim, const GetArgStaticDimT& GetOutDim) { - LOG(FATAL) << "position not found"; + PADDLE_THROW(phi::errors::Fatal("position not found")); } std::optional GetArgDimSize(const OpArgDimPos& arg_dim_pos, diff --git a/paddle/cinn/adt/print_utils/CMakeLists.txt b/paddle/cinn/adt/print_utils/CMakeLists.txt index 4f121de131477..0359ba721490a 100644 --- a/paddle/cinn/adt/print_utils/CMakeLists.txt +++ b/paddle/cinn/adt/print_utils/CMakeLists.txt @@ -1,15 +1,12 @@ -if(NOT CINN_ONLY) - core_gather_headers() +core_gather_headers() - gather_srcs( - cinnapi_src - SRCS - print_dim_expr.cc - print_equations.cc - print_map_expr.cc - print_schedule_descriptor.cc - print_schedule_dim.cc - print_schedule_mesh.cc - print_value.cc) - -endif() +gather_srcs( + cinnapi_src + SRCS + print_dim_expr.cc + print_equations.cc + print_map_expr.cc + print_schedule_descriptor.cc + print_schedule_dim.cc + print_schedule_mesh.cc + print_value.cc) diff --git a/paddle/cinn/adt/print_utils/print_map_expr.cc b/paddle/cinn/adt/print_utils/print_map_expr.cc index 5d57bd457aaa4..1548771f13962 100644 --- a/paddle/cinn/adt/print_utils/print_map_expr.cc +++ b/paddle/cinn/adt/print_utils/print_map_expr.cc @@ -71,7 +71,7 @@ std::string ToTxtStringImpl(const adapter::DynamicTensor& tensor) { } std::string ToTxtStringImpl(const TempStorage& tensor) { - LOG(FATAL) << "Not supported yet"; + PADDLE_THROW(phi::errors::Unimplemented("Not supported yet")); } } // namespace diff --git a/paddle/cinn/adt/schedule_dim.cc b/paddle/cinn/adt/schedule_dim.cc index 4205bebef1aeb..6cc9ee0e66fff 100644 --- a/paddle/cinn/adt/schedule_dim.cc +++ b/paddle/cinn/adt/schedule_dim.cc @@ -188,7 +188,7 @@ List GetReduceAxis(const List& loop_sizes) { } else if (sched_dim.Has>()) { // do nothing } else { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } return reduce_axis; @@ -203,7 +203,7 @@ List GetInjectiveAxis(const List& loop_sizes) { } else if (sched_dim.Has>()) { injective_axis->emplace_back(i); } else { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } return injective_axis; diff --git a/paddle/cinn/adt/schedule_mesh.cc b/paddle/cinn/adt/schedule_mesh.cc index 29665b918ed08..6fe319e09e992 100644 --- a/paddle/cinn/adt/schedule_mesh.cc +++ b/paddle/cinn/adt/schedule_mesh.cc @@ -370,7 +370,8 @@ std::tuple> CreateOptimizedScheduleMesh( return policy->Optimize(loop_sizes); } } - LOG(FATAL) << "Dead code, no valid schedule mesh policy found"; + PADDLE_THROW( + phi::errors::Fatal("Dead code, no valid schedule mesh policy found")); } ScheduleMesh MeshReshape(const ScheduleMesh& sched_mesh, diff --git a/paddle/cinn/adt/simplify_value.cc b/paddle/cinn/adt/simplify_value.cc index ccd42e891525e..07420e7e64743 100644 --- a/paddle/cinn/adt/simplify_value.cc +++ b/paddle/cinn/adt/simplify_value.cc @@ -21,7 +21,7 @@ #include "paddle/cinn/adt/index_expr_infer_context.h" #include "paddle/cinn/adt/match.h" #include "paddle/cinn/adt/simplify_value.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace cinn::adt { @@ -67,7 +67,7 @@ struct SimplifyRedundantBroadcastedIterator { const auto& simplified_bd = DimExpr{symbol::SimplifyDimExpr(bd)}; return BroadcastedIterator{inner_iterator, simplified_bd}; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } }; @@ -368,7 +368,7 @@ struct SymbolicDim_SimplifyDotUndot { return IndexDotValue>{ SimplifyValue(list_get_item_values, ctx), dot_dims}; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } }; @@ -415,7 +415,7 @@ struct SymbolicDim_SimplifyDotUndot_DimExpr { return IndexDotValue>{ SimplifyValue(list_get_item_values, ctx), dot_dims}; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } }; diff --git a/paddle/cinn/adt/tree.h b/paddle/cinn/adt/tree.h index 9dfc4d66d31c4..0e93e45672053 100644 --- a/paddle/cinn/adt/tree.h +++ b/paddle/cinn/adt/tree.h @@ -15,9 +15,9 @@ #pragma once #include - #include "paddle/cinn/adt/adt.h" #include "paddle/cinn/adt/tags.h" +#include "paddle/common/enforce.h" namespace cinn::adt { @@ -144,7 +144,7 @@ List MergeTwoInnerTreeImpl( List{new_lhs, new_rhs}); return List{ret}; } else { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } diff --git a/paddle/cinn/api/op_topo_pattern.h b/paddle/cinn/api/op_topo_pattern.h new file mode 100644 index 0000000000000..34f17fbfde9e0 --- /dev/null +++ b/paddle/cinn/api/op_topo_pattern.h @@ -0,0 +1,77 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace cinn::api { + +template +struct ErrorPattern {}; + +// ElementWise/Broadcast/Injective Ops without reduction ancestors. +template +struct InjectiveSourcePattern {}; + +// Reduce op +template +struct SingleReductionOpPattern {}; + +// ElementWise/Broadcast ops which have shardable dimentions and reduction +// ancestors. +template +struct PartialShardablePattern {}; + +// Reduce base pattern +template +struct ReductionPattern { + using Nothing = std::monostate; + std::variant, PartialShardablePattern> + input; + SingleReductionOpPattern reduce_op_pattern; + + bool HasFusedInput() const { + return !std::holds_alternative(this->input); + } +}; + +// Stmt := IS | R | PS +// ops in StmtPattern will be lowered into a inlined cuda code. +template +using StmtPattern = std::variant, + ReductionPattern, + PartialShardablePattern>; + +// Stmts := [Stmt] +template +using StmtPatternVec = std::vector>; +// fuse rules: +// 1. IS * IS -> IS +// 2. PS * PS -> PS +// 3. IS * PS -> PS +// 4. IS * R -> R +// 5. PS * R -> R +// lifting rules: +// 1. R -> Stmts +// 2. PS -> Stmts +// 3. Stmts * Stmts -> Stmts +// OpTopoPattern := Error | Stmts + +template +using OpTopoPattern = std::variant, StmtPatternVec>; + +} // namespace cinn::api diff --git a/paddle/cinn/ast_gen_ius/ast_gen.cc b/paddle/cinn/ast_gen_ius/ast_gen.cc index 009158d3f9cce..45923624945d0 100644 --- a/paddle/cinn/ast_gen_ius/ast_gen.cc +++ b/paddle/cinn/ast_gen_ius/ast_gen.cc @@ -22,6 +22,7 @@ #include "paddle/cinn/optim/replace_var_with_expr.h" PD_DECLARE_bool(cinn_new_group_scheduler); +PD_DECLARE_bool(group_schedule_tiling_first); PD_DECLARE_bool(cinn_bucket_compile); namespace cinn { @@ -93,9 +94,14 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { std::vector iter_values; // reduce body and reduce init schedule block should have different objects // for same axis so we re-create objects + VLOG(4) << "FLAGS_group_schedule_tiling_first = " + << FLAGS_group_schedule_tiling_first; std::vector axis_vars = cinn::common::GenDefaultAxis(axis_len); + const std::vector& reduce_axis = tensor->reduce_axis; + VLOG(4) << "ast gen: tensor init_body is " << init_body; for (int i = 0; i < shape.size(); ++i) { - if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { optim::ReplaceVarWithExpr(&init_body, axis[i], Expr(0)); continue; } @@ -105,21 +111,25 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { /*is_reduce = */ false)); optim::ReplaceVarWithExpr(&init_body, axis[i], block_vars.back()); axis_vars[i]->is_reduce_axis = false; - if (shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { iter_values.push_back(Expr(0)); } else { iter_values.push_back(axis_vars[i]); } } + VLOG(4) << "iter_value.size() and block_vars.size() is " + << iter_values.size() << " " << block_vars.size(); init_body = ir::ScheduleBlockRealize::Make( iter_values, ir::ScheduleBlock::Make( block_vars, {}, {}, reduce_init_name, init_body)); // For the remaining reduce axis, make reduce body - const std::vector& reduce_axis = tensor->reduce_axis; ir::Expr reduce_body = ConvertReduceBody(tensor->body(), tensor, axis_exprs); + + VLOG(4) << "ast gen: reduce body is " << reduce_body; + // create schedule block itervars, i0,i1... std::vector reduce_block_vars; std::vector reduce_iter_values; @@ -127,7 +137,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // for same axis so we re-create objects std::vector reduce_axis_vars = cinn::common::GenDefaultAxis(axis_len); for (int i = 0; i < shape.size(); ++i) { - if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { optim::ReplaceVarWithExpr(&reduce_body, axis[i], Expr(0)); continue; } @@ -136,12 +147,13 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { cinn::UniqName("i" + std::to_string(i)), /*is_reduce = */ false)); reduce_axis_vars[i]->is_reduce_axis = false; - if (shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { reduce_iter_values.push_back(Expr(0)); } else { reduce_iter_values.push_back(axis_vars[i]); } } + VLOG(4) << "ast gen: reduce body is after replace 0" << reduce_body; for (int i = 0; i < reduce_axis.size(); ++i) { int count = shape.size() + i; reduce_block_vars.push_back( @@ -155,14 +167,40 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { } int non_zero_axis_size = 0; - for (int i = 0; i < axis.size(); ++i) { - if (FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { - continue; + if (FLAGS_group_schedule_tiling_first) { + std::vector non_reduce_axis_vars = [&]() { + std::vector res; + for (int i = 0; i < shape.size(); ++i) { + res.push_back(axis[i]); + } + return res; + }(); + for (int i = 0; i < non_reduce_axis_vars.size(); ++i) { + optim::ReplaceVarWithExpr( + &reduce_body, non_reduce_axis_vars[i], reduce_block_vars[i]); + ++non_zero_axis_size; } - optim::ReplaceVarWithExpr( - &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); - ++non_zero_axis_size; + } else { + for (int i = 0; i < axis.size(); ++i) { + if (!FLAGS_group_schedule_tiling_first && + FLAGS_cinn_new_group_scheduler && shape[i] == Expr(1)) { + continue; + } + optim::ReplaceVarWithExpr( + &reduce_body, axis[i], reduce_block_vars[non_zero_axis_size]); + ++non_zero_axis_size; + } + } + + VLOG(4) << "to replace : " << non_zero_axis_size << " " + << reduce_block_vars.size(); + for (auto i = 0; i < reduce_block_vars.size(); i++) { + VLOG(4) << "reduce_block_vars[" << i << "] = " << reduce_block_vars[i]; + } + for (auto i = 0; i < reduce_axis.size(); i++) { + VLOG(4) << "reduce_axis[" << i << "] = " << reduce_axis[i]; } + VLOG(4) << "before replace body: " << reduce_body; for (int i = non_zero_axis_size; i < reduce_block_vars.size(); ++i) { optim::ReplaceVarWithExpr(&reduce_body, reduce_axis[i - non_zero_axis_size], @@ -185,7 +223,8 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { // Put the two parts together ir::Expr body = ir::Block::Make({init_body, reduce_body}); for (int i = static_cast(axis_len) - 1; i >= 0; --i) { - if (!FLAGS_cinn_bucket_compile && shape[i] == Expr(1)) { + if ((!FLAGS_group_schedule_tiling_first || !FLAGS_cinn_bucket_compile) && + shape[i] == Expr(1)) { continue; } ir::Var loop_var = axis[i]; @@ -210,7 +249,7 @@ ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) { Expr(0), shape[i], cinn::UniqName("i" + std::to_string(i)), false)); optim::ReplaceVarWithExpr(&body, axis[i], block_vars[i]); axis_vars[i]->is_reduce_axis = false; - if (shape[i] == Expr(1)) { + if (!FLAGS_group_schedule_tiling_first && shape[i] == Expr(1)) { iter_values.push_back(Expr(0)); } else { iter_values.push_back(axis_vars[i]); diff --git a/paddle/cinn/auto_schedule/database/database.cc b/paddle/cinn/auto_schedule/database/database.cc index 24d071a7df4e1..2036b44a83fef 100644 --- a/paddle/cinn/auto_schedule/database/database.cc +++ b/paddle/cinn/auto_schedule/database/database.cc @@ -54,7 +54,7 @@ std::unique_ptr Database::Make(const DatabaseConfig& config) { config.capacity_per_task, config.record_file_path, true); } - LOG(FATAL) << "Unimplemented database type."; + PADDLE_THROW(phi::errors::Unimplemented("Unimplemented database type.")); return nullptr; } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h index 90963e831075c..15422b1803e31 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/reduction_factoring.h @@ -36,7 +36,8 @@ class ReductionFactoring : public AutoGenRule { } // In the future, we will no longer use this interface. void Apply(int index) override { - LOG(FATAL) << "This is a deprecated interface, please do not use it."; + PADDLE_THROW(phi::errors::InvalidArgument( + "This is a deprecated interface, please do not use it.")); return; } diff --git a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc index 67d4c4ae3a0f7..994027dba0ee4 100644 --- a/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc +++ b/paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc @@ -145,7 +145,7 @@ void MemoryCopy(const float* src, float* dst, int numel, std::string type) { dst[i] = src[i]; } } else { - LOG(FATAL) << "Unknown memory copy type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown memory copy type")); } } diff --git a/paddle/cinn/auto_schedule/search_space/block_sampler.cc b/paddle/cinn/auto_schedule/search_space/block_sampler.cc index 26b00d3a89fb3..93de31e6a5e36 100644 --- a/paddle/cinn/auto_schedule/search_space/block_sampler.cc +++ b/paddle/cinn/auto_schedule/search_space/block_sampler.cc @@ -40,7 +40,9 @@ std::unique_ptr BlockSampler::Make( all_blocks, default_remove_policy, rand_seed, weights); } - LOG(FATAL) << "Unimplemented strategy:" << strategy; + std::stringstream ss; + ss << "Unimplemented strategy:" << strategy; + PADDLE_THROW(phi::errors::Unimplemented(ss.str())); return nullptr; } diff --git a/paddle/cinn/auto_schedule/search_space/rule_sampler.cc b/paddle/cinn/auto_schedule/search_space/rule_sampler.cc index 500ae91deb89b..3c0868d0748e5 100644 --- a/paddle/cinn/auto_schedule/search_space/rule_sampler.cc +++ b/paddle/cinn/auto_schedule/search_space/rule_sampler.cc @@ -35,7 +35,9 @@ std::unique_ptr RuleSampler::Make( potential_rules, default_remove_policy, rand_seed, weights); } - LOG(FATAL) << "Unimplemented strategy:" << strategy; + std::stringstream ss; + ss << "Unimplemented strategy:" << strategy; + PADDLE_THROW(phi::errors::Unimplemented(ss.str())); return nullptr; } diff --git a/paddle/cinn/auto_schedule/search_space/search_space.cc b/paddle/cinn/auto_schedule/search_space/search_space.cc index eb672a78a6521..650e1d572f831 100644 --- a/paddle/cinn/auto_schedule/search_space/search_space.cc +++ b/paddle/cinn/auto_schedule/search_space/search_space.cc @@ -261,7 +261,8 @@ std::vector SearchSpace::GenerateSketches( } else if (strategy == "random_prune") { sketches = InitSketchWithRandomPrunedStrategy(); } else { - LOG(FATAL) << "Unimplemented init sketch strategy"; + PADDLE_THROW( + phi::errors::Unimplemented("Unimplemented init sketch strategy")); } // the more rules are applied, the greater the possibility of good results, diff --git a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc index 9d41301df614c..94fedc9f021e0 100644 --- a/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc +++ b/paddle/cinn/auto_schedule/search_strategy/mutate_rule/mutate_rule.cc @@ -23,7 +23,9 @@ std::unique_ptr MutateRule::Make(const std::string& name) { if (name == "mutate_tile_size") { return std::make_unique(); } else { - LOG(FATAL) << "MutateRule " << name << " is not supported."; + std::stringstream ss; + ss << "MutateRule " << name << " is not supported."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return nullptr; } diff --git a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc index eed2ad3d66970..a8961e45b980d 100644 --- a/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc +++ b/paddle/cinn/auto_schedule/task_scheduler/task_scheduler.cc @@ -34,7 +34,9 @@ std::unique_ptr TaskScheduler::Make( return std::make_unique(tasks, config); } - LOG(FATAL) << "Unimplemented strategy:" << strategy; + std::stringstream ss; + ss << "Unimplemented strategy:" << strategy; + PADDLE_THROW(phi::errors::Unimplemented(ss.str())); return nullptr; } diff --git a/paddle/cinn/backends/CMakeLists.txt b/paddle/cinn/backends/CMakeLists.txt index 3242ef2577b48..c746886a43d9b 100755 --- a/paddle/cinn/backends/CMakeLists.txt +++ b/paddle/cinn/backends/CMakeLists.txt @@ -59,14 +59,10 @@ if(WITH_CUDA) cinn_nv_test(test_codegen_debug SRCS codegen_debug_test.cc DEPS cinncore) if(WITH_TESTING) - if(CINN_ONLY) - cinn_nv_test(generated1_cuda SRCS generated1.cu DEPS cinncore) - else() - nv_test( - generated1_cuda - SRCS generated1.cu - DEPS cinncore) - endif() + nv_test( + generated1_cuda + SRCS generated1.cu + DEPS cinncore) add_run_test_dependency(generated1_cuda test_codegen_cuda_generate) endif() diff --git a/paddle/cinn/backends/codegen_c.cc b/paddle/cinn/backends/codegen_c.cc index ca80bcdddd0c0..85443b02c0a8c 100644 --- a/paddle/cinn/backends/codegen_c.cc +++ b/paddle/cinn/backends/codegen_c.cc @@ -76,7 +76,7 @@ std::string CodeGenC::Compile(const ir::Module &module, Compile(func); } } else { - LOG(FATAL) << "Not supported OutputKind"; + PADDLE_THROW(phi::errors::Unimplemented("Not supported OutputKind")); } return str_; } @@ -434,30 +434,37 @@ void CodeGenC::Visit(const ir::_Module_ *op) { CINN_NOT_IMPLEMENTED } void CodeGenC::Visit(const ir::_Var_ *op) { str_ += op->name; } void CodeGenC::Visit(const ir::Load *op) { - Expr dense_strided_ramp = detail::StridedRampBase(op->index(), 1); + ir::Expr offset = [&] { + if (load_to_offset_.count(op) == 0) { + load_to_offset_[op] = op->index(); + } + return load_to_offset_.at(op); + }(); + + Expr dense_strided_ramp = detail::StridedRampBase(offset, 1); if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address. CHECK(op->type().is_vector()); - PrintStackVecType(op->type().ElementOf(), op->index().type().lanes()); + PrintStackVecType(op->type().ElementOf(), offset.type().lanes()); str_ += "::"; str_ += "Load("; str_ += op->tensor.As()->name; str_ += ","; IrPrinter::Visit(dense_strided_ramp); str_ += ")"; - } else if (op->index().type().is_vector()) { + } else if (offset.type().is_vector()) { // gather CHECK(op->type().is_vector()); - PrintStackVecType(op->type().ElementOf(), op->index().type().lanes()); + PrintStackVecType(op->type().ElementOf(), offset.type().lanes()); str_ += "::Load("; str_ += op->tensor.As()->name; str_ += ","; - IrPrinter::Visit(op->index()); + IrPrinter::Visit(offset); str_ += ")"; } else if (op->is_addr_tensor()) { auto *tensor = op->tensor.As(); str_ += tensor->name; str_ += "["; - IrPrinter::Visit(op->index()); + IrPrinter::Visit(offset); str_ += "]"; } else { IrPrinter::Visit(op); @@ -466,12 +473,17 @@ void CodeGenC::Visit(const ir::Load *op) { void CodeGenC::Visit(const ir::Store *op) { CHECK(op->is_addr_tensor()); - + ir::Expr offset = [&] { + if (store_to_offset_.count(op) == 0) { + store_to_offset_[op] = op->index(); + } + return store_to_offset_.at(op); + }(); auto *tensor = op->tensor.As(); CHECK(tensor); str_ += tensor->name; str_ += "["; - IrPrinter::Visit(op->index()); + IrPrinter::Visit(offset); str_ += "]"; str_ += " = "; IrPrinter::Visit(op->value); @@ -526,8 +538,9 @@ void CodeGenC::Visit(const ir::Let *op) { } void CodeGenC::Visit(const ir::Reduce *op) { - LOG(FATAL) << "Reduce IR is just for internal representation, should not be " - "used for CodeGen."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Reduce IR is just for internal representation, should not be " + "used for CodeGen.")); } void CodeGenC::Visit(const ir::Ramp *op) { @@ -731,7 +744,8 @@ void CodeGenC::PrintRuntimeType(const cinn_type_t &type) { } else if (type == cinn_float64_t()) { str_ += "cinn_float64_t()"; } else { - LOG(FATAL) << "Unknown type is not supported to print"; + PADDLE_THROW( + phi::errors::InvalidArgument("Unknown type is not supported to print")); } } @@ -806,7 +820,9 @@ void CodeGenC::Visit(const ir::intrinsics::PodValueToX *op) { } else if (to_type == type_of()) { str_ += runtime::intrinsic::pod_value_to_buffer_p; } else { - LOG(FATAL) << "Not supported type: " << to_type; + std::stringstream ss; + ss << "Not supported type: " << to_type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } str_ += "("; diff --git a/paddle/cinn/backends/codegen_c.h b/paddle/cinn/backends/codegen_c.h index c50c85741ce56..2904bef80beea 100644 --- a/paddle/cinn/backends/codegen_c.h +++ b/paddle/cinn/backends/codegen_c.h @@ -118,6 +118,8 @@ class CodeGenC : public ir::IrPrinter { Target target_; std::stringstream ss_; bool inline_builtin_codes_{true}; + std::unordered_map store_to_offset_; + std::unordered_map load_to_offset_; }; namespace detail { diff --git a/paddle/cinn/backends/codegen_c_test.cc b/paddle/cinn/backends/codegen_c_test.cc index 91f80c190f0f8..61adad6ade461 100644 --- a/paddle/cinn/backends/codegen_c_test.cc +++ b/paddle/cinn/backends/codegen_c_test.cc @@ -61,9 +61,9 @@ TEST(CodeGenC, module) { LOG(INFO) << "C.body: " << C->get_compute_op()->body.front(); Target target; - target.arch = Target::Arch ::X86; - target.bits = Target::Bit ::k32; - target.os = Target::OS ::Linux; + target.arch = Target::Arch::X86; + target.bits = Target::Bit::k32; + target.os = Target::OS::Linux; Module::Builder builder("module1", target); ast_gen_ius::TensorGroup tensor_group({A, B, C}); diff --git a/paddle/cinn/backends/codegen_cuda_dev.cc b/paddle/cinn/backends/codegen_cuda_dev.cc index eb70ebe8fff8e..9c19c6faffb73 100644 --- a/paddle/cinn/backends/codegen_cuda_dev.cc +++ b/paddle/cinn/backends/codegen_cuda_dev.cc @@ -21,10 +21,12 @@ #include #include +#include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_verify.h" #include "paddle/cinn/optim/ir_simplify.h" +#include "paddle/common/errors.h" namespace cinn { namespace backends { @@ -124,6 +126,7 @@ std::vector FilterDeallocTempBuffers(const std::vector &frees) { bool has_symbolic_constant = false; const ir::_Buffer_ *buffer = op->destination.As(); for (Expr shape : buffer->shape) { + shape = common::AutoSimplify(shape); ir::ir_utils::CollectIRNodes(shape, [&](const Expr *x) { if (x->as_var()) { CHECK(x->as_var()->is_symbolic_constant) @@ -290,7 +293,7 @@ std::string CodeGenCUDA_Dev::Compile(const ir::Module &module, Compile(func); } } else { - LOG(FATAL) << "Not supported OutputKind"; + PADDLE_THROW(phi::errors::InvalidArgument("Not supported OutputKind")); } if (for_nvrtc_) { @@ -370,8 +373,10 @@ void CodeGenCUDA_Dev::PrintTempBufferCreation(const ir::Buffer &buffer) { print_gpu_memory(""); } } else { - LOG(FATAL) << "CUDA device codegen not support memory " << buffer->name - << ", type " << buffer->memory_type; + std::stringstream ss; + ss << "CUDA device codegen not support memory " << buffer->name << ", type " + << buffer->memory_type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -505,5 +510,36 @@ void CodeGenCUDA_Dev::Visit(const ir::Store *op) { } } +ir::Expr CalculateSharedMemory(const ir::Buffer &buffer) { + Expr buffer_size(1); + for (int i = 0; i < buffer->shape.size(); i++) { + buffer_size = buffer_size * buffer->shape[i]; + } + int type_bytes = buffer->dtype.bytes(); + return buffer_size * Expr(type_bytes); +} + +ir::Expr CalculateSharedMemory(const ir::Expr &func_expr) { + auto func = func_expr.as_lowered_func(); + PADDLE_ENFORCE_NOT_NULL( + func, ::common::errors::InvalidType("expr is not a lowered_func")); + auto alloc_temp_buffers = func->PrepareAllocTempBufferExprs(); + ir::Expr shm_size{0}; + for (const auto &alloc : alloc_temp_buffers) { + PADDLE_ENFORCE_NOT_NULL( + alloc.As(), + ::common::errors::InvalidType("expr is not a Alloc node")); + PADDLE_ENFORCE_NOT_NULL( + alloc.As()->destination.as_buffer(), + ::common::errors::InvalidType("expr is not a Buffer node")); + + auto buffer = alloc.As()->destination.as_buffer_ref(); + if (buffer->memory_type == ir::MemoryType::GPUShared) { + shm_size = shm_size + CalculateSharedMemory(buffer); + } + } + return common::AutoSimplify(shm_size); +} + } // namespace backends } // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_dev.h b/paddle/cinn/backends/codegen_cuda_dev.h index d1ebfd930f92f..d0995fccc0e06 100644 --- a/paddle/cinn/backends/codegen_cuda_dev.h +++ b/paddle/cinn/backends/codegen_cuda_dev.h @@ -127,5 +127,7 @@ class CodeGenCUDA_Dev : public CodeGenC { std::vector dynamic_alloc_buffers_; }; +ir::Expr CalculateSharedMemory(const ir::Expr& func_expr); + } // namespace backends } // namespace cinn diff --git a/paddle/cinn/backends/codegen_cuda_util.cc b/paddle/cinn/backends/codegen_cuda_util.cc index 6adc049e9d349..729dcca7be745 100644 --- a/paddle/cinn/backends/codegen_cuda_util.cc +++ b/paddle/cinn/backends/codegen_cuda_util.cc @@ -78,6 +78,7 @@ detail::CollectBucketStrategyHostFunctionVisitor::GenDeviceKernelName( void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( ir::Expr func, ir::Expr predicate) { + VLOG(4) << "Process Lowered Func" << func; ir::_LoweredFunc_ *func_node = func.as_lowered_func(); CHECK(func_node); if (!func_node->cuda_axis_info.valid()) { @@ -90,12 +91,7 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc( ir::Var kernel_ptr(GenDeviceKernelName(func_node->name, predicate), type_of()); - // shared_mem_bytes Can be calculated after codegen_cuda_dev buffer creation - // however, this make CodeGenCUDA_Dev before spliting the host and device - // module Maybe we could reorder the process. - CodeGenCUDA_Dev codegen_dev(cinn::common::DefaultNVGPUTarget()); - codegen_dev.Compile(ir::LoweredFunc(func.as_lowered_func_ref())); - Expr shared_mem_bytes = codegen_dev.GetDynSharedMemOffset(); + Expr shared_mem_bytes = CalculateSharedMemory(func); VLOG(6) << "Add a call node for func_node->name " << func_node->name << "\n" << "grid_dim: (" << func_node->cuda_axis_info.grid_dim(0) << ", " diff --git a/paddle/cinn/backends/cuda_util.h b/paddle/cinn/backends/cuda_util.h index 5175ba8e819c6..26d4110b0a10c 100644 --- a/paddle/cinn/backends/cuda_util.h +++ b/paddle/cinn/backends/cuda_util.h @@ -26,63 +26,76 @@ #include #include "paddle/cinn/runtime/cinn_runtime.h" - -#define CUDA_DRIVER_CALL(func) \ - { \ - auto status = func; \ - if (status != CUDA_SUCCESS) { \ - const char* msg; \ - cuGetErrorString(status, &msg); \ - LOG(FATAL) << "CUDA Driver Error: " #func " failed with error: " << msg; \ - } \ +#include "paddle/common/enforce.h" + +#define CUDA_DRIVER_CALL(func) \ + { \ + auto status = func; \ + if (status != CUDA_SUCCESS) { \ + const char* msg; \ + cuGetErrorString(status, &msg); \ + std::stringstream ss; \ + ss << "CUDA Driver Error: " #func " failed with error: " << msg; \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } -#define CUDA_CALL(func) \ - { \ - auto status = func; \ - if (status != cudaSuccess) { \ - LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ - } \ +#define CUDA_CALL(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA Error : " << cudaGetErrorString(status); \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } -#define CURAND_CALL(func) \ - { \ - auto status = func; \ - if (status != CURAND_STATUS_SUCCESS) { \ - LOG(FATAL) << "CURAND Error : " << status; \ - } \ +#define CURAND_CALL(func) \ + { \ + auto status = func; \ + if (status != CURAND_STATUS_SUCCESS) { \ + std::stringstream ss; \ + ss << "CURAND Error : " << status; \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } #define CUSOLVER_CALL(func) \ { \ auto status = func; \ if (status != CUSOLVER_STATUS_SUCCESS) { \ - LOG(FATAL) << "CUSOLVER Error: " << status; \ + std::stringstream ss; \ + ss << "CUSOLVER Error: " << status; \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ } \ } -#define CUBLAS_CALL(func) \ - { \ - auto status = func; \ - if (status != CUBLAS_STATUS_SUCCESS) { \ - LOG(FATAL) << "CUBLAS Error!"; \ - } \ +#define CUBLAS_CALL(func) \ + { \ + auto status = func; \ + if (status != CUBLAS_STATUS_SUCCESS) { \ + PADDLE_THROW(phi::errors::Fatal("CUBLAS Error!")); \ + } \ } -#define CUDNN_CALL(func) \ - { \ - auto status = func; \ - if (status != CUDNN_STATUS_SUCCESS) { \ - LOG(FATAL) << "CUDNN Error : " << cudnnGetErrorString(status); \ - } \ +#define CUDNN_CALL(func) \ + { \ + auto status = func; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::stringstream ss; \ + ss << "CUDNN Error : " << cudnnGetErrorString(status); \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } -#define NVRTC_CALL(func) \ - { \ - auto status = func; \ - if (status != NVRTC_SUCCESS) { \ - LOG(FATAL) << "NVRTC Error : " << nvrtcGetErrorString(status); \ - } \ +#define NVRTC_CALL(func) \ + { \ + auto status = func; \ + if (status != NVRTC_SUCCESS) { \ + std::stringstream ss; \ + ss << "NVRTC Error : " << nvrtcGetErrorString(status); \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } namespace cinn { diff --git a/paddle/cinn/backends/ir_schedule_test.cc b/paddle/cinn/backends/ir_schedule_test.cc index e3196e90bfe65..29eae201bbb78 100644 --- a/paddle/cinn/backends/ir_schedule_test.cc +++ b/paddle/cinn/backends/ir_schedule_test.cc @@ -84,7 +84,7 @@ void test_split_and_fuse1(void* _args, int32_t num_args) float* B = ((float*)(_B->memory)); for (int32_t i_j_fused_i_j_fused_0_fused = 0; i_j_fused_i_j_fused_0_fused < 256; i_j_fused_i_j_fused_0_fused += 1) { for (int32_t i_j_fused_i_j_fused_0_fused_0 = 0; i_j_fused_i_j_fused_0_fused_0 < 4; i_j_fused_i_j_fused_0_fused_0 += 1) { - B[(((i_j_fused_i_j_fused_0_fused / 8) * 32) + (((4 * i_j_fused_i_j_fused_0_fused) + i_j_fused_i_j_fused_0_fused_0) & 31))] = A[(((i_j_fused_i_j_fused_0_fused / 8) * 32) + (((4 * i_j_fused_i_j_fused_0_fused) + i_j_fused_i_j_fused_0_fused_0) & 31))]; + B[((((4 * i_j_fused_i_j_fused_0_fused) + i_j_fused_i_j_fused_0_fused_0) & 31) + ((i_j_fused_i_j_fused_0_fused / 8) * 32))] = A[((((4 * i_j_fused_i_j_fused_0_fused) + i_j_fused_i_j_fused_0_fused_0) & 31) + ((i_j_fused_i_j_fused_0_fused / 8) * 32))]; }; }; cinn_buffer_free((void*)(0), _B); @@ -196,7 +196,7 @@ void TestSplitThrow() { auto source_code = codegen.Compile(module, CodeGenC::OutputKind::CImpl); } TEST(IrSchedule, split_throw) { - ASSERT_THROW(TestSplitThrow(), utils::enforce::EnforceNotMet); + ASSERT_THROW(TestSplitThrow(), ::common::enforce::EnforceNotMet); } TEST(IrSchedule, reorder1) { @@ -608,7 +608,7 @@ void test_vectorize(void* _args, int32_t num_args) float* B = ((float*)(_B->memory)); for (int32_t i = 0; i < 32; i += 1) { for (int32_t j = 0; j < 2; j += 1) { - B[StackVec<16,int32_t>::Ramp(((32 * i) + (16 * j)), 1, 16)] = StackedVec::Load(A,((32 * i) + (16 * j))); + B[StackVec<16,int32_t>::Ramp(((16 * j) + (i * 32)), 1, 16)] = StackedVec::Load(A,((16 * j) + (i * 32))); }; }; cinn_buffer_free((void*)(0), _B); @@ -1094,7 +1094,7 @@ void test_compute_at3(void* _args, int32_t num_args) }; }; for (int32_t i_j_fused_0 = 0; i_j_fused_0 < 128; i_j_fused_0 += 1) { - C[((128 * i_j_fused) + i_j_fused_0)] = B[((128 * i_j_fused) + i_j_fused_0)]; + C[(i_j_fused_0 + (128 * i_j_fused))] = B[(i_j_fused_0 + (128 * i_j_fused))]; }; }; cinn_buffer_free((void*)(0), _B); @@ -1286,8 +1286,8 @@ void test_compute_at6(const float* __restrict__ A, float* __restrict__ C) float* B = _B_temp_buffer; for (int32_t i_j_fused = 0; i_j_fused < 32; i_j_fused += 1) { for (int32_t i_j_fused_0 = 0; i_j_fused_0 < 128; i_j_fused_0 += 1) { - B[((128 * i_j_fused) + i_j_fused_0)] = A[((128 * i_j_fused) + i_j_fused_0)]; - C[((128 * i_j_fused) + i_j_fused_0)] = B[((128 * i_j_fused) + i_j_fused_0)]; + B[(i_j_fused_0 + (128 * i_j_fused))] = A[(i_j_fused_0 + (128 * i_j_fused))]; + C[(i_j_fused_0 + (128 * i_j_fused))] = B[(i_j_fused_0 + (128 * i_j_fused))]; }; }; } diff --git a/paddle/cinn/backends/llvm/codegen_llvm.cc b/paddle/cinn/backends/llvm/codegen_llvm.cc index 6147940075d8a..e24b5220919cb 100644 --- a/paddle/cinn/backends/llvm/codegen_llvm.cc +++ b/paddle/cinn/backends/llvm/codegen_llvm.cc @@ -264,7 +264,7 @@ llvm::Value *CodeGenLLVM::Visit(const ir::FloatImm *op) { } else if (op->type().is_float16()) { return llvm::ConstantFP::get(b_->getHalfTy(), op->value); } else { - LOG(FATAL) << "illegal float type."; + PADDLE_THROW(phi::errors::InvalidArgument("illegal float type.")); } return nullptr; } @@ -1379,7 +1379,7 @@ void CodeGenLLVM::InitTarget(const Target &target) { } else if (target.bits == Target::Bit::k64) { naive_vec_alignment_ = 512; } else { - LOG(FATAL) << "get unknown bits"; + PADDLE_THROW(phi::errors::InvalidArgument("get unknown bits")); } break; case Target::Arch::ARM: @@ -1389,7 +1389,7 @@ void CodeGenLLVM::InitTarget(const Target &target) { naive_vec_alignment_ = 128; break; case Target::Arch::Unk: - LOG(FATAL) << "unknown Arch found"; + PADDLE_THROW(phi::errors::InvalidArgument("unknown Arch found")); break; } } @@ -1669,7 +1669,9 @@ llvm::Value *CodeGenLLVM::Visit(const ir::intrinsics::PodValueToX *op) { } else if (to_type == type_of()) { callee = m_->getFunction(runtime::intrinsic::pod_value_to_buffer_p); } else { - LOG(FATAL) << "Not supported type: " << to_type; + std::stringstream ss; + ss << "Not supported type: " << to_type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } CHECK(callee); diff --git a/paddle/cinn/backends/nvrtc/nvrtc_util.cc b/paddle/cinn/backends/nvrtc/nvrtc_util.cc index 7af601f4ead23..4a68b9a82f61d 100644 --- a/paddle/cinn/backends/nvrtc/nvrtc_util.cc +++ b/paddle/cinn/backends/nvrtc/nvrtc_util.cc @@ -75,10 +75,12 @@ std::vector Compiler::FindCUDAIncludePaths() { return {cuda_include_path}; } #endif - LOG(FATAL) << "Cannot find cuda include path." - << "CUDA_PATH is not set or CUDA is not installed in the default " - "installation path." - << "In other than linux, it is necessary to set CUDA_PATH."; + std::stringstream ss; + ss << "Cannot find cuda include path." + << "CUDA_PATH is not set or CUDA is not installed in the default " + "installation path." + << "In other than linux, it is necessary to set CUDA_PATH."; + PADDLE_THROW(phi::errors::Fatal(ss.str())); return {cuda_include_path}; } diff --git a/paddle/cinn/common/CMakeLists.txt b/paddle/cinn/common/CMakeLists.txt index e9c4523edd323..95227b6f414a4 100644 --- a/paddle/cinn/common/CMakeLists.txt +++ b/paddle/cinn/common/CMakeLists.txt @@ -23,8 +23,7 @@ gather_srcs( nvgpu_dev_info.cc integer_set.cc dim_expr_converter.cc - broadcast_tree.cc - dim_expr_util.cc) + broadcast_tree.cc) cinn_cc_test(test_equation_graph_topo_walker SRCS equation_graph_topo_walker_test.cc DEPS gtest glog) @@ -48,9 +47,7 @@ if(WITH_CUDA) cinn_nv_test(test_fp16_bf16_cuda SRCS float16_bfloat16_cuda_test.cu DEPS gtest glog) endif() -if(NOT CINN_ONLY) - cinn_cc_test(dim_expr_util_test SRCS dim_expr_util_test.cc DEPS cinncore) - cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS - cinncore) - cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore) -endif() + +cinn_cc_test(dim_expr_converter_test SRCS dim_expr_converter_test.cc DEPS + cinncore) +cinn_cc_test(broadcast_tree_test SRCS broadcast_tree_test.cc DEPS cinncore) diff --git a/paddle/cinn/common/broadcast_tree.cc b/paddle/cinn/common/broadcast_tree.cc index 1a1bdbd550c75..4b14b41af3ae4 100644 --- a/paddle/cinn/common/broadcast_tree.cc +++ b/paddle/cinn/common/broadcast_tree.cc @@ -17,8 +17,7 @@ #include #include -#include "paddle/cinn/common/dim_expr_util.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace cinn::common { @@ -116,71 +115,6 @@ void ForEachBroadcastDimExpr(const BroadcastLeaf& leaves, } } -std::optional> GetFirstCstrBroadcastable( - const BroadcastLeaf& leaves) { - std::optional> ret; - ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { - const auto& operands = broadcast.operands; - std::optional lhs_symbol; - std::optional rhs_symbol; - size_t i = 0; - for (; i < operands->size(); ++i) { - if (operands->at(i).template isa()) { - lhs_symbol = operands->at(i); - break; - } - } - for (i++; i < operands->size(); ++i) { - if (operands->at(i).template isa()) { - rhs_symbol = operands->at(i); - break; - } - } - if (lhs_symbol.has_value() && rhs_symbol.has_value()) { - CHECK(lhs_symbol != rhs_symbol) - << lhs_symbol.value() << " != " << rhs_symbol.value(); - ret = symbol::Broadcastable{lhs_symbol.value(), - rhs_symbol.value()}; - return true; - } - return false; - }); - if (ret.has_value()) return ret.value(); - ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { - const auto& operands = broadcast.operands; - std::optional lhs_symbol; - std::optional rhs; - for (const auto& operand : *operands) { - if (operand.template isa()) { - lhs_symbol = operand; - break; - } - } - for (const auto& operand : *operands) { - if (operand != lhs_symbol) { - rhs = operand; - break; - } - } - if (lhs_symbol.has_value() && rhs.has_value()) { - ret = symbol::Broadcastable{lhs_symbol.value(), - rhs.value()}; - return true; - } - return false; - }); - if (ret.has_value()) return ret.value(); - ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { - const auto& operands = broadcast.operands; - CHECK_GE(operands->size(), 2); - CHECK(operands->at(0) != operands->at(1)); - ret = symbol::Broadcastable{operands->at(0), - operands->at(1)}; - return true; - }); - return ret; -} - using Pattern2Placement = std::unordered_map; Pattern2Placement ConstructCstrLhsEqRhsReplacement( @@ -209,7 +143,7 @@ symbol::DimExpr GetCstrLhsEqRhsDimExpr( const auto& pattern2replacement = ConstructCstrLhsEqRhsReplacement(broadcastable_condition); return symbol::SimplifyDimExpr( - SubstituteDimExpr(dim_expr, pattern2replacement)); + symbol::SubstituteDimExpr(dim_expr, pattern2replacement)); } symbol::DimExpr GetCstrLhsEqOneDimExpr( @@ -218,7 +152,7 @@ symbol::DimExpr GetCstrLhsEqOneDimExpr( const auto& pattern2replacement = ConstructCstrLhsEqOneReplacement(broadcastable_condition); return symbol::SimplifyDimExpr( - SubstituteDimExpr(dim_expr, pattern2replacement)); + symbol::SubstituteDimExpr(dim_expr, pattern2replacement)); } symbol::DimExpr GetCstrRhsEqOneDimExpr( @@ -227,7 +161,7 @@ symbol::DimExpr GetCstrRhsEqOneDimExpr( const auto& pattern2replacement = ConstructCstrRhsEqOneReplacement(broadcastable_condition); return symbol::SimplifyDimExpr( - SubstituteDimExpr(dim_expr, pattern2replacement)); + symbol::SubstituteDimExpr(dim_expr, pattern2replacement)); } typedef symbol::DimExpr (*ConvertDimExprT)( @@ -292,6 +226,71 @@ BroadcastBranch ConstructBroadcastBranch( } // namespace +std::optional> GetFirstCstrBroadcastable( + const BroadcastLeaf& leaves) { + std::optional> ret; + ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { + const auto& operands = broadcast.operands; + std::optional lhs_symbol; + std::optional rhs_symbol; + size_t i = 0; + for (; i < operands->size(); ++i) { + if (operands->at(i).template isa()) { + lhs_symbol = operands->at(i); + break; + } + } + for (i++; i < operands->size(); ++i) { + if (operands->at(i).template isa()) { + rhs_symbol = operands->at(i); + break; + } + } + if (lhs_symbol.has_value() && rhs_symbol.has_value()) { + CHECK(lhs_symbol != rhs_symbol) + << lhs_symbol.value() << " != " << rhs_symbol.value(); + ret = symbol::Broadcastable{lhs_symbol.value(), + rhs_symbol.value()}; + return true; + } + return false; + }); + if (ret.has_value()) return ret.value(); + ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { + const auto& operands = broadcast.operands; + std::optional lhs_symbol; + std::optional rhs; + for (const auto& operand : *operands) { + if (operand.template isa()) { + lhs_symbol = operand; + break; + } + } + for (const auto& operand : *operands) { + if (operand != lhs_symbol) { + rhs = operand; + break; + } + } + if (lhs_symbol.has_value() && rhs.has_value()) { + ret = symbol::Broadcastable{lhs_symbol.value(), + rhs.value()}; + return true; + } + return false; + }); + if (ret.has_value()) return ret.value(); + ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool { + const auto& operands = broadcast.operands; + CHECK_GE(operands->size(), 2); + CHECK(operands->at(0) != operands->at(1)); + ret = symbol::Broadcastable{operands->at(0), + operands->at(1)}; + return true; + }); + return ret; +} + BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves) { std::optional> broadcastable_condition = GetFirstCstrBroadcastable(leaves); diff --git a/paddle/cinn/common/broadcast_tree.h b/paddle/cinn/common/broadcast_tree.h index 6a7dfc5d1617c..5b8c051299af8 100644 --- a/paddle/cinn/common/broadcast_tree.h +++ b/paddle/cinn/common/broadcast_tree.h @@ -33,4 +33,7 @@ BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves); std::string ToTxtString(const BroadcastTree&); +std::optional> GetFirstCstrBroadcastable( + const BroadcastLeaf& leaves); + } // namespace cinn::common diff --git a/paddle/cinn/common/cas.cc b/paddle/cinn/common/cas.cc index f2e93286a04a7..fac9e08befee9 100644 --- a/paddle/cinn/common/cas.cc +++ b/paddle/cinn/common/cas.cc @@ -854,7 +854,7 @@ void CasSimplifyMutator::UnfoldBound(Expr* lower_bound, AddBaseAndSimplify(lower_bound, var); AddBaseAndSimplify(upper_bound, var); } else { - LOG(FATAL) << "can't get the bound"; + PADDLE_THROW(phi::errors::InvalidArgument("can't get the bound")); } } diff --git a/paddle/cinn/common/common.h b/paddle/cinn/common/common.h index 34623d904515b..e5bb5d29cf181 100644 --- a/paddle/cinn/common/common.h +++ b/paddle/cinn/common/common.h @@ -24,6 +24,8 @@ #include "paddle/cinn/common/shared.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/common/type.h" +#include "paddle/cinn/utils/error.h" +#include "paddle/common/enforce.h" namespace cinn { diff --git a/paddle/cinn/common/dim_expr_converter.cc b/paddle/cinn/common/dim_expr_converter.cc index c0cb71f408ddc..06c8968d98876 100644 --- a/paddle/cinn/common/dim_expr_converter.cc +++ b/paddle/cinn/common/dim_expr_converter.cc @@ -68,7 +68,17 @@ struct DimExprToIrExprVisitor { } ir::Expr product = ConvertToIrExpr(operands->at(0)); for (std::size_t i = 1; i < operands->size(); ++i) { - product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i))); + // Convert Reciprocal(S0) to (1 / S0) will result in precision + // error. For example, (S0 * S1 / S2) != (S0 * S1 * (1 / S2)). So we + // should use Div instead of Reciprocal here. + if (operands->at(i).isa>()) { + product = ir::Div::Make( + product, + ConvertToIrExpr( + operands->at(i).dyn_cast>()->data)); + } else { + product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i))); + } } return product; } @@ -94,8 +104,8 @@ struct DimExprToIrExprVisitor { } ir::Expr operator()(const Broadcast& dim_expr) { - LOG(FATAL) - << "no support for converting from Broadcast to ir::Expr"; + PADDLE_THROW(phi::errors::Fatal( + "no support for converting from Broadcast to ir::Expr")); } }; diff --git a/paddle/cinn/common/dim_expr_util.cc b/paddle/cinn/common/dim_expr_util.cc deleted file mode 100644 index 0d0a9090429a0..0000000000000 --- a/paddle/cinn/common/dim_expr_util.cc +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/common/dim_expr_util.h" - -namespace cinn::common { -using namespace symbol; // NOLINT - -namespace { - -class SubstituteDimExprHelper final { - public: - explicit SubstituteDimExprHelper( - const std::unordered_map& - pattern_to_replacement) - : pattern_to_replacement_(pattern_to_replacement) {} - - std::optional Substitute(const DimExpr& dim_expr) { - auto iter = pattern_to_replacement_.find(dim_expr); - if (iter != pattern_to_replacement_.end()) return iter->second; - return std::visit([&](const auto& impl) { return SubstituteImpl(impl); }, - dim_expr.variant()); - } - - private: - std::optional SubstituteImpl(const std::int64_t& value) { - // `Substitute` has handled the case that `value` is matched. - return std::nullopt; - } - std::optional SubstituteImpl(const std::string& value) { - // `Substitute` has handled the case that `value` is matched. - return std::nullopt; - } - - std::optional SubstituteImpl(const Negative& dim_expr) { - return SubstituteUnary(dim_expr); - } - std::optional SubstituteImpl(const Reciprocal& dim_expr) { - return SubstituteUnary(dim_expr); - } - - template - std::optional SubstituteUnary(const T& dim_expr) { - const auto& operand = dim_expr->data; - const auto& substituted_operand = Substitute(operand); - if (!substituted_operand.has_value()) return std::nullopt; - return T{substituted_operand.value()}; - } - - std::optional SubstituteImpl(const Add& dim_expr) { - return SubstituteVariadic(dim_expr); - } - - std::optional SubstituteImpl(const Mul& dim_expr) { - return SubstituteVariadic(dim_expr); - } - - std::optional SubstituteImpl(const Max& dim_expr) { - return SubstituteVariadic(dim_expr); - } - - std::optional SubstituteImpl(const Min& dim_expr) { - return SubstituteVariadic(dim_expr); - } - - std::optional SubstituteImpl(const Broadcast& dim_expr) { - return SubstituteVariadic(dim_expr); - } - - template - std::optional SubstituteVariadic(const T& dim_expr) { - const auto& operands = *(dim_expr.operands); - List substituted_operands{}; - size_t replace_cnt = 0; - for (const auto& operand : operands) { - const auto& substituted_operand = Substitute(operand); - replace_cnt += substituted_operand.has_value(); - substituted_operands->push_back(substituted_operand.has_value() - ? substituted_operand.value() - : operand); - } - if (replace_cnt == 0) return std::nullopt; - return T{substituted_operands}; - } - - std::unordered_map pattern_to_replacement_; -}; - -} // namespace - -symbol::DimExpr SubstituteDimExpr( - const symbol::DimExpr& dim_expr, - const std::unordered_map& - pattern_to_replacement) { - const auto& opt_replaced = - SubstituteDimExprHelper(pattern_to_replacement).Substitute(dim_expr); - return opt_replaced.has_value() ? opt_replaced.value() : dim_expr; -} - -} // namespace cinn::common diff --git a/paddle/cinn/common/dim_expr_util_test.cc b/paddle/cinn/common/dim_expr_util_test.cc deleted file mode 100644 index 82b300fc5bfe2..0000000000000 --- a/paddle/cinn/common/dim_expr_util_test.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/common/dim_expr_util.h" - -#include "gtest/gtest.h" - -namespace cinn::common { -using namespace symbol; // NOLINT - -namespace { -DimExpr CreateExampleDimExpr() { - DimExpr sym0 = DimExpr("S0"); - DimExpr sym1 = DimExpr("S1"); - DimExpr constant = DimExpr(2); - return (sym0 - sym1) * constant / sym0; -} -} // namespace - -TEST(DimExprUtil, Substitute) { - DimExpr dim_expr = CreateExampleDimExpr(); - std::unordered_map naive_to_full_name{ - {DimExpr("S0"), DimExpr("symbol0")}, {DimExpr("S1"), DimExpr("symbol1")}}; - std::unordered_map full_name_to_naive{ - {DimExpr("symbol0"), DimExpr("S0")}, {DimExpr("symbol1"), DimExpr("S1")}}; - - const auto& mid_expr = SubstituteDimExpr(dim_expr, naive_to_full_name); - const auto& ret_expr = SubstituteDimExpr(mid_expr, full_name_to_naive); - ASSERT_EQ(ret_expr, dim_expr); -} - -} // namespace cinn::common diff --git a/paddle/cinn/common/float16_bfloat16_cuda_test.cu b/paddle/cinn/common/float16_bfloat16_cuda_test.cu index e8d9c7f534cc1..fd6c39cc51f8f 100644 --- a/paddle/cinn/common/float16_bfloat16_cuda_test.cu +++ b/paddle/cinn/common/float16_bfloat16_cuda_test.cu @@ -17,19 +17,21 @@ #include #include - #include "paddle/cinn/common/bfloat16.h" #include "paddle/cinn/common/float16.h" +#include "paddle/common/enforce.h" namespace cinn { namespace common { -#define CUDA_CALL(func) \ - { \ - auto status = func; \ - if (status != cudaSuccess) { \ - LOG(FATAL) << "CUDA Error : " << cudaGetErrorString(status); \ - } \ +#define CUDA_CALL(func) \ + { \ + auto status = func; \ + if (status != cudaSuccess) { \ + std::stringstream ss; \ + ss << "CUDA Error : " << cudaGetErrorString(status); \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } class CudaMem { diff --git a/paddle/cinn/common/graph_utils.cc b/paddle/cinn/common/graph_utils.cc index 446c124124b9a..b1110e8ca8aa0 100755 --- a/paddle/cinn/common/graph_utils.cc +++ b/paddle/cinn/common/graph_utils.cc @@ -32,7 +32,7 @@ namespace { void DFSSortUtil(const GraphNode *node, std::vector *order) {} std::vector DFSSort(const std::vector &nodes) { - LOG(FATAL) << "not implemented"; + PADDLE_THROW(phi::errors::Unimplemented("Not Implemented")); return {}; } diff --git a/paddle/cinn/common/integer_set.cc b/paddle/cinn/common/integer_set.cc index f6d6446b9bb24..5a1bbc6c625a9 100644 --- a/paddle/cinn/common/integer_set.cc +++ b/paddle/cinn/common/integer_set.cc @@ -44,6 +44,9 @@ cas_intervals_t CollectVarIntervalsOfExprs(const std::vector& exprs, if (var->upper_bound.defined()) { upper_bound = var->upper_bound; } + if (var->is_symbolic_constant) { + lower_bound = ir::Expr(1); + } var_intervals.insert( {var->name, CasInterval(lower_bound, upper_bound)}); } @@ -118,25 +121,20 @@ std::optional SymbolicExprAnalyzer::ProveGE(const ir::Expr& lhs, if (lhs == rhs) { return true; } - if (lhs == SymbolicExprLimit::positive_inf || - rhs == SymbolicExprLimit::negative_inf) { - return true; - } if (rhs == SymbolicExprLimit::positive_inf || lhs == SymbolicExprLimit::negative_inf) { return false; } - ir::Expr diff = AutoSimplify(ir::Sub::Make(lhs, rhs), var_intervals_); - VLOG(6) << "diff of " << ir::Sub::Make(lhs, rhs) << " = " << diff; - if (diff.is_constant() && diff.get_constant() >= 0) { + if (lhs == SymbolicExprLimit::positive_inf || + rhs == SymbolicExprLimit::negative_inf) { return true; } + ir::Expr diff = AutoSimplify(ir::Sub::Make(lhs, rhs), var_intervals_); + VLOG(6) << "diff of " << ir::Sub::Make(lhs, rhs) << " = " << diff; if (diff.is_constant() && diff.get_constant() < 0) { return false; } - ir::Expr diff_lower_bound = LowerBound(diff); - VLOG(6) << "lower bound of " << diff << " = " << diff_lower_bound; - if (diff_lower_bound.is_constant() && diff_lower_bound.get_constant() >= 0) { + if (diff.is_constant() && diff.get_constant() >= 0) { return true; } ir::Expr diff_upper_bound = UpperBound(diff); @@ -144,6 +142,11 @@ std::optional SymbolicExprAnalyzer::ProveGE(const ir::Expr& lhs, if (diff_upper_bound.is_constant() && diff_upper_bound.get_constant() < 0) { return false; } + ir::Expr diff_lower_bound = LowerBound(diff); + VLOG(6) << "lower bound of " << diff << " = " << diff_lower_bound; + if (diff_lower_bound.is_constant() && diff_lower_bound.get_constant() >= 0) { + return true; + } return std::nullopt; } @@ -157,25 +160,20 @@ std::optional SymbolicExprAnalyzer::ProveGT(const ir::Expr& lhs, if (lhs == rhs) { return false; } - if (lhs == SymbolicExprLimit::positive_inf || - rhs == SymbolicExprLimit::negative_inf) { - return true; - } if (rhs == SymbolicExprLimit::positive_inf || lhs == SymbolicExprLimit::negative_inf) { return false; } - ir::Expr diff = AutoSimplify(ir::Sub::Make(lhs, rhs), var_intervals_); - VLOG(6) << "diff of " << ir::Sub::Make(lhs, rhs) << " = " << diff; - if (diff.is_constant() && diff.get_constant() > 0) { + if (lhs == SymbolicExprLimit::positive_inf || + rhs == SymbolicExprLimit::negative_inf) { return true; } + ir::Expr diff = AutoSimplify(ir::Sub::Make(lhs, rhs), var_intervals_); + VLOG(6) << "diff of " << ir::Sub::Make(lhs, rhs) << " = " << diff; if (diff.is_constant() && diff.get_constant() <= 0) { return false; } - ir::Expr diff_lower_bound = LowerBound(diff); - VLOG(6) << "lower bound of " << diff << " = " << diff_lower_bound; - if (diff_lower_bound.is_constant() && diff_lower_bound.get_constant() > 0) { + if (diff.is_constant() && diff.get_constant() > 0) { return true; } ir::Expr diff_upper_bound = UpperBound(diff); @@ -183,6 +181,12 @@ std::optional SymbolicExprAnalyzer::ProveGT(const ir::Expr& lhs, if (diff_upper_bound.is_constant() && diff_upper_bound.get_constant() <= 0) { return false; } + ir::Expr diff_lower_bound = LowerBound(diff); + VLOG(6) << "lower bound of " << diff << " = " << diff_lower_bound; + if (diff_lower_bound.is_constant() && diff_lower_bound.get_constant() > 0) { + return true; + } + return std::nullopt; } @@ -288,7 +292,7 @@ std::optional SymbolicExprAnalyzer::ProveDivisible( case cinn::ir::IrNodeTy::Minus: return ProveDivisible(lhs.As()->v(), rhs); default: - LOG(FATAL) << "Not supported yet!"; + PADDLE_THROW(phi::errors::InvalidArgument("Not supported yet!")); break; } } diff --git a/paddle/cinn/common/macros.h b/paddle/cinn/common/macros.h index dbae22549331c..52d91c922ad6f 100644 --- a/paddle/cinn/common/macros.h +++ b/paddle/cinn/common/macros.h @@ -23,7 +23,8 @@ void operator=(const TypeName&) = delete #ifndef CINN_NOT_IMPLEMENTED -#define CINN_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented"; +#define CINN_NOT_IMPLEMENTED \ + PADDLE_THROW(phi::errors::Unimplemented("Not Implemented")); #endif #define CINN_RESULT_SHOULD_USE __attribute__((warn_unused_result)) diff --git a/paddle/cinn/common/target.cc b/paddle/cinn/common/target.cc index fc01a56db481d..c24c89c29ae1a 100644 --- a/paddle/cinn/common/target.cc +++ b/paddle/cinn/common/target.cc @@ -24,6 +24,7 @@ #include "paddle/cinn/backends/cuda_util.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/runtime/cinn_runtime.h" +#include "paddle/common/enforce.h" namespace cinn { namespace common { @@ -51,7 +52,7 @@ int Target::runtime_arch() const { case Arch::ARM: return cinn_arm_device; default: - LOG(FATAL) << "Not supported arch"; + PADDLE_THROW(phi::errors::InvalidArgument("Not supported arch")); } return -1; } @@ -106,7 +107,7 @@ int Target::get_target_bits() const { case Bit::Unk: return 0; default: - LOG(FATAL) << "Not supported Bit"; + PADDLE_THROW(phi::errors::InvalidArgument("Not supported Bit")); } return -1; } diff --git a/paddle/cinn/common/type.cc b/paddle/cinn/common/type.cc index 67ee1b25a09e9..41cfd9e638f90 100644 --- a/paddle/cinn/common/type.cc +++ b/paddle/cinn/common/type.cc @@ -18,7 +18,7 @@ #include #include #include - +#include "paddle/common/enforce.h" namespace cinn { namespace common { @@ -600,7 +600,9 @@ std::string Type2Str(const Type &type) { return "unk"; default: - LOG(FATAL) << "Not support type [" << type << "] ! Please Check.\n"; + std::stringstream ss; + ss << "Not support type [" << type << "] ! Please Check.\n"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return "unk"; } diff --git a/paddle/cinn/frontend/CMakeLists.txt b/paddle/cinn/frontend/CMakeLists.txt index e04ae9e9851c0..f84e4f0cfdc85 100755 --- a/paddle/cinn/frontend/CMakeLists.txt +++ b/paddle/cinn/frontend/CMakeLists.txt @@ -62,6 +62,7 @@ add_subdirectory(paddle) add_subdirectory(decomposer) add_subdirectory(op_mappers) add_subdirectory(pass) +add_subdirectory(group_cluster) cinn_cc_test(test_op_mapper_registry SRCS op_mapper_registry_test.cc DEPS cinncore) diff --git a/paddle/cinn/frontend/computation.cc b/paddle/cinn/frontend/computation.cc index 90c889c599690..ee7d2ce6b3a82 100644 --- a/paddle/cinn/frontend/computation.cc +++ b/paddle/cinn/frontend/computation.cc @@ -251,9 +251,11 @@ hlir::framework::Tensor CinnComputation::GetTensor(const std::string &tname) { } auto it = context_->varmap_paddle2program.find(tname); if (it == context_->varmap_paddle2program.end()) { - LOG(FATAL) << "No variable called [" << tname - << "] found in computation\nThe existing vars: " - << utils::Join(context_->scope->var_names(), ", "); + std::stringstream ss; + ss << "No variable called [" << tname + << "] found in computation\nThe existing vars: " + << utils::Join(context_->scope->var_names(), ", "); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return context_->scope->GetTensor(it->second); } diff --git a/paddle/cinn/frontend/decomposer/batch_norm.cc b/paddle/cinn/frontend/decomposer/batch_norm.cc index b2d59053e43de..5e40fddac7a01 100644 --- a/paddle/cinn/frontend/decomposer/batch_norm.cc +++ b/paddle/cinn/frontend/decomposer/batch_norm.cc @@ -42,7 +42,9 @@ struct BatchNormHelper { reduce_dim = {0, 1, 2}; element_count = x_shape[0] * x_shape[1] * x_shape[2]; } else { - LOG(FATAL) << data_layout << " setting is not support!"; + std::stringstream ss; + ss << data_layout << " setting is not support!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } num_instructions = builder->size(); diff --git a/paddle/cinn/frontend/decomposer/broadcast.cc b/paddle/cinn/frontend/decomposer/broadcast.cc index ece85caccc7da..1067ec51981b8 100644 --- a/paddle/cinn/frontend/decomposer/broadcast.cc +++ b/paddle/cinn/frontend/decomposer/broadcast.cc @@ -14,6 +14,7 @@ #include "paddle/cinn/frontend/decomposer_registry.h" #include "paddle/cinn/frontend/syntax.h" +#include "paddle/common/enforce.h" namespace cinn { namespace frontend { @@ -51,10 +52,18 @@ void GetReduceDimsForY(const std::vector& dy_shape, void elementwise_add(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 2UL) - << " 2 input tensors for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 1UL) - << "1 output tensor for " << instr->op_type; + PADDLE_ENFORCE_EQ(instr->inputs.size(), + 2UL, + phi::errors::InvalidArgument( + "The size of inputs in elementwise_add is incorrect. " + "Expected size is 2, but receive %d. ", + instr->inputs.size())); + PADDLE_ENFORCE_EQ(instr->outputs.size(), + 1UL, + phi::errors::InvalidArgument( + "The size of outputs in elementwise_add is incorrect. " + "Expected size is 1, but receive %d. ", + instr->outputs.size())); auto x = instr->inputs[0]; auto y = instr->inputs[1]; auto output = instr->outputs[0]; @@ -120,17 +129,28 @@ void elementwise_add(const Instruction& instr, void elementwise_add_grad(const Instruction& instr, const DecomposerContext& context) { - CHECK_EQ(instr->inputs.size(), 3UL) - << " 3 input tensors for " << instr->op_type; - CHECK_EQ(instr->outputs.size(), 2UL) - << "2 output tensors for " << instr->op_type; + PADDLE_ENFORCE_EQ( + instr->inputs.size(), + 3UL, + phi::errors::InvalidArgument( + "The size of inputs in elementwise_add_grad is incorrect. " + "Expected size is 3, but receive %d. ", + instr->inputs.size())); + PADDLE_ENFORCE_EQ( + instr->outputs.size(), + 2UL, + phi::errors::InvalidArgument( + "The size of outputs in elementwise_add_grad is incorrect. " + "Expected size is 2, but receive %d. ", + instr->outputs.size())); auto dout = instr->inputs[0]; auto dx = instr->outputs[0]; auto dy = instr->outputs[1]; int axis = instr.GetAttrs("axis"); if (axis < 0 && dx->shape.size() < dy->shape.size()) { - LOG(FATAL) << "Please make sure x'rank greater than or equal to y'rank " - "when axis = -1"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Please make sure x'rank greater than or equal to y'rank " + "when axis = -1")); } axis = axis >= 0 ? axis : dx->shape.size() - dy->shape.size(); auto* builder = context.builder(); diff --git a/paddle/cinn/frontend/decomposer/test_helper.h b/paddle/cinn/frontend/decomposer/test_helper.h index 4a7bb9b2f8091..072ca29151147 100644 --- a/paddle/cinn/frontend/decomposer/test_helper.h +++ b/paddle/cinn/frontend/decomposer/test_helper.h @@ -89,8 +89,8 @@ void CopyFromVector(const std::vector& vec, #ifdef CINN_WITH_CUDA cudaMemcpy(data, vec.data(), numel * sizeof(T), cudaMemcpyHostToDevice); #else - LOG(FATAL) - << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check.")); #endif } else { std::copy(vec.begin(), vec.end(), data); diff --git a/paddle/cinn/frontend/decomposer_registry.h b/paddle/cinn/frontend/decomposer_registry.h index a94708db631d5..27cecf54501b7 100644 --- a/paddle/cinn/frontend/decomposer_registry.h +++ b/paddle/cinn/frontend/decomposer_registry.h @@ -38,18 +38,19 @@ class DecomposerContext { // Map the new var to the original var. void MapOutToOrigin(const Variable& new_var, const Variable& ori_var) const { if (new_var->shape != ori_var->shape) { - LOG(FATAL) - << "The output shape should be equal to the original. But received : " - << new_var->id << ".shape=[" << utils::Join(new_var->shape, ", ") - << "] and the original var " << ori_var->id << ".shape=[" - << utils::Join(ori_var->shape, ", ") << "]."; + std::stringstream ss; + ss << "The output shape should be equal to the original. But received : " + << new_var->id << ".shape=[" << utils::Join(new_var->shape, ", ") + << "] and the original var " << ori_var->id << ".shape=[" + << utils::Join(ori_var->shape, ", ") << "]."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (new_var->type != ori_var->type) { - LOG(FATAL) - << "The output type should be equal to the original. But received : " - << new_var->id << ".type=" << new_var->type - << " and the original var " << ori_var->id - << ".type=" << ori_var->type; + std::stringstream ss; + ss << "The output type should be equal to the original. But received : " + << new_var->id << ".type=" << new_var->type << " and the original var " + << ori_var->id << ".type=" << ori_var->type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } (*var_map_)[new_var->id] = ori_var; } diff --git a/paddle/cinn/frontend/group_cluster/CMakeLists.txt b/paddle/cinn/frontend/group_cluster/CMakeLists.txt new file mode 100644 index 0000000000000..3ade895bb2b6b --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/CMakeLists.txt @@ -0,0 +1,9 @@ +gather_srcs(group_cluster_src SRCS common_utils.cc pattern_node.cc + pattern_graph.cc) + +add_subdirectory(cluster_policy) + +cc_library( + group_cluster + SRCS ${group_cluster_src} + DEPS phi) diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/CMakeLists.txt b/paddle/cinn/frontend/group_cluster/cluster_policy/CMakeLists.txt new file mode 100644 index 0000000000000..7b86c45ca4dd9 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/CMakeLists.txt @@ -0,0 +1,3 @@ +gather_srcs(group_cluster_src SRCS general_topo_policy.cc policy_manager.cc + relative_judge_policy.cc) +add_subdirectory(shardable_axes_policy) diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.cc b/paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.cc new file mode 100644 index 0000000000000..2348701af3d99 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h" + +namespace cinn::frontend::group_cluster::policy { + +bool IsDownstreamNode(const PatternNodePtr start, const PatternNodePtr target) { + if (start == target) return true; + for (const auto& down_node : start->downstream_) { + if (IsDownstreamNode(down_node, target)) return true; + } + return false; +} + +bool IsIndirectDownstreamNode(const PatternNodePtr start, + const PatternNodePtr target) { + for (const auto& node : start->downstream_) { + if (node == target) continue; + if (IsDownstreamNode(node, target)) return true; + } + return false; +} + +bool GeneralTopoPolicy::CanFuse(const PatternNodePtr& first, + const PatternNodePtr& second) { + VLOG(4) << "Start GeneralTopoPolicy"; + return !(IsIndirectDownstreamNode(first, second) || + IsIndirectDownstreamNode(second, first)); +} + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h b/paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h new file mode 100644 index 0000000000000..ae0801a2fe402 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h @@ -0,0 +1,27 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" + +namespace cinn::frontend::group_cluster::policy { + +class GeneralTopoPolicy final : virtual public Policy { + public: + bool CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) override; + std::string Name() { return "GeneralTopoPolicy"; } +}; + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.cc b/paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.cc new file mode 100644 index 0000000000000..edbbe90ec315f --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" +#include "paddle/common/enforce.h" + +namespace cinn::frontend::group_cluster::policy { + +bool PolicyManager::CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) const { + for (const auto& policy : policies_) { + if (!policy->CanFuse(upstream, downstream)) return false; + } + return true; +} + +std::vector PolicyManager::GetFakeReduceIterIdx( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) const { + for (const auto& policy : policies_) { + if (policy->Name() == "RelativeJudgePolicy") { + return policy->GetFakeReduceIterIdx(upstream, downstream); + } + } + return {}; +} + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h b/paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h new file mode 100644 index 0000000000000..414b16f0e725e --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h @@ -0,0 +1,47 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/frontend/group_cluster/pattern_node.h" + +namespace cinn::frontend::group_cluster::policy { + +class Policy { + public: + virtual std::string Name() = 0; + virtual bool CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) = 0; + virtual std::vector GetFakeReduceIterIdx( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + return {}; + } +}; + +using PolicyPtr = std::shared_ptr; + +class PolicyManager { + public: + explicit PolicyManager(const std::vector& policies) + : policies_(policies) {} + bool CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) const; + std::vector GetFakeReduceIterIdx( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) const; + + private: + std::vector policies_; +}; + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.cc b/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.cc new file mode 100644 index 0000000000000..04db9a3401c03 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.cc @@ -0,0 +1,307 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h" + +namespace cinn::frontend::group_cluster::policy { + +bool RelativeJudgePolicy::IsDownstreamStmtDependReduceOp( + pir::Operation* reduce, const StmtPattern& downstream) { + const auto& values = GetPatternInputValues(downstream); + for (const auto& value : reduce->results()) { + if (std::find(values.begin(), values.end(), value) != values.end()) { + return true; + } + } + return false; +} + +std::optional RelativeJudgePolicy::GetDownstreamFromCandidate( + const ReducePattern& upstream, + const std::vector& candidates) { + pir::Operation* reduce = upstream.GetReduceOp(); + for (const auto& candidate : candidates) { + if (IsDownstreamStmtDependReduceOp(reduce, candidate)) { + return candidate; + } + } + return {}; +} + +SplitDims SplitReduceInputDimsIfRelatedWithNonReduceAxis( + const ShardableAxesSignature& signature, pir::Operation* op) { + const auto& v = op->operand_source(0); + const auto& input_names = signature.inputs[0].axis_names; + const auto& output_names = signature.outputs[0].axis_names; + std::set output_names_set(output_names.begin(), + output_names.end()); + auto result = SplitDims(); + int idx = 0; + for (const auto& in : input_names) { + if (output_names_set.count(in) == 0) { + result.non_related.emplace_back(v, idx); + } else { + result.related.emplace_back(v, idx); + } + idx += 1; + } + return result; +} + +SplitDims SplitReduceOutputDimsIfRelatedWithNonReduceAxis( + const ShardableAxesSignature& signature, const pir::Operation* op) { + const auto& v = op->result(0); + const auto& input_names = signature.inputs[0].axis_names; + const auto& output_names = signature.outputs[0].axis_names; + std::set input_names_set(input_names.begin(), input_names.end()); + auto result = SplitDims(); + int idx = 0; + for (const auto& name : output_names) { + if (input_names_set.count(name) == 0) { + result.non_related.emplace_back(v, idx); + } else { + result.related.emplace_back(v, idx); + } + idx += 1; + } + return result; +} + +bool RelativeJudgePolicy::IsBroadcastEdge( + const std::vector& upstream_out_dims, + const std::vector& downstream_reduce_dims) { + VLOG(4) << "IsBroadcastEdge: upstream_out_dims.size()" + << upstream_out_dims.size(); + VLOG(4) << "IsBroadcastEdge: downstream_reduce_dims.size()" + << downstream_reduce_dims.size(); + + for (const auto& downstream_reduce_dim : downstream_reduce_dims) { + for (const auto& upstream_out_dim : upstream_out_dims) { + VLOG(4) << "upstream_out_dim: " << upstream_out_dim.DebugStr() + << " downstream_reduce_dim: " << downstream_reduce_dim.DebugStr(); + if (IsRelated(upstream_out_dim, downstream_reduce_dim)) { + return false; + } + } + } + + VLOG(4) << "IsBroadcastEdge"; + return true; +} + +bool RelativeJudgePolicy::ReduceTreeGrownCanMerge( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + const auto& upstream_tree = + std::get(upstream->stmt_pattern_); + VLOG(4) << "upstream->stmt_pattern_:" + << OpsDebugStr(GetOpsInPattern(upstream_tree)); + const auto& downstream_tree = + std::get(downstream->stmt_pattern_); + VLOG(4) << "downstream->stmt_pattern_" + << OpsDebugStr(GetOpsInPattern(downstream_tree)); + const auto& maybe_downstream_op = GetDownstreamFromCandidate( + upstream_tree.GetRootPattern(), downstream_tree.reduce_patterns_); + int idx = 0; + for (const auto& r_pattern : downstream_tree.reduce_patterns_) { + idx += 1; + VLOG(4) << "downstream_tree.reduce_patterns_" + << "[" << idx << "]" << OpsDebugStr(GetOpsInPattern(r_pattern)); + } + if (!maybe_downstream_op.has_value()) { + VLOG(4) << "can't find candidate from patterns. can fuse return false."; + return false; + } + const pir::Value& reduce_out_value = + upstream_tree.GetRootPattern().GetReduceOp()->result(0); + pir::Operation* downstream_reduce_op = + maybe_downstream_op.value().GetReduceOp(); + const auto& split_reduce_dim_result = + SplitReduceInputDimsIfRelatedWithNonReduceAxis( + axes_info_.GetSignature(downstream_reduce_op), downstream_reduce_op); + VLOG(4) << split_reduce_dim_result.DebugStr(); + const auto& upstream_output_dims = GetAllValueDimFromValue(reduce_out_value); + auto res = IsBroadcastEdge(upstream_output_dims, + split_reduce_dim_result.non_related); + VLOG(4) << "ReduceTreeGrownCanMerge: " << res; + return res; +} + +SplitDims RelativeJudgePolicy::SplitDimsWithRelationship( + const std::vector& targets, + const std::vector& related_with) { + VLOG(4) << "SplitDimsWithRelationship"; + auto result = SplitDims(); + bool is_related = false; + for (auto& target_dim : targets) { + is_related = false; + for (auto& related_dim : related_with) { + if (IsRelated(related_dim, target_dim)) is_related = true; + } + if (is_related) { + result.related.push_back(target_dim); + } else { + result.non_related.push_back(target_dim); + } + } + + return result; +} + +bool DimsEqual(const std::vector& first, + const std::vector& second) { + const auto GetDimInfo = + [](const std::vector& dims) -> std::unordered_map { + std::unordered_map result; + for (const auto& dim : dims) { + VLOG(4) << "dim: " << dim.DebugStr(); + size_t value = dim.GetNumericValue(); + VLOG(4) << "value: " << value; + if (result.find(value) == result.end()) { + result[value] = 1; + } else { + result[value] += 1; + } + } + return result; + }; + VLOG(4) << "GetDimInfo"; + const std::unordered_map& first_dims = GetDimInfo(first); + VLOG(4) << "GetDimInfo"; + const std::unordered_map& second_dims = GetDimInfo(second); + if (first_dims.size() != second_dims.size()) return false; + for (const auto& [dim_value, count] : first_dims) { + if (second_dims.find(dim_value) == second_dims.end() || + second_dims.at(dim_value) != count) + return false; + } + return true; +} + +bool RelativeJudgePolicy::ReducePlusTrivialCanMerge( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + VLOG(4) << "RT can fuse"; + + // const auto& split_reduce_dims_result = + // SplitReduceInputDimsIfRelatedWithNonReduceAxis( + // axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + + // VLOG(4) << split_reduce_dims_result.DebugStr(); + + // const auto& upstream_reduce_dims = split_reduce_dims_result.non_related; + // const auto& upstream_non_reduce_dims = split_reduce_dims_result.related; + + // TODO(wuzhanfei) fix bug in relation that if has multi path in graph + // test_rms_norm can test + const auto& split_reduce_input_dims_result = + SplitReduceInputDimsIfRelatedWithNonReduceAxis( + axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + VLOG(4) << split_reduce_input_dims_result.DebugStr(); + const auto& upstream_reduce_dims = split_reduce_input_dims_result.non_related; + + const auto& split_reduce_output_dims_result = + SplitReduceOutputDimsIfRelatedWithNonReduceAxis( + axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + VLOG(4) << split_reduce_input_dims_result.DebugStr(); + const auto& upstream_non_reduce_dims = + split_reduce_output_dims_result.related; + // replace codes upside with original design + + const auto& split_trivial_dims_result = SplitDimsWithRelationship( + GetAllValueDimFromValue(downstream->sink_op_->result(0)), + upstream_non_reduce_dims); + + VLOG(4) << split_trivial_dims_result.DebugStr(); + + auto res = + DimsEqual(split_trivial_dims_result.non_related, upstream_reduce_dims); + res = res || IsFlattenDimSmaller(upstream, downstream); + VLOG(4) << "ReducePlusTrivialCanMerge: " << res; + return res; +} + +bool RelativeJudgePolicy::IsFlattenDimSmaller( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + const auto& split_reduce_dims_result = + SplitReduceInputDimsIfRelatedWithNonReduceAxis( + axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + const auto& upstream_reduce_dims = split_reduce_dims_result.non_related; + const auto& upstream_non_reduce_dims = split_reduce_dims_result.related; + + const auto& split_trivial_dims_result = SplitDimsWithRelationship( + GetAllValueDimFromValue(downstream->sink_op_->result(0)), + upstream_non_reduce_dims); + + VLOG(4) << "IsFlattenDimSmaller: " + << axes_info_.GetSignature(downstream->sink_op_).DebugStr(); + int rank = axes_info_.GetSignature(downstream->sink_op_) + .outputs[0] + .axis_names.size(); + VLOG(4) << "IsFlattenDimSmaller: " << rank << " " + << split_trivial_dims_result.related.size() << " " + << upstream_non_reduce_dims.size(); + bool res = (rank - split_trivial_dims_result.related.size()) <= + upstream_non_reduce_dims.size(); + VLOG(4) << "IsFlattenDimSmaller: " << res; + return res; +} + +bool RelativeJudgePolicy::CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) { + if (upstream->IsReduceTree() && downstream->IsTrivial()) { + return ReducePlusTrivialCanMerge(upstream, downstream); + } + if (upstream->IsReduceTree() && downstream->IsReduceTree()) { + return ReduceTreeGrownCanMerge(upstream, downstream); + } + return true; // other case. +} + +std::vector RelativeJudgePolicy::GetFakeReduceIterIdx( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + if (!upstream->IsReduceTree() || !downstream->IsTrivial()) { + PADDLE_THROW("Illegal Call GetFakeReduceIterIdx"); + } + + const auto& split_reduce_dims_result = + SplitReduceInputDimsIfRelatedWithNonReduceAxis( + axes_info_.GetSignature(upstream->sink_op_), upstream->sink_op_); + + const auto& upstream_reduce_dims = split_reduce_dims_result.non_related; + const auto& upstream_non_reduce_dims = split_reduce_dims_result.related; + + const auto& split_trivial_dims_result = SplitDimsWithRelationship( + GetAllValueDimFromValue(downstream->sink_op_->result(0)), + upstream_non_reduce_dims); + + const auto& trivial_reorder_dims = split_trivial_dims_result.non_related; + + // CHECK(upstream_reduce_dims.size() == trivial_reorder_dims.size() || + // trivial_reorder_dims.size() == 0); + std::unordered_set visited_dims; + std::vector result; + for (auto& reduce_dim : upstream_reduce_dims) { + for (auto& trivial_dim : trivial_reorder_dims) { + if (visited_dims.find(trivial_dim) == visited_dims.end() && + trivial_dim.GetNumericValue() == reduce_dim.GetNumericValue()) { + visited_dims.emplace(trivial_dim); + result.emplace_back(trivial_dim.idx_); + break; + } + } + } + VLOG(4) << "FakeReduceIterIdx: " << cinn::utils::Join(result, ", "); + return result; +} + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h b/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h new file mode 100644 index 0000000000000..e98b68dc893af --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h @@ -0,0 +1,301 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" +#include "paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.h" +#include "paddle/cinn/frontend/group_cluster/common_utils.h" + +namespace cinn::frontend::group_cluster::policy { + +struct ValueDim { + pir::Value v_; + size_t idx_; + ValueDim(pir::Value v, size_t idx) : v_(v), idx_(idx) {} + ValueDim() = default; + ValueDim(const ValueDim& v) = default; + bool operator==(const ValueDim& v) const { + return (idx_ == v.idx_) && (v_ == v.v_); + } + + size_t GetNumericValue() const { + return v_.type().dyn_cast().dims().at(idx_); + } + + std::string DebugStr() const { + std::ostringstream oss; + oss << "ValueDim: "; + oss << "Index: " << idx_; + oss << ", "; + v_.defining_op()->Print(oss); + return oss.str(); + } +}; + +struct ValueDimHash { + std::size_t operator()(const ValueDim& p) const { + auto h1 = std::hash{}(p.idx_); + auto h2 = std::hash{}(p.v_); + // Mainly for demonstration purposes, i.e. works but is overly simple + // In the real world, use sth. like boost.hash_combine + return h1 ^ (h2 << 1); + } +}; + +using ValueDimRelation = + std::unordered_map, + ValueDimHash>; +// ValueDimRelation[in][out] = True; means f(out) = in is related. + +static std::vector GetAllValueDimFromValue(const pir::Value& v) { + std::vector value_dims; + size_t rank = GetRank(v); + for (size_t i = 0; i < rank; ++i) { + value_dims.emplace_back(v, i); + } + return value_dims; +} + +static std::vector GetAllInputValueDim(pir::Operation* op) { + std::vector value_dims; + for (const auto& v : op->operands()) { + value_dims = ConcatVector(value_dims, GetAllValueDimFromValue(v.source())); + } + return value_dims; +} + +static std::vector GetAllOutputValueDim(pir::Operation* op) { + std::vector value_dims; + for (const auto& v : op->results()) { + value_dims = ConcatVector(value_dims, GetAllValueDimFromValue(v)); + } + return value_dims; +} + +static ValueDimRelation CreateOpRelativenessForElementWise(pir::Operation* op) { + ValueDimRelation res; + for (const auto& v : op->operands()) { + const auto& value_dims = GetAllValueDimFromValue(v.source()); + const auto& out_value_dims = GetAllOutputValueDim(op); + CHECK_EQ(value_dims.size(), out_value_dims.size()); + for (size_t i = 0; i < value_dims.size(); ++i) { + res[value_dims[i]][out_value_dims[i]] = true; + } + } + return res; +} + +static std::vector> GetNonBroadCastDims( + pir::Operation* op) { + std::vector> res; + const auto* shape_analysis = + &pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + + const auto& broad_cast_value = GetBroadcastOpInputOuputValue(op); + CHECK(broad_cast_value.has_value()); + + const auto& [input_value, output_value] = broad_cast_value.value(); + const int input_rank = GetRank(input_value); + const int output_rank = GetRank(output_value); + CHECK_GE(output_rank, input_rank); + + // Compare axis one by one, from back to front. + // The rule of broadcasting: + // https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/tensor_cn.html#id7 + for (int i = 1; i <= input_rank; ++i) { + int input_axis = input_rank - i; + int output_axis = output_rank - i; + if (input_axis < 0 || output_axis < 0) break; + if (shape_analysis->IsProductEqual( + input_value, {input_axis}, output_value, {output_axis})) { + res.emplace_back(input_axis, output_axis); + } + } + + return res; +} + +static ValueDimRelation CreateOpRelativenessForBroadcast(pir::Operation* op) { + ValueDimRelation res; + const auto& in_value = op->operand(0).source(); + const auto& out_value = op->result(0); + for (const auto& t : GetNonBroadCastDims(op)) { + res[ValueDim(in_value, t.first)][ValueDim(out_value, t.second)] = true; + } + return res; +} + +static ValueDimRelation CreateOpRelativenessForDefault(pir::Operation* op) { + ValueDimRelation res; + for (const auto& out_dim : GetAllOutputValueDim(op)) { + for (const auto& in_dim : GetAllInputValueDim(op)) { + res[in_dim][out_dim] = true; + } + } + return res; +} + +static ValueDimRelation CreateOpRelativenessForReduce(pir::Operation* op) { + const auto& reduce_axis_idx = GetReduceAxisIdx(op); + ValueDimRelation res; + const size_t input_rank = GetRank(op->operand_source(0)); + int out_idx = 0; + bool keep_dim = GetReduceOpKeepDims(op); + for (int i = 0; i < input_rank; i++) { + if (std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) != + reduce_axis_idx.end()) { + res[ValueDim(op->operand_source(0), i)] + [ValueDim(op->result(0), out_idx)] = true; + out_idx += 1; + } else { + out_idx += keep_dim; + } + } + return res; +} + +static std::optional CreateOpRelativenessForSpecialOps( + pir::Operation* op) { + if (op->name() == "cinn_op.reshape") { + // Special Elementwise. + return CreateOpRelativenessForDefault(op); + } + if (op->name() == "pd_op.reshape") { + // Special Elementwise. + return CreateOpRelativenessForDefault(op); + } + if (op->name() == "cinn_op.generate_shape") { + return CreateOpRelativenessForDefault(op); + } + if (op->name() == "cinn_op.yield_store") { + return CreateOpRelativenessForDefault(op); + } + return {}; +} + +static ValueDimRelation GetSingleOpRelation(pir::Operation* op) { + VLOG(4) << "GetSingleOpRelation for " << op->name(); + const auto& special_result = CreateOpRelativenessForSpecialOps(op); + if (special_result != std::nullopt) { + return special_result.value(); + } + + CHECK(op->num_results() == 1) + << "Now we do not support op with multi outputs: " << op->name(); + const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); + ValueDimRelation result; + if (kind == hlir::framework::kReduction) { + result = CreateOpRelativenessForReduce(op); + } else if (kind == hlir::framework::kElementWise) { + result = CreateOpRelativenessForElementWise(op); + } else if (kind == hlir::framework::kBroadcast) { + result = CreateOpRelativenessForBroadcast(op); + } else { + result = CreateOpRelativenessForDefault(op); + } + return result; +} + +static std::vector> FlattenRelation( + const ValueDimRelation& axes_relation) { + std::vector> res; + for (const auto& in_dim_pair : axes_relation) { + for (const auto& out_dim_pair : in_dim_pair.second) { + res.emplace_back(in_dim_pair.first, out_dim_pair.first); + } + } + return res; +} + +static ValueDimRelation AnalysisIndexExprRelation( + const std::vector& ops) { + ValueDimRelation res; + + for (size_t i = ops.size(); i >= 1; --i) { + pir::Operation* op = ops[i - 1]; + if (op->name() == "cf.yield") continue; + + const auto& value_dim_relation = GetSingleOpRelation(op); + for (const auto& in_out_pair : FlattenRelation(value_dim_relation)) { + for (const auto& out_relation : res[in_out_pair.second]) { + res[in_out_pair.first][out_relation.first] = true; + } + res[in_out_pair.first][in_out_pair.second] = true; + } + } + return res; +} + +struct SplitDims { + std::vector related; + std::vector non_related; + + std::string DebugStr() const { + std::stringstream ss; + ss << "SplitDims:\nrelated:\n"; + for (const auto& dim : related) { + ss << dim.DebugStr() << "\n"; + } + ss << "non_related:\n"; + for (const auto& dim : non_related) { + ss << dim.DebugStr() << "\n"; + } + return ss.str(); + } +}; + +class RelativeJudgePolicy final : public Policy { + public: + RelativeJudgePolicy(const std::vector& ops, + const pir::ShapeConstraintIRAnalysis* shape_analysis) + : axes_info_(ops, shape_analysis) { + VLOG(4) << "[relative_judge_policy] Start AnalysisIndexExprRelation."; + index_expr_map_ = AnalysisIndexExprRelation(ops); + VLOG(4) << "[relative_judge_policy] End AnalysisIndexExprRelation."; + } + bool CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) override; + + std::string Name() { return "RelativeJudgePolicy"; } + + std::vector GetFakeReduceIterIdx( + const PatternNodePtr& upstream, + const PatternNodePtr& downstream) override; + + bool IsRelated(ValueDim in, ValueDim out) { + return index_expr_map_[in].count(out) == 1; + } + + private: + ValueDimRelation index_expr_map_; + ShardableAxesInfoManager axes_info_; + bool ReduceTreeGrownCanMerge(const PatternNodePtr&, const PatternNodePtr&); + bool IsFlattenDimSmaller(const PatternNodePtr& upstream, + const PatternNodePtr& downstream); + bool ReducePlusTrivialCanMerge(const PatternNodePtr&, const PatternNodePtr&); + SplitDims SplitDimsWithRelationship( + const std::vector& targets, + const std::vector& related_with); + std::optional GetDownstreamFromCandidate( + const ReducePattern& upstream, + const std::vector& candidates); + bool IsDownstreamStmtDependReduceOp(pir::Operation* reduce, + const StmtPattern& downstream); + bool IsBroadcastEdge(const std::vector& upstream_out_dims, + const std::vector&); +}; + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/CMakeLists.txt b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/CMakeLists.txt new file mode 100644 index 0000000000000..8d3f64fa5bc96 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/CMakeLists.txt @@ -0,0 +1,2 @@ +gather_srcs(group_cluster_src SRCS shardable_axes_base.cc + shardable_axes_policy.cc) diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.cc b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.cc new file mode 100644 index 0000000000000..f14f9b3051de2 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.cc @@ -0,0 +1,306 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.h" +#include "paddle/cinn/frontend/group_cluster/common_utils.h" + +namespace cinn::frontend::group_cluster::policy { + +ShardableAxes ShardableAxesInfoManager::ReplaceShardableAxesWithRootName( + const ShardableAxes& axes) { + std::vector names; + for (auto name : axes.axis_names) { + names.push_back(name_union_[name]); + } + return ShardableAxes(names); +} + +ShardableAxesSignature ShardableAxesInfoManager::GetSignature( + pir::Operation* op) { + return op_signature_map_[op]; + // TODO(baizhou) fix broadcast signature and enable here + // auto result = ShardableAxesSignature(); + // auto origin_sig = op_signature_map_[op]; + // for (const auto& axes : origin_sig.inputs) { + // result.inputs.emplace_back(ReplaceShardableAxesWithRootName(axes)); + // } + // for (const auto& axes : origin_sig.outputs) { + // result.outputs.emplace_back(ReplaceShardableAxesWithRootName(axes)); + // } + // return result; +} + +ShardableAxes ShardableAxesInfoManager::GetAxes(pir::Value value) { + return ReplaceShardableAxesWithRootName(value_axes_map_[value]); +} + +std::string ShardableAxesInfoManager::GetUniqueName() { + static std::atomic counter = 0; + counter += 1; + return "D" + std::to_string(counter); +} + +std::vector CreateNewNamesWithRank(int64_t rank) { + auto result = std::vector(); + for (int64_t i = 0; i < rank; i++) { + result.emplace_back(ShardableAxesInfoManager::GetUniqueName()); + } + return result; +} + +ShardableAxesSignature CreateDefaultSignature(pir::Operation* op) { + ShardableAxesSignature result = ShardableAxesSignature(); + for (int i = 0; i < op->num_operands(); ++i) { + result.inputs.emplace_back( + CreateNewNamesWithRank(GetRank(op->operand_source(i)))); + } + for (int i = 0; i < op->num_results(); ++i) { + result.outputs.emplace_back(CreateNewNamesWithRank(GetRank(op->result(i)))); + } + return result; +} + +std::optional CreateSignatureForSpecialOps( + pir::Operation* op) { + if (op->isa()) { + return CreateDefaultSignature(op); + } + if (op->name() == "cinn_op.generate_shape") { + return CreateDefaultSignature(op); + } + if (op->name() == "cinn_op.yield_store") { + return CreateDefaultSignature(op); + } + if (op->name() == "cinn_op.reshape") { + return CreateDefaultSignature(op); + } + if (op->name() == "pd_op.reshape") { + return CreateDefaultSignature(op); + } + return std::nullopt; +} + +ShardableAxesSignature CreateSignatureForReduce(pir::Operation* reduce_op) { + CHECK_EQ(reduce_op->num_operands(), 1); + CHECK_EQ(reduce_op->num_results(), 1); + ShardableAxesSignature result = ShardableAxesSignature(); + const size_t input_rank = GetRank(reduce_op->operand_source(0)); + auto input_axes = CreateNewNamesWithRank(input_rank); + + const auto& reduce_axis_idx = GetReduceAxisIdx(reduce_op); + bool keep_dim = GetReduceOpKeepDims(reduce_op); + auto output_axes = std::vector(); + + for (int i = 0; i < input_rank; i++) { + if (std::find(reduce_axis_idx.begin(), reduce_axis_idx.end(), i) != + reduce_axis_idx.end()) { + if (keep_dim) { + output_axes.emplace_back(ShardableAxesInfoManager::GetUniqueName()); + } // else do nothing + } else { + output_axes.emplace_back(input_axes[i]); + } + } + + result.inputs.emplace_back(input_axes); + result.outputs.emplace_back(output_axes); + + return result; +} + +ShardableAxesSignature CreateSignatureForElementWise(pir::Operation* op) { + ShardableAxesSignature result = ShardableAxesSignature(); + + int64_t rank = GetRank(op->result(0)); + auto same_axes = CreateNewNamesWithRank(rank); + + for (int i = 0; i < op->num_operands(); ++i) { + CHECK(rank == GetRank(op->operand_source(i))); + result.inputs.emplace_back(same_axes); + } + for (int i = 0; i < op->num_results(); ++i) { + CHECK(rank == GetRank(op->result(i))); + result.outputs.emplace_back(same_axes); + } + return result; +} + +ShardableAxesSignature CreateSignatureForBroadcast( + pir::Operation* op, const pir::ShapeConstraintIRAnalysis* shape_analysis) { + ShardableAxesSignature result = ShardableAxesSignature(); + + const auto& broad_cast_value = GetBroadcastOpInputOuputValue(op); + CHECK(broad_cast_value.has_value()); + + const auto& [input_value, output_value] = broad_cast_value.value(); + const int input_rank = GetRank(input_value); + const int output_rank = GetRank(output_value); + CHECK_GE(output_rank, input_rank); + + // Create axes for operands. For expand op, the second operand is the shape of + // output. + for (int i = 0; i < op->num_operands(); ++i) { + result.inputs.emplace_back( + CreateNewNamesWithRank(GetRank(op->operand_source(i)))); + } + + // Create output axes. Compare axis one by one, from back to front. + // The rule of broadcasting: + // https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/tensor_cn.html#id7 + const auto& input_axis_names = result.inputs[0].axis_names; + std::vector output_axis_names; + for (int i = 1; i <= output_rank; ++i) { + int input_axis = input_rank - i; + int output_axis = output_rank - i; + if ((input_axis >= 0) && + shape_analysis->IsProductEqual( + input_value, {input_axis}, output_value, {output_axis})) { + output_axis_names.emplace_back(input_axis_names[input_axis]); + } else { + output_axis_names.emplace_back(ShardableAxesInfoManager::GetUniqueName()); + } + } + std::reverse(output_axis_names.begin(), output_axis_names.end()); + result.outputs.emplace_back(ShardableAxes(output_axis_names)); + + return result; +} + +ShardableAxesSignature ShardableAxesInfoManager::CreateShardableSignature( + pir::Operation* op) { + auto special_result = CreateSignatureForSpecialOps(op); + if (special_result != std::nullopt) { + VLOG(4) << "[ShardableAxesInfoManager] Create Shardable Axes Signature : \n" + << op->name() << " : " << special_result.value().DebugStr(); + return special_result.value(); + } + + CHECK(op->num_results() == 1) + << "Now we do not support op with multi outputs: " << op->name(); + ShardableAxesSignature result; + const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); + if (kind == hlir::framework::kReduction) { + result = CreateSignatureForReduce(op); + } else if (kind == hlir::framework::kElementWise) { + result = CreateSignatureForElementWise(op); + } else if (kind == hlir::framework::kBroadcast) { + result = CreateSignatureForBroadcast(op, shape_analysis_); + } else { + result = CreateDefaultSignature(op); + } + VLOG(4) << "[ShardableAxesInfoManager] Create Shardable Axes Signature : \n" + << op->name() << " : " << result.DebugStr(); + return result; +} + +ShardableAxesInfoManager::ShardableAxesInfoManager( + const std::vector& ops, + const pir::ShapeConstraintIRAnalysis* shape_analysis) + : ops_(ops), shape_analysis_(shape_analysis) { + for (const auto& op : ops) { + if (op->name() == "cf.yield") continue; + op_signature_map_[op] = CreateShardableSignature(op); + } + + const auto FindRoot = [&](std::string non_root) { + std::string result = non_root; + while (name_union_[result] != result) { + result = name_union_[result]; + } + return result; + }; + + const auto CombineAxes = [&](const ShardableAxes& root, + const ShardableAxes& non_root) { + CHECK_EQ(root.axis_names.size(), non_root.axis_names.size()); + for (int i = 0; i < non_root.axis_names.size(); i++) { + name_union_[non_root.axis_names[i]] = FindRoot(root.axis_names[i]); + } + }; + + for (const auto& [op, axes_signature] : op_signature_map_) { + for (int i = 0; i < op->num_operands(); ++i) { + auto value = op->operand_source(i); + auto axes = axes_signature.inputs[i]; + if (value_axes_map_.find(value) == value_axes_map_.end()) { + value_axes_map_[value] = axes; + for (auto& axis_name : axes.axis_names) { + name_union_[axis_name] = axis_name; + } + } else { + CombineAxes(value_axes_map_[value], axes); + } + } + for (int i = 0; i < op->num_results(); ++i) { + auto value = op->result(i); + auto axes = axes_signature.outputs[i]; + if (value_axes_map_.find(value) == value_axes_map_.end()) { + value_axes_map_[value] = axes; + for (auto& axis_name : axes.axis_names) { + name_union_[axis_name] = axis_name; + } + } else { + CombineAxes(value_axes_map_[value], axes); + } + } + } + + VLOG(4) << NameUnionDebugStr(); +} + +std::string ShardableAxes::DebugStr() const { + std::stringstream ss; + for (const auto& name : axis_names) { + ss << name << ", "; + } + return ss.str(); +} + +std::string ShardableAxesSignature::DebugStr() const { + std::stringstream ss; + ss << "ShardableAxes Signature:\n"; + for (int i = 0; i < inputs.size(); i++) { + ss << "input " << i << ": " << inputs[i].DebugStr() << "\n"; + } + for (int i = 0; i < outputs.size(); i++) { + ss << "output " << i << ": " << outputs[i].DebugStr() << "\n"; + } + return ss.str(); +} + +std::string ShardableAxesInfoManager::NameUnionDebugStr() const { + std::stringstream ss; + ss << "[ShardableAxesInfoManager] NameUnion :\n"; + + std::unordered_map> root_to_sons; + for (const auto& [non_root, root] : name_union_) { + if (root_to_sons.find(root) == root_to_sons.end()) { + root_to_sons[root] = std::vector{non_root}; + } else { + root_to_sons[root].push_back(non_root); + } + } + for (const auto& [root, sons] : root_to_sons) { + ss << "Root " << root << ": "; + for (const auto& son : sons) { + ss << son << ", "; + } + ss << "\n"; + } + + return ss.str(); +} + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.h b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.h new file mode 100644 index 0000000000000..b2795f944f938 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.h @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/frontend/group_cluster/common_utils.h" + +namespace cinn::frontend::group_cluster::policy { + +struct ShardableAxes { + ShardableAxes() : axis_names({}) {} + explicit ShardableAxes(const std::vector& names) + : axis_names(names) {} + std::vector axis_names; + std::string DebugStr() const; +}; + +struct ShardableAxesSignature { + std::vector inputs; + std::vector outputs; + std::string DebugStr() const; +}; + +struct ShardableAxesInfoManager { + ShardableAxesInfoManager( + const std::vector& ops, + const pir::ShapeConstraintIRAnalysis* shape_analysis); + ShardableAxesSignature GetSignature(pir::Operation* op); + ShardableAxes GetAxes(pir::Value value); + ShardableAxesSignature CreateShardableSignature(pir::Operation* op); + ShardableAxes ReplaceShardableAxesWithRootName(const ShardableAxes& axes); + static std::string GetUniqueName(); + std::string NameUnionDebugStr() const; + + private: + const std::vector& ops_; + const pir::ShapeConstraintIRAnalysis* shape_analysis_; + + std::unordered_map op_signature_map_; + std::unordered_map value_axes_map_; + std::unordered_map name_union_; +}; + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.cc b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.cc new file mode 100644 index 0000000000000..17606d0cf771c --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.h" + +namespace cinn::frontend::group_cluster::policy { + +bool ShardableAxesRRFusePolicy::IsDownstreamStmtDependReduceOp( + pir::Operation* reduce, const StmtPattern& downstream) { + const auto& values = GetPatternInputValues(downstream); + for (const auto& value : reduce->results()) { + if (std::find(values.begin(), values.end(), value) != values.end()) { + return true; + } + } + return false; +} + +std::optional +ShardableAxesRRFusePolicy::GetDownstreamFromCandidate( + const ReducePattern& upstream, + const std::vector& candidates) { + pir::Operation* reduce = upstream.GetReduceOp(); + for (const auto& candidate : candidates) { + if (IsDownstreamStmtDependReduceOp(reduce, candidate)) { + return candidate; + } + } + return {}; +} + +static std::set GetReduceAxesName( + const ShardableAxesSignature& signature) { + const auto& input_names = signature.inputs[0].axis_names; + const auto& output_names = signature.outputs[0].axis_names; + std::set res(input_names.begin(), input_names.end()); + for (const auto& n : output_names) { + res.erase(n); + } + return res; +} + +bool ShardableAxesRRFusePolicy::ReduceTreeGrownCanMerge( + const PatternNodePtr& upstream, const PatternNodePtr& downstream) { + if (!upstream->IsReduceTree() || !downstream->IsReduceTree()) { + return false; + } + const auto& upstream_tree = + std::get(upstream->stmt_pattern_); + const auto& downstream_tree = + std::get(downstream->stmt_pattern_); + const auto& maybe_downstream_op = GetDownstreamFromCandidate( + upstream_tree.GetRootPattern(), downstream_tree.reduce_patterns_); + if (!maybe_downstream_op.has_value()) { + return false; + } + const pir::Value& reduce_out_value = + upstream_tree.GetRootPattern().GetReduceOp()->result(0); + pir::Operation* downstream_reduce_op = + maybe_downstream_op.value().GetReduceOp(); + const auto& reduce_names = + GetReduceAxesName(axes_info_.GetSignature(downstream_reduce_op)); + for (const auto& n : + axes_info_.GetAxes(downstream_reduce_op->result(0)).axis_names) { + if (reduce_names.count(n) > 0) { + // not meeting the BroadcastEdge condition. + return false; + } + } + return true; +} + +bool ShardableAxesRRFusePolicy::CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) { + // TODO(wuzhanfei) shardable axes policy + return ReduceTreeGrownCanMerge(upstream, downstream); +} + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.h b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.h new file mode 100644 index 0000000000000..1917d2f5af4df --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" +#include "paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_base.h" + +namespace cinn::frontend::group_cluster::policy { + +class ShardableAxesRRFusePolicy final : public Policy { + public: + ShardableAxesRRFusePolicy( + const std::vector& ops, // NOLINT + const pir::ShapeConstraintIRAnalysis* shape_analysis) // NOLINT + : axes_info_(ops, shape_analysis) {} + bool CanFuse(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) override; + std::string Name() { return "ShardableAxesRRFusePolicy"; } + + private: + bool ReduceTreeGrownCanMerge(const PatternNodePtr&, const PatternNodePtr&); + std::optional GetDownstreamFromCandidate( + const ReducePattern& upstream, + const std::vector& candidates); + ShardableAxesInfoManager axes_info_; + bool IsDownstreamStmtDependReduceOp(pir::Operation* reduce, + const StmtPattern& downstream); +}; + +} // namespace cinn::frontend::group_cluster::policy diff --git a/paddle/cinn/frontend/group_cluster/common_utils.cc b/paddle/cinn/frontend/group_cluster/common_utils.cc new file mode 100644 index 0000000000000..36280069aca18 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/common_utils.cc @@ -0,0 +1,199 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/common_utils.h" + +namespace cinn::frontend::group_cluster { + +OpPatternKind GetOpPatternKind(const ::pir::Operation* op) { + return hlir::framework::pir::CompatibleInfo::OpKind(*op); +} + +size_t GetRank(pir::Value value) { + return value.type().dyn_cast().dims().size(); +} + +std::vector GetReduceAxisIdx(pir::Operation* reduce_op) { + const size_t input_rank = GetRank(reduce_op->operand_source(0)); + const auto& attr_val = reduce_op->attributes().at("dim"); + CHECK(attr_val.isa<::pir::ArrayAttribute>()); + const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); + std::vector reduce_axis_idx; + for (int i = 0; i < axis_attr.size(); ++i) { + int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data(); + if (axis < 0) { + axis += input_rank; + } + CHECK_GE(axis, 0); + CHECK_LT(axis, input_rank); + reduce_axis_idx.push_back(axis); + } + VLOG(4) << "GetReduceAxisIdx: " << utils::Join(reduce_axis_idx, ","); + return reduce_axis_idx; +} + +bool GetReduceOpKeepDims(pir::Operation* reduce_op) { + const auto& attr_val = reduce_op->attributes().at("keep_dim"); + CHECK(attr_val.isa<::pir::BoolAttribute>()); + return attr_val.dyn_cast<::pir::BoolAttribute>().data(); +} + +std::string GetPatternName(const StmtPattern& s) { + return std::visit([](const auto& impl) { return impl.name(); }, s); +} + +std::string OpsDebugStr(std::vector ops) { + std::stringstream ss; + pir::IrPrinter printer(ss); + for (const auto* op : ops) { + printer.PrintOperation(const_cast(op)); + ss << "\n"; + } + return ss.str(); +} + +std::optional> GetBroadcastOpInputOuputValue( + pir::Operation* op) { + auto* mut_op = const_cast(op); + if (op->isa()) { + auto expand_op = mut_op->dyn_cast(); + return std::make_pair(expand_op.x(), expand_op.out()); + } else if (op->isa()) { + auto broadcast_op = mut_op->dyn_cast(); + return std::make_pair(broadcast_op.x(), broadcast_op.out()); + } else { + CHECK(false) << "Unsupported broadcast op: " << op->name(); + } + return std::nullopt; +} +} // namespace cinn::frontend::group_cluster + +namespace cinn::frontend::group_cluster { + +bool IsTrivialPattern(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +bool IsReducePattern(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +bool IsReduceTreePattern(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +bool IsOpsDependents(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +bool IsUnsupportPattern(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +bool IsReduceTrivialPattern(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +std::unordered_set GetPatternInputValuesIncludeInner( + const StmtPattern& A) { + std::unordered_set result; + for (const auto& op : GetOpsInPattern(A)) { + for (const auto& value : op->operands()) { + result.insert(value.source()); + } + } + return result; +} + +std::unordered_set GetPatternOutputValuesIncludedInner( + const StmtPattern& A) { + std::unordered_set result; + for (const auto& op : GetOpsInPattern(A)) { + for (const auto& value : op->results()) { + result.insert(value); + } + } + return result; +} + +std::unordered_set GetPatternInputValues(const StmtPattern& A) { + auto all_input_values = GetPatternInputValuesIncludeInner(A); + for (const auto& value : GetPatternOutputValuesIncludedInner(A)) { + all_input_values.erase(value); + } + VLOG(4) << "GetPatternInputValues: " << all_input_values.size(); + return all_input_values; +} + +std::vector GetOpsInPattern(const StmtPattern& pattern) { + return std::visit([](const auto& impl) { return impl.ops(); }, pattern); +} + +std::string StmtPatternDebugStr(const StmtPattern& stmt) { + std::stringstream ss; + auto all_ops = GetOpsInPattern(stmt); + ss << "StmtPattern, size " << all_ops.size() << " :\n"; + ss << OpsDebugStr(all_ops); + return ss.str(); +} + +StmtPattern MergePattern(const StmtPattern& first, const StmtPattern& second) { + std::vector ops = + MergeVector(GetOpsInPattern(first), GetOpsInPattern(second)); + if (IsUnsupportPattern(first) || IsUnsupportPattern(second)) { + return UnsupportPattern(ops); + } else if (IsReduceTreePattern(first) && IsReduceTreePattern(second)) { + const auto& merged = + ConcatVector(std::get(first).reduce_patterns_, + std::get(second).reduce_patterns_); + return ReduceTreePattern( + merged, std::get(second).GetRootPattern()); + } else if (IsReduceTreePattern(first) && IsTrivialPattern(second)) { + return ReduceTreePlusTrivialPattern(std::get(first), + std::get(second)); + } else if (IsTrivialPattern(first) && IsReducePattern(second)) { + return ReducePattern(ops); + } else if (IsTrivialPattern(first) && IsTrivialPattern(second)) { + return TrivialPattern(ops); + } else if (IsHorizontalFusionPattern(first) && + IsHorizontalFusionPattern(second)) { + return HorizontalFusionPattern(ops); + } else { + // Not Implementation. + CHECK(false) << "Found not support merge!"; + } +} + +bool IsHorizontalFusionPattern(const StmtPattern& pattern) { + return std::holds_alternative(pattern); +} + +StmtPattern ConvertToStmtPattern(pir::Operation* op) { + const auto& kind = GetOpPatternKind(op); + if (kind == hlir::framework::kReduction) { + return ReducePattern({op}); + } else if (kind == hlir::framework::kElementWise || + kind == hlir::framework::kBroadcast || + kind == hlir::framework::kInjective) { + return TrivialPattern({op}); + } else { + return UnsupportPattern({op}); + } +} + +ReducePattern ToReducePattern(const StmtPattern& second) { + return std::get(second); +} + +} // namespace cinn::frontend::group_cluster diff --git a/paddle/cinn/frontend/group_cluster/common_utils.h b/paddle/cinn/frontend/group_cluster/common_utils.h new file mode 100644 index 0000000000000..2430facb703e5 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/common_utils.h @@ -0,0 +1,121 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "glog/logging.h" + +#include "paddle/cinn/frontend/group_cluster/pattern.h" + +#include "paddle/cinn/common/bfs_walker.h" +#include "paddle/cinn/common/topo_walker.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/utils/string.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn::frontend::group_cluster { + +using OpPatternKind = cinn::hlir::framework::OpPatternKind; + +OpPatternKind GetOpPatternKind(const ::pir::Operation* op); +size_t GetRank(pir::Value value); +std::vector GetReduceAxisIdx(pir::Operation* reduce_op); +bool GetReduceOpKeepDims(pir::Operation* reduce_op); +std::string OpsDebugStr(std::vector ops); +std::optional> GetBroadcastOpInputOuputValue( + pir::Operation* op); +} // namespace cinn::frontend::group_cluster + +namespace cinn::frontend::group_cluster { + +bool IsTrivialPattern(const StmtPattern& pattern); +bool IsHorizontalFusionPattern(const StmtPattern& pattern); +bool IsReducePattern(const StmtPattern& pattern); +bool IsReduceTreePattern(const StmtPattern& pattern); +bool IsUnsupportPattern(const StmtPattern& pattern); +bool IsReduceTrivialPattern(const StmtPattern& pattern); + +template +void RemoveFromVector(std::vector* vec, T item) { + auto iter = std::find(vec->begin(), vec->end(), item); + if (iter != vec->end()) { + vec->erase(iter); + } +} + +template +std::vector ConcatVector(const std::vector& first, + const std::vector& second) { + std::vector result = first; + result.insert(result.end(), second.begin(), second.end()); + return result; +} + +template +std::vector FilterVector(const std::vector& first, const F& func) { + std::vector result; + for (const auto& i : first) { + if (func(i)) { + result.push_back(i); + } + } + return result; +} + +template +std::set ToSet(const std::vector& input) { + std::set result(input.begin(), input.end()); + return result; +} + +template +bool IsAnyFirstInSecond(const std::vector& first, + const std::vector& second) { + const auto& second_set = ToSet(second); + for (const auto& ele : first) { + if (second_set.count(ele)) { + return true; + } + } + return false; +} + +template +std::vector UniqueVectorBySet(const std::vector& v) { + std::set unique(v.begin(), v.end()); + return std::vector(unique.begin(), unique.end()); +} + +std::vector GetOpsInPattern(const StmtPattern& pattern); +std::string StmtPatternDebugStr(const StmtPattern& pattern); +StmtPattern MergePattern(const StmtPattern& first, const StmtPattern& second); +ReducePattern ToReducePattern(const StmtPattern& second); +std::string GetPatternName(const StmtPattern& s); + +StmtPattern ConvertToStmtPattern(pir::Operation* op); +std::unordered_set GetPatternInputValues(const StmtPattern& A); +} // namespace cinn::frontend::group_cluster diff --git a/paddle/cinn/frontend/group_cluster/group_cluster.h b/paddle/cinn/frontend/group_cluster/group_cluster.h new file mode 100644 index 0000000000000..5a09b5e2ace95 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/group_cluster.h @@ -0,0 +1,83 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/frontend/group_cluster/cluster_policy/general_topo_policy.h" +#include "paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h" +#include "paddle/cinn/frontend/group_cluster/cluster_policy/shardable_axes_policy/shardable_axes_policy.h" +#include "paddle/cinn/frontend/group_cluster/pattern_graph.h" + +namespace cinn::frontend { + +inline std::vector ClusterOps( + const std::vector& origin_ops, + bool with_horizontal_fusion = false) { + CHECK_GT(origin_ops.size(), 0); + VLOG(4) << "Start Cluster Ops!"; + VLOG(4) << "Input Group with size " << origin_ops.size() << " :\n" + << group_cluster::OpsDebugStr(origin_ops); + + std::vector outputs; + const auto& ops = [&] { + std::vector ops; + for (const auto& op : origin_ops) { + if (op->name() == "cf.yield") { // just skip cf.yield. + for (auto& operand : op->operands()) { + outputs.push_back(operand.source()); + } + continue; + } + ops.emplace_back(op); + } + return ops; + }(); + + pir::Program* program = ops.at(0)->GetParentProgram(); + + const auto* shape_analysis = + &pir::ShapeAnalysisManager::Instance().Get(program); + + // const auto& shardable_axes_policy = + // std::make_shared( + // ops, shape_analysis); + VLOG(4) << "Start Create Policies and PolicyManager!"; + const auto& relative_judge_policy = + std::make_shared( + ops, shape_analysis); + + const auto& general_topo_policy = + std::make_shared(); + + auto policy_manager = group_cluster::policy::PolicyManager( + {relative_judge_policy, general_topo_policy}); + + auto topo_manager = group_cluster::policy::PolicyManager( + {relative_judge_policy, general_topo_policy}); + + VLOG(4) << "Start Create PatternGraph"; + group_cluster::PatternGraph graph(ops, outputs, policy_manager, topo_manager); + auto result = graph.ClusterOps(with_horizontal_fusion); + + VLOG(4) << "End Cluster Ops! result size:" << result.size(); + for (const auto& node : result) { + VLOG(4) << "\n" + << node->DebugStr() << "\n" + << group_cluster::StmtPatternDebugStr(node->stmt_pattern_); + } + + return result; +} + +} // namespace cinn::frontend diff --git a/paddle/cinn/frontend/group_cluster/pattern.h b/paddle/cinn/frontend/group_cluster/pattern.h new file mode 100644 index 0000000000000..03947b312565f --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/pattern.h @@ -0,0 +1,123 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "glog/logging.h" +#include "paddle/pir/include/core/operation.h" + +namespace cinn::frontend::group_cluster { + +class TrivialPattern; +class ReducePattern; +class ReduceTreePattern; +class ReduceTreePlusTrivialPattern; +class UnsupportPattern; +class HorizontalFusionPattern; + +template +void ExtendVector(std::vector* first, const std::vector& second) { + std::unordered_set visited = + std::unordered_set(first->begin(), first->end()); + for (auto iter = second.begin(); iter != second.end(); iter++) { + if (visited.find(*iter) == visited.end()) { + visited.emplace(*iter); + first->emplace_back(*iter); + } + } +} + +template +std::vector MergeVector(const std::vector& first, + const std::vector& second) { + std::vector result = std::vector(first); + ExtendVector(&result, second); + return result; +} + +struct TrivialPattern { + explicit TrivialPattern(const std::vector& ops) + : ops_(ops) {} + std::vector ops_; + static std::string name() { return "Trivial"; } + std::vector ops() const { return ops_; } +}; + +struct ReducePattern { + explicit ReducePattern(const std::vector& ops) : ops_(ops) {} + std::vector ops_; + std::vector ops() const { return ops_; } + pir::Operation* GetReduceOp() const { return ops_.back(); } + static std::string name() { return "Reduce"; } +}; + +struct ReduceTreePattern { + explicit ReduceTreePattern(const std::vector& v, + const ReducePattern& root) + : reduce_patterns_(v), root_(root) {} + std::vector reduce_patterns_; + const ReducePattern& GetRootPattern() const { return root_; } + std::vector ops() const { + std::vector result; + for (const auto& reduce_pattern : reduce_patterns_) { + result = MergeVector(result, reduce_pattern.ops()); + } + return result; + } + static std::string name() { return "ReduceTree"; } + + private: + ReducePattern root_; +}; + +struct ReduceTreePlusTrivialPattern { + explicit ReduceTreePlusTrivialPattern(const ReduceTreePattern& tree, + const TrivialPattern& sink_trivial) + : tree(tree), sink_trivial(sink_trivial) {} + ReduceTreePattern tree; + TrivialPattern sink_trivial; + std::vector ops() const { + return MergeVector(tree.ops(), sink_trivial.ops()); + } + static std::string name() { return "ReduceTree+Trivial"; } + std::vector fake_reduce_iter_idx; +}; + +struct UnsupportPattern { + explicit UnsupportPattern(const std::vector& ops) + : ops_(ops) {} + std::vector ops_; + std::vector ops() const { return ops_; } + static std::string name() { return "Unsupport"; } +}; + +struct HorizontalFusionPattern { + explicit HorizontalFusionPattern(const std::vector& ops) + : ops_(ops) {} + std::vector ops_; + std::vector ops() const { return ops_; } + static std::string name() { return "HorizontalFusionPattern"; } +}; + +using StmtPattern = std::variant; + +} // namespace cinn::frontend::group_cluster diff --git a/paddle/cinn/frontend/group_cluster/pattern_graph.cc b/paddle/cinn/frontend/group_cluster/pattern_graph.cc new file mode 100644 index 0000000000000..bbd49d1b17503 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/pattern_graph.cc @@ -0,0 +1,235 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/pattern_graph.h" + +namespace cinn::frontend::group_cluster { + +std::vector PatternGraph::ClusterOps( + bool with_horizontal_fusion) { + VLOG(4) << "[Group Cluster] Initial Condition: " << GraphInfo(); + + VLOG(4) << "[Group Cluster] Start SinkTrivialPattern"; + SinkTrivialPattern(); + VLOG(4) << "[Group Cluster] After SinkTrivialPattern: " << GraphInfo(); + + // ReducePattern -> ReduceTreePattern + VLOG(4) << "[Group Cluster] Start ReduceLiftReduceTree"; + ReduceLiftReduceTree(); + VLOG(4) << "[Group Cluster] After ReduceLiftReduceTree: " << GraphInfo(); + + // ReduceTreePattern + ReduceTreePattern fusion + VLOG(4) << "[Group Cluster] Start ReduceTreeGrown"; + ReduceTreeGrown(); + VLOG(4) << "[Group Cluster] After ReduceTreeGrown: " << GraphInfo(); + + // ReduceTreePattern + TrivialPattern fusion. + VLOG(4) << "[Group Cluster] Start ReduceTree_Trivial_Fusion"; + ReduceTree_Trivial_Fusion(); + VLOG(4) << "[Group Cluster] After ReduceTree_Trivial_Fusion: " << GraphInfo(); + + // Horizontal fusion. + if (with_horizontal_fusion) { + VLOG(4) << "[Group Cluster] Start HorizontalFusion"; + HorizontalFusion(); + VLOG(4) << "[Group Cluster] After HorizontalFusion: " << GraphInfo(); + } + + return SortByTopoOrder(); +} + +std::vector PatternGraph::SortByTopoOrder() { + // sort all_pattern_nodes_ by topo order. + std::vector res; + std::list topo_queue; + std::map degree; + for (const auto& node : all_pattern_nodes_) { + degree[node] = node->upstream_.size(); + if (degree[node] == 0) { + topo_queue.push_back(node); + } + } + while (!topo_queue.empty()) { + PatternNodePtr node = topo_queue.front(); + topo_queue.pop_front(); + res.push_back(node); + for (const auto& downstream_op : node->downstream_) { + degree[downstream_op] = degree[downstream_op] - 1; + if (degree[downstream_op] == 0) { + topo_queue.push_back(downstream_op); + } + } + } + return res; +} + +void PatternGraph::SinkTrivialPattern() { + GraphTransformer< + NodePattern, + And>, + IsNotOutputNodeMatcher>, + MergeTrivialPatternOperation>(this); +} + +void PatternGraph::ReduceLiftReduceTree() { + GraphTransformer< + NodePattern, + And, StmtPatternGraphMatcher>, + LiftReduceToReduceTreeOperation>(this); +} + +void PatternGraph::HorizontalFusion() { + GraphTransformer, + LiftToHorizontalFusionPatternOperation>(this); + + GraphTransformer(this); +} + +void PatternGraph::ReduceTreeGrown() { + GraphTransformer, + MergeReduceTreeOperation>(this); +} + +void PatternGraph::ReduceTree_Trivial_Fusion() { + GraphTransformer< + NodePattern, + And, + MergeReduceTreeAndTrivialOperation>(this); +} + +PatternGraph::PatternGraph(const std::vector& ops, + const std::vector& outputs, + const policy::PolicyManager policy_manager, + const policy::PolicyManager topo_manager) + : policy_manager_(policy_manager), + topo_manager_(topo_manager), + outputs_(outputs) { + std::unordered_map op_to_node_map; + + VLOG(4) << "len(outputs) = " << outputs_.size(); + for (const auto& v : outputs) { + VLOG(4) << "output is" << OpsDebugStr({v.defining_op()}); + } + + for (const auto& op : ops) { + PatternNodePtr node = std::make_shared(op); + op_to_node_map[op] = node; + all_pattern_nodes_.emplace(node); + node->sink_op_ = op; + } + + for (pir::Operation* op : ops) { + PatternNodePtr cur_node = op_to_node_map[op]; + + // add upstream nodes + for (int i = 0; i < op->num_operands(); ++i) { + ::pir::Operation* input_op = op->operand_source(i).defining_op(); + if (op_to_node_map.find(input_op) != op_to_node_map.end()) { + PatternNodePtr upstream_node = op_to_node_map[input_op]; + cur_node->upstream_.push_back(upstream_node); + } + } + + // add downstream nodes + for (int i = 0; i < op->num_results(); ++i) { + pir::Value related_value = op->result(i); + for (auto consumer_it = related_value.use_begin(); + consumer_it != related_value.use_end(); + ++consumer_it) { + ::pir::Operation* output_op = consumer_it->owner(); + if (op_to_node_map.find(output_op) != op_to_node_map.end()) { + PatternNodePtr downstream_node = op_to_node_map[output_op]; + cur_node->downstream_.push_back(downstream_node); + } + } + } + } + + VLOG(4) << "PatternGraph Created, pattern node size: " + << all_pattern_nodes_.size(); +} + +void PatternGraph::RemoveNode(const PatternNodePtr& node) { + VLOG(4) << "Start Remove: " << node; + if (all_pattern_nodes_.find(node) != all_pattern_nodes_.end()) { + VLOG(4) << "Removed! "; + all_pattern_nodes_.erase(node); + } + + for (PatternNodePtr& upstream : node->upstream_) { + RemoveFromVector(&upstream->downstream_, node); + } + + for (PatternNodePtr& downstream : node->downstream_) { + RemoveFromVector(&downstream->upstream_, node); + } +} + +void PatternGraph::AppendNode(const PatternNodePtr& node) { + all_pattern_nodes_.emplace(node); +} + +std::string PatternGraph::GraphInfo() const { + std::stringstream ss; + ss << "\n========= GraphInfo ==========="; + for (const auto& v : all_pattern_nodes_) { + ss << "\n" << v->DebugStr(); + ss << "\n IsOutput: " << IsOutputNodeMatcher()(*this, v); + } + ss << "\n==============================="; + return ss.str(); +} + +PatternNodePtr PatternGraph::MergeNode(const PatternNodePtr& upstream, + const PatternNodePtr& downstream) { + PatternNodePtr merged_node = + std::make_shared(upstream, downstream); + + // deal with the reference. + ExtendVector(&merged_node->upstream_, upstream->upstream_); + ExtendVector(&merged_node->upstream_, downstream->upstream_); + RemoveFromVector(&merged_node->upstream_, upstream); + + ExtendVector(&merged_node->downstream_, upstream->downstream_); + ExtendVector(&merged_node->downstream_, downstream->downstream_); + RemoveFromVector(&merged_node->downstream_, downstream); + + for (const auto& upstream_node : merged_node->upstream_) { + upstream_node->downstream_.push_back(merged_node); + RemoveFromVector(&upstream_node->downstream_, upstream); + RemoveFromVector(&upstream_node->downstream_, downstream); + } + for (const auto& downstream_node : merged_node->downstream_) { + downstream_node->upstream_.push_back(merged_node); + RemoveFromVector(&downstream_node->downstream_, upstream); + RemoveFromVector(&downstream_node->downstream_, downstream); + } + + const auto vec_unique = [](const std::vector& vec) { + auto set = std::unordered_set(vec.begin(), vec.end()); + return set.size() == vec.size(); + }; + + CHECK(vec_unique(merged_node->upstream_)); + CHECK(vec_unique(merged_node->downstream_)); + + // deal with the graph storage. + AppendNode(merged_node); + return merged_node; +} +} // namespace cinn::frontend::group_cluster diff --git a/paddle/cinn/frontend/group_cluster/pattern_graph.h b/paddle/cinn/frontend/group_cluster/pattern_graph.h new file mode 100644 index 0000000000000..9f151f25558c7 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/pattern_graph.h @@ -0,0 +1,360 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/cinn/frontend/group_cluster/cluster_policy/policy_manager.h" +#include "paddle/cinn/frontend/group_cluster/cluster_policy/relative_judge_policy.h" +#include "paddle/cinn/frontend/group_cluster/common_utils.h" +#include "paddle/cinn/frontend/group_cluster/pattern_node.h" + +namespace cinn::frontend::group_cluster { + +struct PatternNodePtrHash { + size_t operator()(const PatternNodePtr& node) const { + return std::hash()(node.get()); + } +}; + +struct PatternNodePtrCompare { + bool operator()(const std::shared_ptr& a, + const std::shared_ptr& b) const { + return a.get() == b.get(); + } +}; + +using PatternNodePtrSet = std:: + unordered_set; + +class PatternGraph { + public: + PatternGraph(const std::vector& ops, + const std::vector& outputs, + const policy::PolicyManager policy_manager, + const policy::PolicyManager topo_manager); + + std::vector ClusterOps(bool with_horizontal_fusion = false); + + private: + void SinkTrivialPattern(); + void HorizontalFusion(); + void FuseReducePattern(); + void ReduceLiftReduceTree(); + void ReduceTreeGrown(); + void ReduceTree_Trivial_Fusion(); + + void RemoveNode(const PatternNodePtr& node); + void AppendNode(const PatternNodePtr& node); + std::string GraphInfo() const; + PatternNodePtr MergeNode(const PatternNodePtr& upstream, + const PatternNodePtr& downstream); + std::vector SortByTopoOrder(); + + friend class IsOutputNodeMatcher; + friend class IsNotOutputNodeMatcher; + friend class CanFuseReduceTreeAndTrivialMatcher; + friend class CanFuseReduceTreeMatcher; + + friend class MergeTrivialPatternOperation; + friend class LiftReduceToReduceTreeOperation; + friend class MergeReduceTreeOperation; + friend class MergeReduceTreeAndTrivialOperation; + friend class HorizontalFusionOperation; + friend class LiftToHorizontalFusionPatternOperation; + + public: + PatternNodePtrSet all_pattern_nodes_; + std::vector outputs_; + policy::PolicyManager policy_manager_; + policy::PolicyManager topo_manager_; +}; + +// PatternGraphFusionOperation := (GraphMatcher, GraphOperation) +// SearchAlgorithm := NodePattern | EdgePattern | GraphMatcher +// GraphOperation := Merge2Node | SplitNode | SplitAllAndMergeDownstream + +struct NodePattern {}; +struct EdgePattern {}; +struct GraphPattern {}; // not implemented. +struct NodePairPattern {}; // not implemented. + +template +struct SearchAlgorithm {}; + +template +struct SearchAlgorithm { + PatternGraph* graph_; + PatternNodePtrSet visited_nodes; + + explicit SearchAlgorithm(PatternGraph* graph) { + VLOG(4) << "Create NodePattern algorithm."; + graph_ = graph; + } + + PatternNodePtr FindMatchedNode() { + for (PatternNodePtr iter_node : graph_->all_pattern_nodes_) { + if (GraphMatcher()(*graph_, iter_node) && + !visited_nodes.count(iter_node)) { + visited_nodes.insert(iter_node); + VLOG(4) << "Find Matched Node: " << iter_node; + return iter_node; + } + } + VLOG(4) << "Can't find matched node any more."; + return nullptr; + } + + void operator()() { + while (true) { + PatternNodePtr node = FindMatchedNode(); + if (node == nullptr) { + break; + } + GraphOperation()(graph_, node); + } + } +}; + +template +struct SearchAlgorithm { + PatternGraph* graph_; + std::set> visited_node_pair; + explicit SearchAlgorithm(PatternGraph* graph) { + VLOG(4) << "Create NodePairPattern algorithm."; + graph_ = graph; + } + std::optional> FindMatchedPair() { + for (PatternNodePtr i : graph_->all_pattern_nodes_) { + for (PatternNodePtr j : graph_->all_pattern_nodes_) { + if (i == j) continue; + const auto& pair = std::make_pair(i, j); + if (GraphMatcher()(*graph_, i, j) && !visited_node_pair.count(pair)) { + visited_node_pair.insert(pair); + VLOG(4) << "Find Matched Node Pair: (" << i << ", " << j << ")"; + return pair; + } + } + } + VLOG(4) << "Can't find matched node any more."; + return {}; + } + void operator()() { + while (true) { + const auto& node = FindMatchedPair(); + if (!node.has_value()) break; + const auto& [i, j] = node.value(); + GraphOperation()(graph_, i, j); + } + } +}; + +// Operation + +struct MergeReduceTreeOperation { + void operator()(PatternGraph* graph, PatternNodePtr node) { + CHECK_EQ(node->downstream_.size(), 1); + auto downstream = node->downstream_.at(0); + auto merged_node = graph->MergeNode(node, downstream); + graph->RemoveNode(downstream); + graph->RemoveNode(node); + VLOG(4) << "MergeReduceTreeOperation: \nupstream " << node->DebugStr() + << "\ndownstream " << downstream->DebugStr() << "\nmerged " + << merged_node->DebugStr(); + } +}; + +struct MergeReduceTreeAndTrivialOperation { + void operator()(PatternGraph* graph, PatternNodePtr node) { + CHECK_EQ(node->downstream_.size(), 1); + auto downstream = node->downstream_.at(0); + auto fake_reduce_iter_idx = + graph->policy_manager_.GetFakeReduceIterIdx(node, downstream); + PatternNodePtr merged_node = graph->MergeNode(node, downstream); + std::get(merged_node->stmt_pattern_) + .fake_reduce_iter_idx = fake_reduce_iter_idx; + graph->RemoveNode(downstream); + graph->RemoveNode(node); + VLOG(4) << "MergeReduceTreeAndTrivialOperation: \nupstream " + << node->DebugStr() << "\ndownstream " << downstream->DebugStr() + << "\nmerged " << merged_node->DebugStr(); + } +}; + +struct LiftReduceToReduceTreeOperation { + void operator()(PatternGraph* graph, PatternNodePtr node) { + const auto& reduce_pattern = ToReducePattern(node->stmt_pattern_); + node->stmt_pattern_ = ReduceTreePattern({reduce_pattern}, reduce_pattern); + VLOG(4) << "LiftReduceToReduceTreeOperation: \nnode " << node->DebugStr(); + } +}; + +struct MergeTrivialPatternOperation { + void operator()(PatternGraph* graph, PatternNodePtr upstream) { + std::vector fusion_candidate = upstream->downstream_; + upstream->downstream_.clear(); + for (const auto& downstream : fusion_candidate) { + if (downstream->IsReduce() || downstream->IsTrivial()) { + auto merged_node = graph->MergeNode(upstream, downstream); + graph->RemoveNode(downstream); + VLOG(4) << "MergeTrivialPatternOperation: \nupstream " + << upstream->DebugStr() << "\ndownstream " + << downstream->DebugStr() << "\nmerged " + << merged_node->DebugStr(); + } else { + upstream->downstream_.push_back(downstream); + } + } + if (upstream->downstream_.empty()) { + graph->RemoveNode(upstream); + } + } +}; + +struct LiftToHorizontalFusionPatternOperation { + void operator()(PatternGraph* graph, PatternNodePtr i) { + i->stmt_pattern_ = + HorizontalFusionPattern(GetOpsInPattern(i->stmt_pattern_)); + } +}; + +// Matcher + +template +struct AlwaysTrue { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return true; + } +}; + +template +struct StmtPatternGraphMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return GetPatternName(node->stmt_pattern_) == StmtPattern::name(); + } +}; + +struct CanFuseRxTMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return (node->IsReduceTree() && !node->downstream_.empty() && + node->downstream_.at(0)->IsTrivial()); + } +}; + +struct CanFuseReduceTreeMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return StmtPatternGraphMatcher()(graph, node) && + !node->downstream_.empty() && + node->downstream_.at(0)->IsReduceTree() && + graph.policy_manager_.CanFuse(node, node->downstream_.at(0)); + } +}; + +struct CanFuseReduceTreeAndTrivialMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return StmtPatternGraphMatcher()(graph, node) && + !node->downstream_.empty() && node->downstream_.at(0)->IsTrivial() && + graph.policy_manager_.CanFuse(node, node->downstream_.at(0)); + } +}; + +struct HorizontalFusionConstrain { + bool operator()(const PatternGraph& graph, + const PatternNodePtr& first, + const PatternNodePtr& second) { + if (!StmtPatternGraphMatcher()(graph, first)) { + return false; + } + if (!StmtPatternGraphMatcher()(graph, second)) { + return false; + } + const auto& first_dim = first->sink_op_->result(0) + .type() + .dyn_cast() + .dims(); + const auto& second_dim = second->sink_op_->result(0) + .type() + .dyn_cast() + .dims(); + return graph.topo_manager_.CanFuse(first, second) && + first_dim == second_dim; + } +}; + +struct HorizontalFusionOperation { + void operator()(PatternGraph* graph, + const PatternNodePtr& i, + const PatternNodePtr& j) { + CHECK(GetPatternName(i->stmt_pattern_) == HorizontalFusionPattern::name()); + CHECK(GetPatternName(j->stmt_pattern_) == HorizontalFusionPattern::name()); + graph->MergeNode(i, j); + graph->RemoveNode(i); + graph->RemoveNode(j); + } +}; + +struct NonSinkNodeMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return !node->downstream_.empty(); + } +}; + +struct IsOutputNodeMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + bool res = IsAnyFirstInSecond(node->sink_op_->results(), graph.outputs_); + return res; + } +}; + +struct IsNotOutputNodeMatcher { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + bool res = !IsOutputNodeMatcher()(graph, node); + return res; + } +}; + +template +struct DownstreamSmallerThan { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return node->downstream_.size() < N; + } +}; + +template +struct And { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return A()(graph, node) && B()(graph, node); + } +}; + +template +struct Or { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return A()(graph, node) || B()(graph, node); + } +}; + +template +struct Not { + bool operator()(const PatternGraph& graph, const PatternNodePtr& node) { + return !A()(graph, node); + } +}; + +template +void GraphTransformer(PatternGraph* graph) { + VLOG(4) << "Start GraphTransformer..."; + auto alog = SearchAlgorithm(graph); + alog(); +} + +} // namespace cinn::frontend::group_cluster diff --git a/paddle/cinn/frontend/group_cluster/pattern_node.cc b/paddle/cinn/frontend/group_cluster/pattern_node.cc new file mode 100644 index 0000000000000..342fc36847229 --- /dev/null +++ b/paddle/cinn/frontend/group_cluster/pattern_node.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/frontend/group_cluster/pattern_node.h" + +namespace cinn::frontend::group_cluster { + +PatternNode::PatternNode(pir::Operation* op) + : sink_op_(op), stmt_pattern_(ConvertToStmtPattern(op)) {} + +PatternNode::PatternNode(PatternNodePtr fused_up_node, + PatternNodePtr fused_down_node) + : sink_op_(fused_down_node->sink_op_), + stmt_pattern_(MergePattern(fused_up_node->stmt_pattern_, + fused_down_node->stmt_pattern_)) {} + +std::vector PatternNode::GetOps() const { + return GetOpsInPattern(stmt_pattern_); +} + +bool PatternNode::IsTrivial() const { return IsTrivialPattern(stmt_pattern_); } +bool PatternNode::IsReduce() const { return IsReducePattern(stmt_pattern_); } +bool PatternNode::IsReduceTree() const { + return IsReduceTreePattern(stmt_pattern_); +} +bool PatternNode::IsUnsupport() const { + return IsUnsupportPattern(stmt_pattern_); +} +bool PatternNode::IsReduceTrivial() const { + return IsReduceTrivialPattern(stmt_pattern_); +} +std::string PatternNode::DebugStr() const { + std::stringstream ss; + ss << "Node: " << this << ", Pattern: " << GetPatternName(stmt_pattern_) + << "\n -u>: "; + for (const auto& u : upstream_) { + ss << u << ", "; + } + ss << "\n ; + + explicit PatternNode(pir::Operation* op); + explicit PatternNode(PatternNodePtr fused_up_node, + PatternNodePtr fused_down_node); + + bool IsTrivial() const; + bool IsReduce() const; + bool IsReduceTree() const; + bool IsUnsupport() const; + bool IsReduceTrivial() const; + + std::vector GetOps() const; + + StmtPattern stmt_pattern_; + pir::Operation* sink_op_; + + std::vector upstream_; + std::vector downstream_; + + std::string DebugStr() const; +}; + +using PatternNodePtr = PatternNode::PatternNodePtr; +} // namespace cinn::frontend::group_cluster diff --git a/paddle/cinn/frontend/interpreter.cc b/paddle/cinn/frontend/interpreter.cc index 12964fb8e79ad..ff8c4280b754f 100644 --- a/paddle/cinn/frontend/interpreter.cc +++ b/paddle/cinn/frontend/interpreter.cc @@ -97,9 +97,11 @@ hlir::framework::Tensor Interpreter::GetTensor(const std::string& name) { auto it = impl_->var_map_paddle_to_cinn_.find(name); if (it == impl_->var_map_paddle_to_cinn_.end()) { - LOG(FATAL) << "No variable called [" << name - << "] found in executor\nThe existing vars: " - << utils::Join(impl_->scope_->var_names(), ", "); + std::stringstream ss; + ss << "No variable called [" << name + << "] found in executor\nThe existing vars: " + << utils::Join(impl_->scope_->var_names(), ", "); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return impl_->scope_->GetTensor(it->second); } diff --git a/paddle/cinn/frontend/net_builder.cc b/paddle/cinn/frontend/net_builder.cc index b9f6135bdd5b5..0388fb6e42e0c 100644 --- a/paddle/cinn/frontend/net_builder.cc +++ b/paddle/cinn/frontend/net_builder.cc @@ -285,8 +285,9 @@ Variable NetBuilder::FillConstant(const std::vector& shape, } else if (type.is_bool()) { value = !cinn::runtime::CheckStringFlagFalse(str_value); } else { - LOG(FATAL) << "FillConstant only support int/float/bool, but here " - << dtype; + std::stringstream ss; + ss << "FillConstant only support int/float/bool, but here " << dtype; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } auto out = CustomInstr("fill_constant", {}, @@ -676,7 +677,9 @@ std::vector UpdatePool2dKernelSize(const std::vector& x_shape, height_axis = 1; width_axis = 2; } else { - LOG(FATAL) << "Unsupport data_format: " << data_format; + std::stringstream ss; + ss << "Unsupport data_format: " << data_format; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (global_pooling) { new_ksize[0] = x_shape[height_axis]; @@ -709,7 +712,9 @@ std::vector UpdatePool2dPaddings(const std::vector& paddings, height_axis = 1; width_axis = 2; } else { - LOG(FATAL) << "Unsupport data_format: " << data_format; + std::stringstream ss; + ss << "Unsupport data_format: " << data_format; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } // When padding_algorithm is VALID, set paddings to [0, 0, 0, 0]. // When padding_algorithm is SAME, the calculation formula of padding is as diff --git a/paddle/cinn/frontend/op_mapper_registry.cc b/paddle/cinn/frontend/op_mapper_registry.cc index 883ac8104d9ae..702888ce62bd2 100644 --- a/paddle/cinn/frontend/op_mapper_registry.cc +++ b/paddle/cinn/frontend/op_mapper_registry.cc @@ -83,7 +83,9 @@ Variable OpMapperContext::GetVar(const std::string& origin_name) const { return local_var; } - LOG(FATAL) << "No var called [" << origin_name << "] exists"; + std::stringstream ss; + ss << "No var called [" << origin_name << "] exists"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return Variable(); } diff --git a/paddle/cinn/frontend/op_mappers/common_utils.h b/paddle/cinn/frontend/op_mappers/common_utils.h index 61e9dc2cda93f..58202c991c4c0 100644 --- a/paddle/cinn/frontend/op_mappers/common_utils.h +++ b/paddle/cinn/frontend/op_mappers/common_utils.h @@ -62,10 +62,11 @@ inline T GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, << " here we will return a empty vector."; \ return {}; \ } else { \ - LOG(FATAL) << "Op \"" << op_desc.Type() << "\"'s attribute \"" \ - << name << "\" should be " << #ATTR_TYPE \ - << "S. But here " << static_cast(attr_type) \ - << " Please Check!"; \ + std::stringstream ss; \ + ss << "Op \"" << op_desc.Type() << "\"'s attribute \"" << name \ + << "\" should be " << #ATTR_TYPE << "S. But here " \ + << static_cast(attr_type) << " Please Check!"; \ + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); \ } \ } \ } \ @@ -94,8 +95,10 @@ inline bool GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, case AttrType::LONG: return static_cast(op_desc.GetAttr(name)); default: - LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name - << " should be BOOLEAN. Please Check!"; + std::stringstream ss; + ss << "Op " << op_desc.Type() << "'s attribute " << name + << " should be BOOLEAN. Please Check!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } return default_value; @@ -114,8 +117,10 @@ inline int64_t GetAttrOrDefault(const paddle::cpp::OpDesc& op_desc, case AttrType::INT: return static_cast(op_desc.GetAttr(name)); default: - LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name - << " should be LONG. Please Check!"; + std::stringstream ss; + ss << "Op " << op_desc.Type() << "'s attribute " << name + << " should be LONG. Please Check!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } return default_value; @@ -150,8 +155,10 @@ inline std::vector GetAttrOrDefault( return {}; } default: - LOG(FATAL) << "Op " << op_desc.Type() << "'s attribute " << name - << " should be LONGS. Please Check!"; + std::stringstream ss; + ss << "Op " << op_desc.Type() << "'s attribute " << name + << " should be LONGS. Please Check!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } return default_value; diff --git a/paddle/cinn/frontend/op_mappers/paddle/concat.cc b/paddle/cinn/frontend/op_mappers/paddle/concat.cc index 6904cb85f6c6a..d7181f3ac1a60 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/concat.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/concat.cc @@ -63,8 +63,9 @@ void StackOpMapper(const paddle::cpp::OpDesc& op_desc, CHECK_EQ(op_desc.Output("Y").size(), 1UL); out_name = op_desc.Output("Y").front(); } else { - LOG(FATAL) << "The output argument name of [stack] should be 'Out' or 'Y', " - "but here cannot found! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The output argument name of [stack] should be 'Out' or 'Y', " + "but here cannot found! Please check.")); } cinn::utils::ShapeType input_shape(ctx.GetVar(x_names.front())->shape); diff --git a/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc b/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc index 792ae1e922904..63f9316fc9990 100644 --- a/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc +++ b/paddle/cinn/frontend/op_mappers/paddle/elementwise.cc @@ -225,8 +225,9 @@ void PowOpMapper(const paddle::cpp::OpDesc& op_desc, cinn::UniqName(x_name + "_factor"), cinn::common::Type2Str(x->type)); } else { - LOG(FATAL) << "Cannot found [FactorTensor] input or [factor] attribute in " - "paddle.pow! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Cannot found [FactorTensor] input or [factor] attribute in " + "paddle.pow! Please check.")); } VLOG(4) << out_name << " = pow(" << x_name << ", " << y.value()->id << ")"; diff --git a/paddle/cinn/frontend/op_mappers/science/transform.cc b/paddle/cinn/frontend/op_mappers/science/transform.cc index 412ec1ddf8ce1..fa23c354061f0 100644 --- a/paddle/cinn/frontend/op_mappers/science/transform.cc +++ b/paddle/cinn/frontend/op_mappers/science/transform.cc @@ -91,11 +91,13 @@ void SplitOpMapper(const paddle::cpp::OpDesc& op_desc, } else if (sec == -1 && !has_neg) { has_neg = true; } else if (sec == 0) { - LOG(FATAL) << "The attribute 'num_or_sections' of split should not has " - "0 ! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The attribute 'num_or_sections' of split should not has " + "0 ! Please check.")); } else { - LOG(FATAL) << "The attribute 'num_or_sections' of split can only have " - "at most one '-1' ! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The attribute 'num_or_sections' of split can only have " + "at most one '-1' ! Please check.")); } } CHECK(!has_neg && sec_sum == x_shape[axis]) diff --git a/paddle/cinn/frontend/optimize.cc b/paddle/cinn/frontend/optimize.cc index bc3d1388cf368..3440d3f2b6f4f 100644 --- a/paddle/cinn/frontend/optimize.cc +++ b/paddle/cinn/frontend/optimize.cc @@ -172,8 +172,9 @@ std::shared_ptr Optimize( enable_fusion = true; } } else { - LOG(FATAL) << "Pass " << pass - << " unsupported in CINN! Please check.\n"; + std::stringstream ss; + ss << "Pass " << pass << " unsupported in CINN! Please check.\n"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } diff --git a/paddle/cinn/frontend/paddle/compatible_pb.cc b/paddle/cinn/frontend/paddle/compatible_pb.cc index 68ad3ae514ac5..711e78889a9b0 100644 --- a/paddle/cinn/frontend/paddle/compatible_pb.cc +++ b/paddle/cinn/frontend/paddle/compatible_pb.cc @@ -128,7 +128,9 @@ void OpAttrsAnyToCpp(const OpDescType &any_desc, cpp::OpDesc *cpp_desc) { break; } default: - LOG(FATAL) << "Unsupported attr type found " << static_cast(type); + std::stringstream ss; + ss << "Unsupported attr type found " << static_cast(type); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } }; @@ -157,7 +159,9 @@ void OpAttrsCppToAny(const cpp::OpDesc &cpp_desc, OpDescType *any_desc) { IMPL_ONE(LONG, int64_t); IMPL_ONE(LONGS, std::vector); default: - LOG(FATAL) << "Unsupported attr type found: " << static_cast(type); + std::stringstream ss; + ss << "Unsupported attr type found: " << static_cast(type); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } }; #undef IMPL_ONE diff --git a/paddle/cinn/frontend/paddle/model_parser.cc b/paddle/cinn/frontend/paddle/model_parser.cc index c54c772d803fe..086cf11fe34b5 100644 --- a/paddle/cinn/frontend/paddle/model_parser.cc +++ b/paddle/cinn/frontend/paddle/model_parser.cc @@ -42,7 +42,9 @@ int SizeOfType(framework_proto::VarType::Type type) { DO(INT64, int64_t); #undef DO default: - LOG(FATAL) << "unknown data type " << type; + std::stringstream ss; + ss << "unknown data type " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return -1; } @@ -90,14 +92,17 @@ void TensorFromStream(std::istream &is, SET_TENSOR(INT64, int64_t, Int(64)); #undef SET_TENSOR default: - LOG(FATAL) << "unknown type " << desc.data_type(); + std::stringstream ss; + ss << "unknown type " << desc.data_type(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } // tensor->set_persistable(true); is.read(static_cast(buf), size); } else if (target.arch == Target::Arch::NVGPU) { #ifdef CINN_WITH_CUDA if (desc.data_type() != Type::VarType_Type_FP32) - LOG(FATAL) << "[CUDA] The type is not fp32!!"; + PADDLE_THROW( + phi::errors::InvalidArgument("[CUDA] The type is not fp32!!")); auto *data = tensor->mutable_data(target); tensor->set_type(Float(32)); std::vector temp(tensor->shape().numel()); @@ -108,7 +113,8 @@ void TensorFromStream(std::istream &is, tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice)); #else - LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal( + "To use CUDA backends, you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED @@ -281,7 +287,7 @@ void LoadModelPb(const std::string &model_dir, target); break; default: - LOG(FATAL) << "unknown weight type"; + PADDLE_THROW(phi::errors::InvalidArgument("unknown weight type")); } } } diff --git a/paddle/cinn/frontend/paddle/pb/op_desc.h b/paddle/cinn/frontend/paddle/pb/op_desc.h index 82e1477270fa4..222bdda4da2b2 100644 --- a/paddle/cinn/frontend/paddle/pb/op_desc.h +++ b/paddle/cinn/frontend/paddle/pb/op_desc.h @@ -17,6 +17,7 @@ #include "paddle/cinn/frontend/paddle/cpp/op_desc.h" #include "paddle/cinn/frontend/paddle/framework.pb.h" +#include "paddle/common/enforce.h" namespace cinn::frontend::paddle::pb { @@ -106,7 +107,7 @@ class OpDesc : public cpp::OpDescAPI { DEF_ONE(BLOCKS); DEF_ONE(LONGS); default: - LOG(FATAL) << "Unknown attribute type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown attribute type")); return static_cast(-1); } #undef DEF_ONE diff --git a/paddle/cinn/frontend/paddle/pb/var_desc.cc b/paddle/cinn/frontend/paddle/pb/var_desc.cc index efee4f211d662..c6069daa1f67d 100644 --- a/paddle/cinn/frontend/paddle/pb/var_desc.cc +++ b/paddle/cinn/frontend/paddle/pb/var_desc.cc @@ -15,9 +15,9 @@ #include "paddle/cinn/frontend/paddle/pb/var_desc.h" #include - #include "paddle/cinn/frontend/paddle/cpp/desc_api.h" #include "paddle/cinn/frontend/paddle/framework.pb.h" +#include "paddle/common/enforce.h" namespace cinn::frontend::paddle::pb { @@ -39,7 +39,7 @@ cpp::VarDescAPI::Type VarDesc::GetType() const { GET_TYPE_CASE_ITEM(PLACE_LIST); GET_TYPE_CASE_ITEM(READER); default: - LOG(FATAL) << "Unknown var type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown var type")); return VarDescAPI::Type(); } #undef GET_TYPE_CASE_ITEM @@ -62,7 +62,7 @@ void VarDesc::SetType(VarDescAPI::Type type) { SET_TYPE_CASE_ITEM(PLACE_LIST); SET_TYPE_CASE_ITEM(READER); default: - LOG(FATAL) << "Unknown var type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown var type")); } #undef SET_TYPE_CASE_ITEM } @@ -83,9 +83,11 @@ void VarDesc::SetTensorDescNum(size_t num) { return; } break; default: - LOG(FATAL) << "Setting 'sub_tensor_number' is not supported by the type " - "of var %s." - << this->Name(); + std::stringstream ss; + ss << "Setting 'sub_tensor_number' is not supported by the type " + "of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -95,9 +97,11 @@ size_t VarDesc::GetTensorDescNum() const { return desc_->type().reader().lod_tensor_size(); break; default: - LOG(FATAL) << "Getting 'sub_tensor_number' is not supported by the type " - "of var %s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'sub_tensor_number' is not supported by the type " + "of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return 0; } @@ -151,7 +155,9 @@ void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) { SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP64); default: - LOG(FATAL) << "Unknown var type: " << static_cast(data_type); + std::stringstream ss; + ss << "Unknown var type: " << static_cast(data_type); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } #undef SET_DATA_TYPE_CASE_ITEM } @@ -200,7 +206,9 @@ cpp::VarDescAPI::VarDataType VarDesc::GetDataType() const { GET_DATA_TYPE_CASE_ITEM(FP32); GET_DATA_TYPE_CASE_ITEM(FP64); default: - LOG(FATAL) << "Unknown var type: " << static_cast(type); + std::stringstream ss; + ss << "Unknown var type: " << static_cast(type); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return VarDescAPI::Type(); } #undef GET_DATA_TYPE_CASE_ITEM @@ -225,9 +233,10 @@ void VarDesc::SetLoDLevel(int32_t lod_level) { desc_->mutable_type()->mutable_tensor_array()->set_lod_level(lod_level); break; default: - LOG(FATAL) - << "Setting 'lod_level' is not supported by the type of var %s." - << this->Name(); + std::stringstream ss; + ss << "Setting 'lod_level' is not supported by the type of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -249,9 +258,10 @@ void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { } } break; default: - LOG(FATAL) - << "Setting 'lod_levels' is not supported by the type of var %s." - << this->Name(); + std::stringstream ss; + ss << "Setting 'lod_levels' is not supported by the type of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -262,9 +272,10 @@ int32_t VarDesc::GetLoDLevel() const { case framework_proto::VarType::LOD_TENSOR_ARRAY: return desc_->type().tensor_array().lod_level(); default: - LOG(FATAL) - << "Getting 'lod_level' is not supported by the type of var %s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'lod_level' is not supported by the type of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return 0; } @@ -280,9 +291,10 @@ std::vector VarDesc::GetLoDLevels() const { return res; break; default: - LOG(FATAL) - << "Getting 'lod_levels' is not supported by the type of var %s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'lod_levels' is not supported by the type of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return std::vector(); } @@ -298,9 +310,10 @@ const framework_proto::VarType::TensorDesc &VarDesc::tensor_desc() const { case framework_proto::VarType::LOD_TENSOR_ARRAY: return desc_->type().tensor_array().tensor(); default: - LOG(FATAL) - << "Getting 'tensor_desc' is not supported by the type of var %s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'tensor_desc' is not supported by the type of var %s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return framework_proto::VarDesc().type().lod_tensor().tensor(); } @@ -317,10 +330,11 @@ std::vector VarDesc::tensor_descs() } return res; default: - LOG(FATAL) - << "Getting 'tensor_descs' is not supported by the type of var " - "%s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return std::vector(); } @@ -336,10 +350,12 @@ framework_proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { case framework_proto::VarType::LOD_TENSOR_ARRAY: return desc_->mutable_type()->mutable_tensor_array()->mutable_tensor(); default: - LOG(FATAL) << "Getting 'mutable_tensor_desc' is not supported by the " - "type of var " - "%s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'mutable_tensor_desc' is not supported by the " + "type of var " + "%s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return nullptr; } @@ -358,10 +374,11 @@ VarDesc::mutable_tensor_descs() { } return res; default: - LOG(FATAL) - << "Getting 'tensor_descs' is not supported by the type of var " - "%s." - << this->Name(); + std::stringstream ss; + ss << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return std::vector(); } diff --git a/paddle/cinn/frontend/paddle_model_convertor_test.cc b/paddle/cinn/frontend/paddle_model_convertor_test.cc index 30364c05e417e..5e69cdef80cc2 100644 --- a/paddle/cinn/frontend/paddle_model_convertor_test.cc +++ b/paddle/cinn/frontend/paddle_model_convertor_test.cc @@ -84,7 +84,8 @@ void RunProgram(const Target& target, Program* prog) { } else if (inputs[i]->type.is_bool()) { RandomInput(target, tensor, 0, inputs[i]->shape[0]); } else { - LOG(FATAL) << "Only support float/int/bool! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support float/int/bool! Please check.")); } } diff --git a/paddle/cinn/frontend/paddle_model_to_program.cc b/paddle/cinn/frontend/paddle_model_to_program.cc index 52c91216dd901..7249c35f19d26 100644 --- a/paddle/cinn/frontend/paddle_model_to_program.cc +++ b/paddle/cinn/frontend/paddle_model_to_program.cc @@ -104,7 +104,8 @@ void PaddleModelToProgram::AddOpMapper_scale() { if (op_desc.HasAttr("bias")) { // the old model format bias = op_desc.GetAttr("bias"); } else { - LOG(FATAL) << "Didn't find [bias] attr in Scale operator!!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Didn't find [bias] attr in Scale operator!!")); } absl::flat_hash_map attrs; auto out = net_builder_->Scale(x, scale, bias); @@ -243,7 +244,9 @@ void PaddleModelToProgram::AddOpMapper_fill_constant() { DO(INT32, int); #undef DO default: - LOG(FATAL) << "unknown data type " << dtype; + std::stringstream ss; + ss << "unknown data type " << dtype; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } AddVar(TransValidVarName(out_name), out); var_model_to_program_map_[out_name] = out->id; @@ -622,7 +625,9 @@ void PaddleModelToProgram::AddOp(const paddle::cpp::OpDesc& op_desc) { return; } // feed op's output is a input of the model - LOG(FATAL) << "Not supported op [" << op_desc.Type() << "] found"; + std::stringstream ss; + ss << "Not supported op [" << op_desc.Type() << "] found"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } void PaddleModelToProgram::TransposeVar(const std::string& name) { @@ -658,7 +663,8 @@ void PaddleModelToProgram::TransposeVar(const std::string& name) { cudaMemcpyHostToDevice)); #endif #else - LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal( + "To use CUDA backends, you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED @@ -674,7 +680,9 @@ void PaddleModelToProgram::TransposeVar(const std::string& name) { var->type = Float(32); AddVar(name, var, true); } else { - LOG(FATAL) << "No var called [" << name << "] exists"; + std::stringstream ss; + ss << "No var called [" << name << "] exists"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -707,13 +715,16 @@ void PaddleModelToProgram::ReverseHWVar(const std::string& name) { tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice)); #else - LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal( + "To use CUDA backends, you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED } } else { - LOG(FATAL) << "No var called [" << name << "] exists"; + std::stringstream ss; + ss << "No var called [" << name << "] exists"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -736,7 +747,9 @@ Variable PaddleModelToProgram::GetVar(const std::string& name) { return var; } - LOG(FATAL) << "No var called [" << name << "] exists"; + std::stringstream ss; + ss << "No var called [" << name << "] exists"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return Variable(); } diff --git a/paddle/cinn/frontend/pass/fill_constant_rewriter.cc b/paddle/cinn/frontend/pass/fill_constant_rewriter.cc index 81b331042096e..f1a8a9db01e29 100644 --- a/paddle/cinn/frontend/pass/fill_constant_rewriter.cc +++ b/paddle/cinn/frontend/pass/fill_constant_rewriter.cc @@ -37,7 +37,8 @@ namespace pass { else if (absl::holds_alternative(OLD_VALUE)) \ NEW_VALUE = FUNC(absl::get(OLD_VALUE)); \ else \ - LOG(FATAL) << "fill_constant Only support float32/float64/int32/int64"; + PADDLE_THROW(phi::errors::InvalidArgument( \ + "fill_constant Only support float32/float64/int32/int64")); #define MATH_FUNC_REWRITER(op_name) \ { \ diff --git a/paddle/cinn/frontend/pass/transpose_folding_input.cc b/paddle/cinn/frontend/pass/transpose_folding_input.cc index 3c50ce3f2d6c9..1353848ff8985 100644 --- a/paddle/cinn/frontend/pass/transpose_folding_input.cc +++ b/paddle/cinn/frontend/pass/transpose_folding_input.cc @@ -111,7 +111,8 @@ class TransposeFoldingInputPass : public TransposeFoldingBase { : false; dot->SetAttr("trans_b", static_cast(trans_b ^ true)); } else { - LOG(FATAL) << "The matmul should only have two inputs."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The matmul should only have two inputs.")); } // shape has changed, the ignore op should update shape diff --git a/paddle/cinn/frontend/var_type_utils.h b/paddle/cinn/frontend/var_type_utils.h index 85a70ee4f53a9..fa539b1085f86 100644 --- a/paddle/cinn/frontend/var_type_utils.h +++ b/paddle/cinn/frontend/var_type_utils.h @@ -83,9 +83,10 @@ inline cinn::common::Type CppVarType2CommonType( // so here need convert back to unkown type. SET_TYPE_CASE_ITEM(RAW, Type) default: - LOG(FATAL) << "Unknown VarDesc type: " - << var_type_names_[static_cast(type)] << "(" - << static_cast(type) << ")"; + std::stringstream ss; + ss << "Unknown VarDesc type: " << var_type_names_[static_cast(type)] + << "(" << static_cast(type) << ")"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } #undef SET_DATA_TYPE_CASE_ITEM return cinn::common::Type(); diff --git a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt index 89e47a59b546b..ba58a034fb4bb 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt @@ -1,78 +1,74 @@ -# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could -# not found under CINN_ONLY mode -if(NOT CINN_ONLY) - set(CINN_DIALECT_SOURCE_DIR - "${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/operator/ir") +set(CINN_DIALECT_SOURCE_DIR + "${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/operator/ir") - # Generate cinn_op_dialect files defining op using op_gen_file - set(cinn_op_gen_parsed_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) +# Generate cinn_op_dialect files defining op using op_gen_file +set(cinn_op_gen_parsed_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py) - set(cinn_op_gen_file - ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py) +set(cinn_op_gen_file + ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py) - set(cinn_op_compat_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) +set(cinn_op_compat_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) - set(cinn_op_yaml_file - ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml) +set(cinn_op_yaml_file + ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml) - set(parsed_op_dir ${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/generated) +set(parsed_op_dir ${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/generated) - set(cinn_op_parsed_yaml_file ${parsed_op_dir}/ops.parsed.yaml) +set(cinn_op_parsed_yaml_file ${parsed_op_dir}/ops.parsed.yaml) - set(cinn_op_parsed_yaml_files ${cinn_op_parsed_yaml_file}) +set(cinn_op_parsed_yaml_files ${cinn_op_parsed_yaml_file}) - set(cinn_op_namespace cinn,dialect) - set(cinn_op_dialect_name cinn_op) - set(cinn_op_header_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op.h) - set(cinn_op_source_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op.cc) - set(cinn_op_header_file_tmp ${cinn_op_header_file}.tmp) - set(cinn_op_source_file_tmp ${cinn_op_source_file}.tmp) +set(cinn_op_namespace cinn,dialect) +set(cinn_op_dialect_name cinn_op) +set(cinn_op_header_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op.h) +set(cinn_op_source_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op.cc) +set(cinn_op_header_file_tmp ${cinn_op_header_file}.tmp) +set(cinn_op_source_file_tmp ${cinn_op_source_file}.tmp) - set(cinn_op_info_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op_info.cc) - set(cinn_op_info_file_tmp ${cinn_op_info_file}.tmp) +set(cinn_op_info_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op_info.cc) +set(cinn_op_info_file_tmp ${cinn_op_info_file}.tmp) - execute_process( - COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} - COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_parsed_yaml_file} --op_yaml_path - ${cinn_op_yaml_file} --output_path ${cinn_op_parsed_yaml_file}) +execute_process( + COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir} + COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_parsed_yaml_file} --op_yaml_path + ${cinn_op_yaml_file} --output_path ${cinn_op_parsed_yaml_file}) - execute_process( - COMMAND - ${PYTHON_EXECUTABLE} ${cinn_op_gen_file} --op_yaml_files - ${cinn_op_parsed_yaml_files} --op_compat_yaml_file - ${cinn_op_compat_yaml_file} --namespaces ${cinn_op_namespace} - --dialect_name ${cinn_op_dialect_name} --op_def_h_file - ${cinn_op_header_file_tmp} --op_info_file ${cinn_op_info_file_tmp} - --op_def_cc_file ${cinn_op_source_file_tmp}) +execute_process( + COMMAND + ${PYTHON_EXECUTABLE} ${cinn_op_gen_file} --op_yaml_files + ${cinn_op_parsed_yaml_files} --op_compat_yaml_file + ${cinn_op_compat_yaml_file} --namespaces ${cinn_op_namespace} + --dialect_name ${cinn_op_dialect_name} --op_def_h_file + ${cinn_op_header_file_tmp} --op_info_file ${cinn_op_info_file_tmp} + --op_def_cc_file ${cinn_op_source_file_tmp}) - set(generated_files_cinn_op "${cinn_op_header_file}" "${cinn_op_info_file}" - "${cinn_op_source_file}") - foreach(generated_file ${generated_files_cinn_op}) - if(EXISTS "${generated_file}.tmp" AND EXISTS "${generated_file}") - execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different - "${generated_file}.tmp" "${generated_file}") - message("copy if different ${generated_file}.tmp ${generated_file}") - elseif(EXISTS "${generated_file}.tmp") - execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_file}.tmp" - "${generated_file}") - message("copy ${generated_file}.tmp ${generated_file}") - endif() - endforeach() +set(generated_files_cinn_op "${cinn_op_header_file}" "${cinn_op_info_file}" + "${cinn_op_source_file}") +foreach(generated_file ${generated_files_cinn_op}) + if(EXISTS "${generated_file}.tmp" AND EXISTS "${generated_file}") + execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different + "${generated_file}.tmp" "${generated_file}") + message("copy if different ${generated_file}.tmp ${generated_file}") + elseif(EXISTS "${generated_file}.tmp") + execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_file}.tmp" + "${generated_file}") + message("copy ${generated_file}.tmp ${generated_file}") + endif() +endforeach() - cinn_cc_library( - cinn_op_dialect - SRCS - op_dialect.cc - ${cinn_op_source_file} - ${cinn_op_info_file} - generate_shape_util.cc - manual_op.cc - op_attribute.cc - DEPS - op_dialect_vjp - pir) +cinn_cc_library( + cinn_op_dialect + SRCS + op_dialect.cc + ${cinn_op_source_file} + ${cinn_op_info_file} + generate_shape_util.cc + manual_op.cc + op_attribute.cc + DEPS + op_dialect_vjp + pir) - target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR}) -endif() +target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR}) diff --git a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h index 61a2ae3268e05..770eeb4b55701 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h +++ b/paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h @@ -22,6 +22,7 @@ #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/pir/include/core/attribute_base.h" #include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" namespace cinn { namespace dialect { @@ -52,6 +53,7 @@ struct GroupInfo { alignment_schedule_info; std::vector reduce_axis; std::vector loop_ranges; + std::vector loop_ranges_expr; private: void Initialize() { @@ -71,6 +73,11 @@ struct GroupInfoAttributeStorage : public pir::AttributeStorage { static std::size_t HashValue(const ParamKey& key) { size_t hash_value = std::hash{}(key.group_id); + for (auto op : key.ops) { + hash_value = + pir::detail::hash_combine(hash_value, std::hash()(op)); + } + for (auto d : key.loop_ranges) { hash_value = pir::detail::hash_combine(hash_value, std::hash()(d)); diff --git a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc index a230e032c41e4..0ce1ad6bab5c0 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc @@ -575,7 +575,7 @@ std::vector GetMinimalInputs( [&](pir::Value input_tensor, const std::vector& dim_exprs) { for (const auto& dim_expr : dim_exprs) { - if (dim_expr.isa()) continue; + if (!dim_expr.isa()) continue; if (handled_dim_exprs.insert(dim_expr).second) { first_occurred_input_tensors.insert(input_tensor); } diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc index 54299cc2ff7ff..2dbe30c4447b7 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.cc @@ -24,14 +24,17 @@ #include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/transforms/shape_optimization_pass.h" #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/op_base.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace cinn { namespace dialect { +using DenseTensorType = paddle::dialect::DenseTensorType; + const char* GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; const char* FusionOp::attributes_name[GroupOp::attributes_num] = {"group_info"}; const char* ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"}; @@ -78,7 +81,13 @@ pir::Block* GroupOp::block() { return ®ion.front(); } -std::vector GroupOp::GetOperators() { +pir::Block* GroupOp::block() const { + pir::Region& region = (*this)->region(0); + CHECK(!region.empty()); + return ®ion.front(); +} + +std::vector GroupOp::GetOperators() const { std::vector rt_ops; for (auto& op : *block()) { rt_ops.push_back(&op); @@ -98,12 +107,30 @@ void GroupOp::Print(pir::IrPrinter& printer) { printer.PrintOpReturnType(op); os << " {"; for (auto& sub_op : GetOperators()) { - os << "\n"; + os << "\n "; printer.PrintOperation(sub_op); } os << " \n }"; } +bool GroupOp::InferSymbolicShape( + ::pir::ShapeConstraintIRAnalysis* shape_analysis) { + ::pir::InferSymExprForBlock(*block(), shape_analysis); + + for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) { + auto inner_yield_value = block()->back().operand_source(rst_idx); + const auto& shape = + shape_analysis->GetShapeOrDataForValue(inner_yield_value); + shape_analysis->SetShapeOrDataForValue(result(rst_idx), shape); + } + + if (VLOG_IS_ON(4)) { + ::std::cerr << ">>>>>>>>>>>>>>>>>>>> cinn_op.group(op_id: op_" + << block()->back().id() << ") END." << ::std::endl; + } + return true; +} + void FusionOp::Build(pir::Builder& builder, pir::OperationArgument& argument, const std::vector& output_types) { @@ -149,12 +176,29 @@ void FusionOp::Print(pir::IrPrinter& printer) { printer.PrintOpReturnType(op); os << " {"; for (auto& sub_op : GetOperators()) { - os << "\n"; + os << "\n "; printer.PrintOperation(sub_op); } os << " \n }"; } +void YieldStoreOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value x, + pir::Type output_type) { + argument.inputs = {x}; + argument.output_types = {output_type}; +} + +void YieldStoreOp::VerifySig() {} + +bool YieldStoreOp::InferSymbolicShape( + pir::ShapeConstraintIRAnalysis* shape_analysis) { + shape_analysis->SetShapeOrDataForValue( + result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0))); + return true; +} + bool ConcatOp::InferSymbolicShape( pir::ShapeConstraintIRAnalysis* shape_analysis) { VLOG(4) << "Infer symbolic shape for cinn_op.concat"; @@ -175,39 +219,31 @@ void ConcatOp::Build(pir::Builder& builder, // NOLINT phi::errors::InvalidArgument( "input size [%d] is less than 0", inputs.size())); - auto first_ele = - inputs[0].type().dyn_cast(); - phi::DDim out_dims = first_ele.dims(); - - if (axis < 0) { - axis += out_dims.size(); - } - - for (size_t idx = 0; idx < inputs.size(); ++idx) { - inputs_type[idx] = inputs[idx].type(); - - if (idx > 0) { - auto dim_i = inputs[idx] - .type() - .dyn_cast() - .dims(); - - out_dims[axis] += dim_i[axis]; + const pir::Type out_type = [&]() { + auto first_ele = inputs[0].type().dyn_cast(); + phi::DDim out_dims = first_ele.dims(); + if (axis < 0) axis += out_dims.size(); + + for (size_t idx = 1; idx < inputs.size(); ++idx) { + inputs_type[idx] = inputs[idx].type(); + auto dim_i = inputs[idx].type().dyn_cast().dims(); + + if (out_dims[axis] > 0 && dim_i[axis] > 0) { + out_dims[axis] += dim_i[axis]; + } else { + out_dims[axis] = -1; + break; + } } - } - - auto out_type = - paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - first_ele.dtype(), - out_dims, - first_ele.data_layout(), - first_ele.lod(), - first_ele.offset()); - + return DenseTensorType::get(pir::IrContext::Instance(), + first_ele.dtype(), + out_dims, + first_ele.data_layout(), + first_ele.lod(), + first_ele.offset()); + }(); argument.output_types.emplace_back(out_type); - PassStopGradientsDefaultly(argument); - argument.AddAttribute( "axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis)); } @@ -223,7 +259,7 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT std::vector output_type(sections.size()); - auto input_ele = input.type().dyn_cast(); + auto input_ele = input.type().dyn_cast(); if (axis < 0) { axis += input_ele.dims().size(); @@ -232,13 +268,12 @@ void SplitOp::Build(pir::Builder& builder, // NOLINT for (size_t idx = 0; idx < sections.size(); ++idx) { auto out_dims = input_ele.dims(); out_dims[axis] = sections[idx]; - auto out_type = - paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - input_ele.dtype(), - out_dims, - input_ele.data_layout(), - input_ele.lod(), - input_ele.offset()); + auto out_type = DenseTensorType::get(pir::IrContext::Instance(), + input_ele.dtype(), + out_dims, + input_ele.data_layout(), + input_ele.lod(), + input_ele.offset()); argument.output_types.emplace_back(out_type); @@ -284,7 +319,7 @@ void GenerateShapeOp::Build( auto type = pir::Int64Type::get(ctx); auto dim = ::common::make_ddim({static_cast(output_dim_exprs.size())}); - return paddle::dialect::DenseTensorType::get(ctx, type, dim); + return DenseTensorType::get(ctx, type, dim); }()}); ::pir::PassStopGradientsDefaultly(argument); } @@ -486,3 +521,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::FusionOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); +IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::YieldStoreOp); diff --git a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h index bb9917cfbfa63..f27908438d3b9 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/manual_op.h +++ b/paddle/cinn/hlir/dialect/operator/ir/manual_op.h @@ -29,7 +29,8 @@ namespace cinn { namespace dialect { -class IR_API GroupOp : public pir::Op { +class IR_API GroupOp + : public pir::Op { public: using Op::Op; static const char *name() { return "cinn_op.group"; } @@ -49,7 +50,10 @@ class IR_API GroupOp : public pir::Op { const cinn::dialect::GroupInfo &group_info); pir::Block *block(); - std::vector GetOperators(); + pir::Block *block() const; + std::vector GetOperators() const; + + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT @@ -74,11 +78,32 @@ class IR_API FusionOp : public pir::Op { pir::Block *block(); std::vector GetOperators(); + std::vector GetOperators() const; void VerifySig(); void Print(pir::IrPrinter &printer); // NOLINT }; +// YieldStoreOp represents a store operation for +// seperate local variable and ouptut +class IR_API YieldStoreOp + : public pir::Op { + public: + using Op::Op; + static const char *name() { return "cinn_op.yield_store"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value x, + pir::Type output_type); + + void VerifySig(); + + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); +}; + class IR_API ConcatOp : public pir::Op { public: @@ -167,3 +192,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::FusionOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GenerateShapeOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::YieldStoreOp); diff --git a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc index c07ae5a9b0cad..32a534a397018 100644 --- a/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc +++ b/paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc @@ -56,6 +56,7 @@ void OperatorDialect::initialize() { RegisterOp(); RegisterOp(); RegisterOp(); + RegisterOp(); RegisterOp(); RegisterAttribute(); RegisterAttribute(); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt index 00eecee4d883c..5808789c9adef 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt @@ -1,21 +1,19 @@ -if(NOT CINN_ONLY) +file(GLOB_RECURSE cinn_transforms_srcs "*.cc") - file(GLOB_RECURSE cinn_transforms_srcs "*.cc") +set(cinn_transforms_deps + pir + drr + op_dialect + cinn_op_dialect + op_dialect_vjp + cinn_runtime_dialect + group_cluster + pir_compiler) - set(cinn_transforms_deps - pir - drr - op_dialect - cinn_op_dialect - op_dialect_vjp - cinn_runtime_dialect - pir_compiler) +cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS + ${cinn_transforms_deps}) - cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS - ${cinn_transforms_deps}) - - cc_library( - add_cinn_pass - SRCS add_cinn_pass.cc - DEPS op_dialect pir cinn_op_dialect cinnapi pir_transforms cinn_transforms) -endif() +cc_library( + add_cinn_pass + SRCS add_cinn_pass.cc + DEPS op_dialect pir cinn_op_dialect cinnapi pir_transforms cinn_transforms) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index ff0fa6381c08f..97604471f5ba9 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -15,6 +15,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/common/ddim.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" @@ -173,6 +174,23 @@ class AddBroadcastToElementwisePattern : public pir::OpRewritePattern { } }; +class DeleteUselessBroadcastPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(cinn::dialect::BroadcastOp broadcast, + pir::PatternRewriter& rewriter) const override { + if (!broadcast->GetParentOp()->isa()) { + rewriter.ReplaceAllUsesWith(broadcast.result(0), + broadcast->operand_source(0)); + rewriter.EraseOp(broadcast); + return true; + } + return false; + } +}; + class AddBroadcastToElementwisePass : public pir::PatternRewritePass { public: AddBroadcastToElementwisePass() @@ -213,6 +231,8 @@ class AddBroadcastToElementwisePass : public pir::PatternRewritePass { context); // bitwise ops + ps.Add>( + context); ps.Add>( context); ps.Add>( @@ -224,7 +244,19 @@ class AddBroadcastToElementwisePass : public pir::PatternRewritePass { } bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; + return op->num_regions() > 0; + } +}; + +class DeleteUselessBroadcastPass : public pir::PatternRewritePass { + public: + DeleteUselessBroadcastPass() + : pir::PatternRewritePass("delete_useless_broadcast_pass", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + return ps; } }; @@ -232,6 +264,10 @@ std::unique_ptr CreateAddBroadcastToElementwisePass() { return std::make_unique(); } +std::unique_ptr CreateDeleteUselessBroadcastPass() { + return std::make_unique(); +} + } // namespace ir } // namespace dialect } // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h index d4778a17a1fbd..6b2226d385733 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h @@ -23,6 +23,8 @@ namespace ir { std::unique_ptr CreateAddBroadcastToElementwisePass(); +std::unique_ptr CreateDeleteUselessBroadcastPass(); + } // namespace ir } // namespace dialect } // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc index 6ded2f5a85c93..3b6b1adcdbda1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc @@ -23,8 +23,11 @@ #include "paddle/pir/include/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/check_infer_symbolic_pass.h" @@ -34,19 +37,21 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h" -#include "paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/remove_unchanged_reshape_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.h" #include "paddle/fluid/pir/transforms/build_cinn_pass.h" -#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" #include "paddle/fluid/pir/transforms/shape_optimization_pass.h" COMMON_DECLARE_bool(print_ir); COMMON_DECLARE_bool(check_infer_symbolic); +PD_DECLARE_bool(group_schedule_tiling_first); namespace cinn::dialect::ir { @@ -70,6 +75,16 @@ bool HasDynamicShape(const pir::Program& program) { } } // namespace +void ApplyPdToCinnPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); + pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->Run(program); +} + void ApplyCinnPreprocessPass( ::pir::Program* program, const std::function()>& @@ -77,42 +92,79 @@ void ApplyCinnPreprocessPass( std::shared_ptr pass_manager = CreatePassManager(); bool has_dynamic_shape = HasDynamicShape(*program); - pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); if (!has_dynamic_shape && FLAGS_check_infer_symbolic) { pass_manager->AddPass(pir::CreateShapeOptimizationPass()); pass_manager->AddPass(cinn::dialect::ir::CreateCheckInferSymbolicPass()); } - pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass()); - pass_manager->AddPass(cinn::dialect::ir::CreateRemoveUnchangedReshapePass()); - pass_manager->AddPass( - cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); - pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); if (has_dynamic_shape) { + pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); pass_manager->AddPass(pir::CreateShapeOptimizationPass()); - pass_manager->AddPass(cinn::dialect::ir::CreateSimplifyDimExprPass()); - pass_manager->AddPass( - cinn::dialect::ir::CreateSubstituteDimExprBasedOnConstraintsPass()); - pass_manager->AddPass(cinn::dialect::ir::CreateInsertBroadcastPass()); - pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass()); pass_manager->AddPass( cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass()); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + } + + pass_manager->Run(program); +} + +void ApplyBuildGroupOpPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + bool has_dynamic_shape = HasDynamicShape(*program); + if (has_dynamic_shape) { pass_manager->AddPass(pir::CreateShapeOptimizationPass()); } + pass_manager->AddPass(cinn::dialect::ir::CreateRemoveUnchangedReshapePass()); pass_manager->AddPass(pir::CreateBuildCinnPass()); + pass_manager->Run(program); +} + +void ApplyGroupOpPass(::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); pass_manager->AddPass( - cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass()); - pass_manager->AddPass(cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass()); + cinn::dialect::ir::CreateAddBroadcastToElementwisePass()); + if (HasDynamicShape(*program)) { + pass_manager->AddPass(::pir::CreateShapeOptimizationPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateInsertBroadcastPass()); + pass_manager->AddPass( + cinn::dialect::ir::CreateSubstituteDimExprBasedOnConstraintsPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateSimplifyDimExprPass()); + pass_manager->AddPass( + cinn::dialect::ir::CreateFuseShapeOpsIntoGenerateShapeOpPass()); + pass_manager->AddPass( + cinn::dialect::ir::CreateMoveGenerateShapeOpsToProloguePass()); + } + pass_manager->AddPass(cinn::dialect::ir::CreateDynamicReshapeOpPass()); - pass_manager->AddPass(cinn::dialect::ir::CreateReplaceDynamicExpandOpPass()); pass_manager->AddPass(pir::CreateDeadCodeEliminationPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateRemoveUnchangedReshapePass()); pass_manager->Run(program); } +void ApplyDivideGroupOpToFusionOpPass( + ::pir::Program* program, + const std::function()>& + CreatePassManager) { + std::shared_ptr pass_manager = CreatePassManager(); + if (FLAGS_group_schedule_tiling_first) { + pass_manager->AddPass(cinn::dialect::ir::CreateCinnGroupClusterPass()); + pass_manager->AddPass(cinn::dialect::ir::CreateAddStoreInFusionOpPass()); + } else { + pass_manager->AddPass( + cinn::dialect::ir::CreateDivideGroupOpToFusionOpPass()); + } + pass_manager->Run(program); +} + void ApplyCinnLowerPass( ::pir::Program* program, const std::function()>& @@ -130,22 +182,49 @@ void ApplyCinnLowerPass( pass_manager->AddPass(std::move(pass.value())); } + pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass()); if (has_dynamic_shape && !force_static_shape) { pass_manager->AddPass( cinn::dialect::ir::CreateLowerCinnDyShapeFusionOpPass()); + } else { + pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); } - - pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass()); pass_manager->AddPass( cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass()); pass_manager->Run(program); } +template +int64_t GetOpCount(const ::pir::Operation* op) { + int64_t count = 0; + for (auto& region : *op) { + for (auto& block : region) { + for (auto& sub_op : block) { + if (sub_op.isa()) { + count++; + continue; + } + if (sub_op.num_regions() > 0) { + count += GetOpCount(&sub_op); + } + } + } + } + return count; +} + void ApplyCinnPass(::pir::Program* program, const std::function()>& CreatePassManager) { + ApplyPdToCinnPass(program, CreatePassManager); ApplyCinnPreprocessPass(program, CreatePassManager); + ApplyBuildGroupOpPass(program, CreatePassManager); + ApplyGroupOpPass(program, CreatePassManager); + ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager); + LOG(INFO) << "FusionOp count before lowering : *****[ " + << GetOpCount(program->module_op()) + << " ]*****"; ApplyCinnLowerPass(program, CreatePassManager); } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc new file mode 100644 index 0000000000000..e0c52169df0a6 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/pir/include/core/builtin_type_interfaces.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" +#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" + +namespace cinn { +namespace dialect { +namespace ir { + +class AddYieldStoreInFusionOpPattern + : public pir::OpRewritePattern<::pir::YieldOp> { + public: + using pir::OpRewritePattern<::pir::YieldOp>::OpRewritePattern; + + bool MatchAndRewrite(::pir::YieldOp op, + pir::PatternRewriter& rewriter) const override { + auto& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + for (auto i = 0; i < op->num_operands(); ++i) { + if (op->operand_source(i).use_count() == 1) { + continue; + } + + auto store_op = rewriter.Build( + op->operand_source(i), op->operand_source(i).type()); + auto orignal_base = op->operand_source(i); + op->operand(i).set_source(store_op.result(0)); + + if (shape_analysis.HasShapeOrDataForValue(orignal_base)) { + shape_analysis.SetShapeOrDataForValue( + store_op.result(0), + shape_analysis.GetShapeOrDataForValue(orignal_base)); + } + } + + return true; + } +}; + +class AddStoreInFusionOpPass : public pir::Pass { + public: + AddStoreInFusionOpPass() + : pir::Pass("add_store_in_fusion_op", /*opt_level=*/1) {} + + bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet ps(context); + ps.Add(context); + + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } + + void Run(pir::Operation* op) override { + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 1; + for (uint32_t i = 0; i < op->num_regions(); ++i) { + for (auto& block : op->region(i)) { + for (auto& op : block) { + if (op.isa()) { + auto fusion_op = op.dyn_cast(); + if (fusion_op.GetOperators().size() == 2 && + fusion_op.GetOperators() + .front() + ->isa()) { + continue; + } + auto [_, num_rewrites] = + pir::ApplyPatternsGreedily(&op, patterns_, cfg); + AddStatistics(num_rewrites); + } + } + } + } + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->num_regions() > 0; + } + + private: + pir::FrozenRewritePatternSet patterns_; +}; + +std::unique_ptr CreateAddStoreInFusionOpPass() { + return std::make_unique(); +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/common/dim_expr_util.h b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h similarity index 62% rename from paddle/cinn/common/dim_expr_util.h rename to paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h index c3eec6be4a125..403e9a13ce38b 100644 --- a/paddle/cinn/common/dim_expr_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h @@ -14,16 +14,15 @@ #pragma once -#include -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" -#include "paddle/pir/include/core/builder.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" +#include +#include "paddle/pir/include/pass/pass.h" -namespace cinn::common { +namespace cinn { +namespace dialect { +namespace ir { -symbol::DimExpr SubstituteDimExpr( - const symbol::DimExpr& dim_expr, - const std::unordered_map& - pattern_to_replacement); +std::unique_ptr CreateAddStoreInFusionOpPass(); -} +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc index 9f9856004646f..9fd5a721ac825 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.cc @@ -28,12 +28,14 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/cinn_group_cluster_pass.h" +#include "paddle/cinn/frontend/group_cluster/group_cluster.h" #include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/common/ddim.h" +#include "paddle/common/flags.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" @@ -47,6 +49,8 @@ #include "paddle/pir/include/pattern_rewrite/pattern_match.h" #include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" +PD_DECLARE_bool(cinn_new_cluster_op_method); + namespace cinn { namespace dialect { namespace ir { @@ -117,6 +121,7 @@ struct GroupClusterNode { // if kind is reduce, loop ranges equal input dim // if kind id elementwise or broadcast, loop ranges equal output dim std::vector loop_ranges; + std::vector loop_rangs_expr; std::unordered_map<::pir::Operation*, std::vector> alignment_schedule_info; @@ -125,7 +130,7 @@ struct GroupClusterNode { return GetListOutsideInput(ops); } - std::string DebugStr() { + std::string DebugStr() const { std::stringstream ss; ::pir::IrPrinter printer(ss); @@ -155,6 +160,16 @@ struct GroupClusterNode { return ss.str(); } + bool HasYieldOp( + const std::unordered_set<::pir::Operation*>& all_yield_ops) const { + for (const auto& op : ops) { + if (all_yield_ops.find(op) != all_yield_ops.end()) { + return true; + } + } + return false; + } + void MergeNode(const GroupClusterNode& node, const ScheduleInfoNode& inner_sch_node) { std::unordered_set<::pir::Operation*> inner_ops(ops.begin(), ops.end()); @@ -182,6 +197,7 @@ struct GroupClusterNode { if ((node.group_kind == cinn::hlir::framework::kReduction) || (node.group_kind == cinn::hlir::framework::kBroadcast)) { this->loop_ranges = node.loop_ranges; + this->loop_rangs_expr = node.loop_rangs_expr; } if (node.group_kind == cinn::hlir::framework::kReduction) { this->reduce_axis = node.reduce_axis; @@ -189,6 +205,7 @@ struct GroupClusterNode { if ((ops.size() == 1) && (ops.front()->name() == "cinn_op.reshape")) { this->loop_ranges = node.loop_ranges; + this->loop_rangs_expr = node.loop_rangs_expr; } } @@ -232,7 +249,6 @@ std::vector<::pir::Value> GenerateOutputValue( if (outside_need_value.count(op->result(i))) { if (!inserted_val.count(op->result(i))) { temp_out.push_back(op->result(i)); - inserted_val.insert(op->result(i)); } } @@ -252,9 +268,10 @@ cinn::dialect::GroupInfo BuildGroupInfo( const GroupClusterNode& node, const std::unordered_map<::pir::Operation*, std::vector>& new_align_info) { - cinn::dialect::GroupInfo group_info({}); + cinn::dialect::GroupInfo group_info(vec_new_op_list); group_info.group_id = BuildGroupId(vec_new_op_list); group_info.loop_ranges = node.loop_ranges; + group_info.loop_ranges_expr = node.loop_rangs_expr; group_info.reduce_axis = node.reduce_axis; group_info.op_pattern_kind = node.group_kind; group_info.alignment_schedule_info = new_align_info; @@ -287,10 +304,13 @@ ::pir::GroupOpsVec CloneOps( auto new_op = op->Clone(*ir_mapping, clone_options); auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + for (size_t i = 0; i < op->num_results(); ++i) { - shape_analysis.SetShapeOrDataForValue( - new_op->result(i), - shape_analysis.GetShapeOrDataForValue(op->result(i))); + if (shape_analysis.HasShapeOrDataForValue(op->result(i))) { + shape_analysis.SetShapeOrDataForValue( + new_op->result(i), + shape_analysis.GetShapeOrDataForValue(op->result(i))); + } } vec_new_op_list.push_back(new_op); @@ -336,6 +356,7 @@ ::pir::Operation* ReplaceWithGroupOp( group_ops.end()); std::vector<::pir::Value> new_output; + for (size_t i = 0; i < output_value.size(); ++i) { new_output.push_back(ir_mapping->Lookup<::pir::Value>(output_value[i])); } @@ -349,7 +370,16 @@ ::pir::Operation* ReplaceWithGroupOp( bool CanFuse(const GroupClusterNode& first, const GroupClusterNode& second, - ScheduleInfoNode* sch_node) { + ScheduleInfoNode* sch_node, + const std::unordered_set<::pir::Operation*>& all_yield_ops) { + if (first.HasYieldOp(all_yield_ops)) { + return false; + } + + if (!first.ops.empty() && + (first.ops.front()->name() == "cinn_op.generate_shape")) { + return true; + } if ((second.ops.size() == 1) && (second.ops.front()->name() == "cinn_op.reshape") && (IsLastReshape(second.ops.front()))) { @@ -398,7 +428,13 @@ bool CanFuse(const GroupClusterNode& first, if (first.loop_ranges != second.loop_ranges) { sch_node->type = hlir::framework::pir::ScheduleAlignType::kBroadcast; - sch_node->axis_info = first.reduce_axis; + for (auto& d : first.reduce_axis) { + if (d < 0) { + sch_node->axis_info.push_back(d + first.loop_ranges.size()); + } else { + sch_node->axis_info.push_back(d); + } + } sch_node->factor_info = first.loop_ranges; } return true; @@ -513,27 +549,111 @@ void GetClusterNodeBasicInfo(::pir::Operation* op, .type() .dyn_cast() .dims()); + + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + if (shape_analysis.HasShapeOrDataForValue(op->operand_source(0))) { + auto sym_shape = + shape_analysis.GetShapeOrDataForValue(op->operand_source(0)).shape(); + cluster_node->loop_rangs_expr = sym_shape; + for (size_t i = 0; i < cluster_node->loop_ranges.size(); ++i) { + if (cluster_node->loop_ranges[i] < 0 && sym_shape[i].isa()) { + cluster_node->loop_ranges[i] = sym_shape[i].Get(); + } + } + } + + if (cluster_node->reduce_axis.size() == 0) { + for (size_t i = 0; i < cluster_node->loop_ranges.size(); ++i) { + cluster_node->reduce_axis.push_back(i); + } + } + } else if (cluster_node->group_kind == cinn::hlir::framework::kElementWise) { cluster_node->loop_ranges = phi::vectorize(op->result(0) .type() .dyn_cast() .dims()); - - } else if (cluster_node->group_kind == cinn::hlir::framework::kBroadcast) { + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + if (shape_analysis.HasShapeOrDataForValue(op->result(0))) { + auto sym_shape = + shape_analysis.GetShapeOrDataForValue(op->result(0)).shape(); + cluster_node->loop_rangs_expr = sym_shape; + for (size_t i = 0; i < cluster_node->loop_ranges.size(); ++i) { + if (cluster_node->loop_ranges[i] < 0 && sym_shape[i].isa()) { + cluster_node->loop_ranges[i] = sym_shape[i].Get(); + } + } + } + } else if (cluster_node->group_kind == cinn::hlir::framework::kInjective) { cluster_node->loop_ranges = phi::vectorize(op->result(0) .type() .dyn_cast() .dims()); - + } else if (cluster_node->group_kind == cinn::hlir::framework::kBroadcast) { + const std::vector output_shape = [&] { + auto output_shape = + phi::vectorize(op->result(0) + .type() + .dyn_cast() + .dims()); + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + + if (shape_analysis.HasShapeOrDataForValue(op->result(0))) { + auto shape_info = + shape_analysis.GetShapeOrDataForValue(op->result(0)).shape(); + cluster_node->loop_rangs_expr = shape_info; + for (size_t i = 0; i < shape_info.size(); ++i) { + if (shape_info[i].isa()) { + output_shape[i] = shape_info[i].Get(); + } + } + } + return output_shape; + }(); + cluster_node->loop_ranges = output_shape; sch_node->type = hlir::framework::pir::ScheduleAlignType::kBroadcast; - sch_node->axis_info = - cinn::dialect::ir::GetVectorAttr(op, "broadcast_axes"); - sch_node->factor_info = cinn::dialect::ir::GetVectorAttr(op, "out_shape"); + sch_node->axis_info = [&] { + int x_rank = op->operand_source(0) + .type() + .dyn_cast() + .dims() + .size(); + int out_rank = + op->result(0).type().dyn_cast().dims().size(); + std::vector broadcast_axes(x_rank, 0); + size_t index_gap = out_rank - x_rank; + for (size_t i = 0; i < x_rank; ++i) { + broadcast_axes[i] = i + index_gap; + } + return broadcast_axes; + }(); + sch_node->factor_info = output_shape; + + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + if (shape_analysis.HasShapeOrDataForValue(op->result(0))) { + auto sym_shape = + shape_analysis.GetShapeOrDataForValue(op->result(0)).shape(); + for (size_t i = 0; i < cluster_node->loop_ranges.size(); ++i) { + if (cluster_node->loop_ranges[i] < 0 && sym_shape[i].isa()) { + cluster_node->loop_ranges[i] = sym_shape[i].Get(); + } + + if (sch_node->factor_info[i] < 0 && sym_shape[i].isa()) { + sch_node->factor_info[i] = sym_shape[i].Get(); + } + } + } + } else if (op->name() == "cinn_op.generate_shape") { + // do nothing for now } else { PADDLE_THROW(phi::errors::Unimplemented( - "only support elementwise, broadcast, reduce type")); + "only support elementwise, broadcast, injective, reduce type")); } } @@ -553,50 +673,106 @@ std::vector<::pir::Operation*> GetPreOps( bool CanOpMergeNode( const std::unordered_map<::pir::Operation*, GroupClusterNode>& op_path_info, ::pir::Operation* pre_op, - ::pir::Operation* cur_op) { + ::pir::Operation* cur_op, + const std::unordered_set<::pir::Operation*>& all_yield_ops) { const auto& node1 = op_path_info.at(pre_op); const auto& node2 = op_path_info.at(cur_op); + + if (node1.HasYieldOp(all_yield_ops) || + all_yield_ops.find(pre_op) != all_yield_ops.end()) { + return false; + } + // reduce can not fuse with any op in first stage if (cinn::hlir::framework::pir::CompatibleInfo::OpKind(*pre_op) == cinn::hlir::framework::kReduction) { return false; } - // TODO(phlrain): need update here - // different loop range can merge, like [128, 128, 1], with [128, 128] - if ((cinn::hlir::framework::pir::CompatibleInfo::OpKind(*cur_op) != - cinn::hlir::framework::kBroadcast) && - (op_path_info.at(cur_op).loop_ranges != - op_path_info.at(pre_op).loop_ranges)) { - return false; + if (cinn::hlir::framework::pir::CompatibleInfo::OpKind(*pre_op) <= + cinn::hlir::framework::kInjective) { + return true; } - - return true; + return false; } -bool ShouldOutputPreNode( - const std::unordered_map<::pir::Operation*, GroupClusterNode>& op_path_info, - ::pir::Operation* pre_op, - ::pir::Operation* cur_op) { - if (cinn::hlir::framework::pir::CompatibleInfo::OpKind(*pre_op) == - cinn::hlir::framework::kReduction) { - return false; +namespace horizontal_merge_detail { +template +std::optional> FindMergePair( + const ConditionFunc& condition_fn, + const std::vector& elements) { + for (int i = 0; i < elements.size(); ++i) { + for (int j = i + 1; j < elements.size(); ++j) { + if (condition_fn(elements[i], elements[j])) { + return std::make_pair(i, j); + } + } } + return std::nullopt; +} - // TODO(phlrain): need update here - // different loop range can merge, like [128, 128, 1], with [128, 128] - if ((cinn::hlir::framework::pir::CompatibleInfo::OpKind(*cur_op) != - cinn::hlir::framework::kBroadcast) && - (op_path_info.at(cur_op).loop_ranges != - op_path_info.at(pre_op).loop_ranges)) { - return true; +template +void MergeAndRemove(const MergeFunc& merge_fn, + const std::pair& range, + std::vector* elements) { + const auto& merged = + merge_fn(elements->at(range.first), elements->at(range.second)); + elements->erase(elements->begin() + range.second); + elements->erase(elements->begin() + range.first); + elements->push_back(merged); +} + +template +void FindPatternAndMerge(const ConditionFunc& condition_fn, + const MergeFunc& merge_fn, + std::vector* elements) { + while (true) { + auto merge_pair = FindMergePair(condition_fn, *elements); + if (merge_pair.has_value()) { + VLOG(4) << "FindPatternAndMerge: find and merge!"; + MergeAndRemove(merge_fn, merge_pair.value(), elements); + } else { + break; + } } +} - return false; +bool SameOutputShape(const GroupClusterNode& a, const GroupClusterNode& b) { + return a.loop_ranges == b.loop_ranges; +} + +bool CanHorizontalMerge(const GroupClusterNode& a, const GroupClusterNode& b) { + const auto& IsTrivialKind = [](OpPatternKind kind) { + return kind == OpPatternKind::kElementWise || + kind == OpPatternKind::kBroadcast || + kind == OpPatternKind::kInjective; + }; + return IsTrivialKind(a.group_kind) && IsTrivialKind(b.group_kind) && + SameOutputShape(a, b); +} + +GroupClusterNode HorizontalMerge(const GroupClusterNode& a, + const GroupClusterNode& b) { + GroupClusterNode res = a; + res.MergeNode(b, ScheduleInfoNode()); + return res; } +std::vector HorizontalMergePass( + const std::vector& last_stage_output) { + VLOG(4) << "Before HorizontalMergePass, cluster size is = " + << last_stage_output.size(); + std::vector third_stage_output = last_stage_output; + FindPatternAndMerge(CanHorizontalMerge, HorizontalMerge, &third_stage_output); + VLOG(4) << "After HorizontalMergePass, cluster size is = " + << third_stage_output.size(); + return third_stage_output; +} +} // namespace horizontal_merge_detail + std::vector NodeMergeWithNode( - const std::vector& first_stage_output) { + const std::vector& first_stage_output, + const std::unordered_set<::pir::Operation*>& all_yield_ops) { // stage 2 merge // for now we merge node in same pass // only for vertical fuse @@ -631,7 +807,7 @@ std::vector NodeMergeWithNode( const auto& pre_node = second_stage_output[pre_id]; ScheduleInfoNode sch_node; - auto can_fuse = CanFuse(pre_node, new_node, &sch_node); + auto can_fuse = CanFuse(pre_node, new_node, &sch_node, all_yield_ops); if (can_fuse) { // merge pre node to new_node @@ -658,6 +834,36 @@ std::vector NodeMergeWithNode( return second_stage_output; } +std::vector NewOpMergeWithOp( + cinn::dialect::GroupOp group_op) { + auto cluster_result = frontend::ClusterOps(group_op.GetOperators(), true); + std::vector> result; + std::transform(cluster_result.begin(), + cluster_result.end(), + std::back_inserter(result), + [](const frontend::group_cluster::PatternNodePtr node) { + return node->GetOps(); + }); + + // Each stmts corresponds to each fusion op(cluster node). + // Concat all the ops of patterns in the stmts, and make them the op list of + // cluster node. + VLOG(4) << "Start Creating Cluster Nodes!"; + std::vector output_cluster_nodes; + for (const auto& op_set : result) { + GroupClusterNode cluster_node; + for (const auto* op : op_set) { + cluster_node.ops.push_back(const_cast(op)); + auto op_kind = cinn::hlir::framework::pir::CompatibleInfo::OpKind(*op); + cluster_node.group_kind = + cluster_node.group_kind > op_kind ? cluster_node.group_kind : op_kind; + } + output_cluster_nodes.push_back(cluster_node); + } + VLOG(4) << "Finished Creating Cluster Nodes!"; + return output_cluster_nodes; +} + std::vector OpMergeWithOp(cinn::dialect::GroupOp group_op) { // op merge with op auto inner_values = GetInnerGeneValue(group_op.GetOperators()); @@ -670,11 +876,11 @@ std::vector OpMergeWithOp(cinn::dialect::GroupOp group_op) { std::unordered_set<::pir::Operation*> yield_output_ops; std::unordered_set<::pir::Operation*> first_output_ops; + std::unordered_set<::pir::Operation*> all_yield_ops; auto yield_op = op_list.back(); for (size_t i = 0; i < yield_op->num_operands(); ++i) { - if (yield_op->operand_source(i).defining_op()->result(0).use_count() == 1) { - yield_output_ops.insert(yield_op->operand_source(i).defining_op()); - } + all_yield_ops.insert(yield_op->operand_source(i).defining_op()); + yield_output_ops.insert(yield_op->operand_source(i).defining_op()); } // first stage op fuse op @@ -697,19 +903,9 @@ std::vector OpMergeWithOp(cinn::dialect::GroupOp group_op) { continue; } - if (CanOpMergeNode(op_path, pre_op, op)) { + if (CanOpMergeNode(op_path, pre_op, op, all_yield_ops)) { cluster_node.MergePreNode(op_path.at(pre_op), sch_node); } - - // TODO(phlrain): should remove this strategy - if (ShouldOutputPreNode(op_path, pre_op, op)) { - // Can not merge here, should output pre_op cluster Node - if (!first_output_ops.count(pre_op)) { - first_stage_output.push_back(op_path[pre_op]); - first_output_ops.insert(pre_op); - } - continue; - } } op_list.push_back(op); @@ -717,8 +913,10 @@ std::vector OpMergeWithOp(cinn::dialect::GroupOp group_op) { if (yield_output_ops.count(op) || cinn::hlir::framework::pir::CompatibleInfo::OpKind(*op) == cinn::hlir::framework::kReduction) { - // TODO(phlrain): yiled output no nedd to push into first stage output, + // TODO(phlrain): yield output no need to push into first stage output, // Update here + VLOG(4) << "Split Group by yield output ops: " + << yield_output_ops.count(op); if (!first_output_ops.count(op)) { first_stage_output.push_back(op_path[op]); first_output_ops.insert(op); @@ -726,11 +924,16 @@ std::vector OpMergeWithOp(cinn::dialect::GroupOp group_op) { } } + VLOG(4) << "first stage output size " << first_stage_output.size(); return first_stage_output; } std::vector GroupSplit(cinn::dialect::GroupOp group_op) { // stage 1 + if (FLAGS_cinn_new_cluster_op_method) { + return NewOpMergeWithOp(group_op); + } + auto first_stage_output = OpMergeWithOp(group_op); if (first_stage_output.size() <= 1) { @@ -738,12 +941,22 @@ std::vector GroupSplit(cinn::dialect::GroupOp group_op) { } // stage 2 - auto second_stage_output = NodeMergeWithNode(first_stage_output); - + auto yield_op = group_op.GetOperators().back(); + std::unordered_set<::pir::Operation*> all_yield_ops; + for (size_t i = 0; i < yield_op->num_operands(); ++i) { + all_yield_ops.insert(yield_op->operand_source(i).defining_op()); + } + auto second_stage_output = + NodeMergeWithNode(first_stage_output, all_yield_ops); if (second_stage_output.size() == 1) { return second_stage_output; } + // Note: horizontal merge will make loop in graph, skip it + // // stage 3 + // auto third_stage_output = + // horizontal_merge_detail::HorizontalMergePass(second_stage_output); + std::vector> pre_ids_info; auto out_id_list = SortNodeList(&second_stage_output, &pre_ids_info); @@ -820,27 +1033,38 @@ class CinnGroupClusterPattern auto all_output_values = BuildValueOrderByYieldOp(split_res, group_op); for (auto& node : split_res) { + if (node.ops.size() == 0) { + continue; + } auto output_values = GenerateOutputValue(node.ops, all_output_values); + VLOG(4) << "cluster node output size: " << output_values.size(); auto uniq_ops = SortByOriginalOrderAndUniq(group_op, node.ops); auto new_group_op = ReplaceWithGroupOp( &rewriter, uniq_ops, node, output_values, &ir_mapping); + auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get( + group_op->GetParentProgram()); // update ir mapping for (size_t i = 0; i < output_values.size(); ++i) { ir_mapping.Add(output_values[i], new_group_op->result(i)); + if (shape_analysis.HasShapeOrDataForValue(output_values[i])) { + shape_analysis.SetShapeOrDataForValue( + new_group_op->result(i), + shape_analysis.GetShapeOrDataForValue(output_values[i])); + } } - for (size_t i = 0; i < output_values.size(); ++i) { auto find_it = all_output_values.find(output_values[i]); if ((find_it != all_output_values.end()) && (find_it->second < group_op->num_results())) { - // id < num_results means yiled input + // id < num_results means yield input rewriter.ReplaceAllUsesWith(group_op.result(find_it->second), new_group_op->result(i)); } } } + rewriter.EraseOp(group_op); return true; @@ -861,7 +1085,7 @@ class CinnGroupClusterPass : public pir::PatternRewritePass { } bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; + return op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc index cab96a8bd27f9..2bebdf4c2149f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/dynamic_reshape_pass.cc @@ -28,14 +28,30 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis, pir::PatternRewriter& rewriter) { // NOLINT pir::Value output = op->result(0); - // The value of shape attribute is fake, we only use the output shape info - // in shape analysis. - std::vector shape( - output.type().dyn_cast().dims().size(), 1); - shape[0] = -1; + // Try to Get more detail output info + const auto& GetOutputShape = [&]() -> std::vector { + std::vector shape = phi::vectorize( + output.type().dyn_cast().dims()); + + if (shape_analysis->HasShapeOrDataForValue(op->result(0))) { + const auto& shape_info = + shape_analysis->GetShapeOrDataForValue(op->result(0)).shape(); + int temp_dim = -1; + + for (size_t i = 0; i < shape_info.size(); ++i) { + if (shape_info[i].isa()) { + shape[i] = shape_info[i].Get(); + } else { + shape[i] = temp_dim; + temp_dim = 1; + } + } + } + return shape; + }; - auto cinn_reshape = - rewriter.Build(op->operand_source(0), shape); + auto cinn_reshape = rewriter.Build( + op->operand_source(0), GetOutputShape()); shape_analysis->SetShapeOrDataForValue( cinn_reshape.result(0), shape_analysis->GetShapeOrDataForValue(output)); @@ -97,43 +113,23 @@ class DynamicUnsqueezeOpPattern } }; -class DynamicReshapeOpPass : public pir::Pass { +class DynamicReshapeOpPass : public pir::PatternRewritePass { public: DynamicReshapeOpPass() - : pir::Pass("cinn_dynamic_reshape_op_pass", /*opt_level=*/1) {} + : pir::PatternRewritePass("cinn_dynamic_reshape_op_pass", 1) {} - bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { pir::RewritePatternSet ps(context); - ps.Add(context); + // Comment out the DynamicReshapeOpPattern to use pd_op.reshape in + // cinn.group ps.Add(context); ps.Add(context); ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation* op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 10; - for (uint32_t i = 0; i < op->num_regions(); ++i) { - for (auto& block : op->region(i)) { - for (auto& op : block) { - if (op.isa()) { - auto [_, num_rewrites] = - pir::ApplyPatternsGreedily(&op, patterns_, cfg); - AddStatistics(num_rewrites); - } - } - } - } + return ps; } bool CanApplyOn(pir::Operation* op) const override { - return op->num_regions() > 0; + return op->isa() && op->num_regions() > 0; } - - private: - pir::FrozenRewritePatternSet patterns_; }; std::unique_ptr CreateDynamicReshapeOpPass() { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index f396e79925a37..11361d34300ef 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -16,15 +16,18 @@ #include #include #include "paddle/cinn/common/bfs_walker.h" +#include "paddle/cinn/common/topo_walker.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builtin_dialect.h" #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" @@ -56,8 +59,8 @@ std::vector FindSourceDenseTensorOfDimTensor( // find input dimension tensor; pir::Operation* owner = value.defining_op(); if (owner == nullptr) return; - for (int i = 0; i < owner->num_operands(); ++i) { - Visit(owner->operand_source(i)); + for (auto input_value : pir::GetUsedExternalValue(*owner)) { + Visit(input_value); } }; const auto& IsDimTensorOrListDimExpr = symbol::Overloaded{ @@ -107,8 +110,12 @@ bool MakeGenerateShapeOpAttribute( std::vector* output_dim_expr_attrs, GenerateShapeOp::SymbolBindings* symbol_bindings) { const auto& shape_or_data_dim_exprs = ShapeOrDataDimExprs4Value(output_shape); - CHECK(shape_or_data_dim_exprs.data().has_value()); - const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); + ExprVec data_vec = + paddle::dialect::details::GetExprVecFromData(shape_or_data_dim_exprs); + // CHECK(shape_or_data_dim_exprs.data().has_value()); + CHECK(data_vec.size()); + // const auto& out_dim_exprs = shape_or_data_dim_exprs.data().value(); + const auto& out_dim_exprs = data_vec; return MakeGenerateShapeOpAttribute(ir_context, ShapeOrDataDimExprs4Value, out_dim_exprs, @@ -118,6 +125,145 @@ bool MakeGenerateShapeOpAttribute( symbol_bindings); } +std::unordered_set GetOpSetFromOutputToInputsValue( + const std::vector& input_values, pir::Value output_value) { + std::unordered_set op_set; + const std::unordered_set input_value_set(input_values.begin(), + input_values.end()); + auto VisitNextOp = [&](pir::Operation* node, + const std::function& Visit) { + for (uint32_t i = 0; i < node->num_operands(); ++i) { + pir::Value in_value = node->operand_source(i); + if (!in_value || !in_value.type()) continue; + if (input_value_set.count(in_value)) continue; + if (op_set.count(in_value.defining_op())) continue; + + Visit(in_value.defining_op()); + } + }; + common::BfsWalker walker(VisitNextOp); + walker(output_value.defining_op(), [&](pir::Operation* op) { + if (!op) return; + op_set.insert(op); + }); + return op_set; +} + +std::vector GetSubGraphFromOutputToInputsValue( + const std::vector& input_values, pir::Value output_value) { + const std::unordered_set& op_set = + GetOpSetFromOutputToInputsValue(input_values, output_value); + auto VisitUpstreamOp = + [&](pir::Operation* node, + const std::function& Visit) { + for (uint32_t i = 0; i < node->num_operands(); ++i) { + pir::Value in_value = node->operand_source(i); + if (!in_value || !in_value.type()) continue; + if (in_value.defining_op() == nullptr) continue; + if (op_set.count(in_value.defining_op()) == 0) continue; + Visit(in_value.defining_op()); + } + }; + auto VisitDownstreamOp = + [&](pir::Operation* node, + const std::function& Visit) { + for (uint32_t i = 0; i < node->num_results(); ++i) { + for (auto iter = node->result(i).use_begin(); + iter != node->result(i).use_end(); + ++iter) { + if (op_set.count(iter->owner())) { + Visit(iter->owner()); + } + } + } + }; + common::TopoWalker walker(VisitUpstreamOp, + VisitDownstreamOp); + + const std::vector input_ops = [&] { + const std::unordered_set input_value_set(input_values.begin(), + input_values.end()); + auto IsInputOp = [&](pir::Operation* op) { + for (uint32_t i = 0; i < op->num_operands(); ++i) { + if (input_value_set.count(op->operand_source(i)) == 0) { + return false; + } + } + return true; + }; + std::vector input_ops; + for (auto* op : op_set) { + if (IsInputOp(op)) { + input_ops.push_back(op); + } + } + return input_ops; + }(); + std::vector ops; + walker(input_ops.begin(), input_ops.end(), [&](pir::Operation* node) { + if (!node) return; + ops.push_back(node); + }); + return ops; +} + +void InferSymbolicShapeForSubgraph( + const std::vector& ops, + pir::ShapeConstraintIRAnalysis* shape_analysis) { + for (auto* op : ops) { + auto infer_symbolic_shape_interface = + op->dyn_cast(); + if (infer_symbolic_shape_interface) { + infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + } + } +} + +void UpdateLocalShapeAnalysis( + const std::vector& input_tensors, + pir::Value shape, + const std::unordered_map& dim_expr_map, + const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value, + pir::ShapeConstraintIRAnalysis* shape_analysis) { + // init inputs value's dim expr + auto CreateExprsByExprMap = + [&](const std::vector& dim_exprs) { + std::vector new_shape; + new_shape.reserve(dim_exprs.size()); + for (const auto& dim_expr : dim_exprs) { + auto iter = dim_expr_map.find(dim_expr); + if (iter == dim_expr_map.end()) { + new_shape.push_back(dim_expr); + } else { + new_shape.push_back(iter->second); + } + } + return new_shape; + }; + + for (const auto& input_tensor : input_tensors) { + const auto& shape_or_data = ShapeOrDataDimExprs4Value(input_tensor); + std::vector new_shape = + CreateExprsByExprMap(shape_or_data.shape()); + if (shape_or_data.data()) { + std::vector new_data = + CreateExprsByExprMap(shape_or_data.data().value()); + shape_analysis->SetShapeOrDataForValue( + input_tensor, symbol::TensorShapeOrDataDimExprs(new_shape, new_data)); + } else { + shape_analysis->SetShapeOrDataForValue( + input_tensor, symbol::TensorShapeOrDataDimExprs(new_shape)); + } + } + // infer new symbol shape for shape value + std::vector sub_graph_ops = + GetSubGraphFromOutputToInputsValue(input_tensors, shape); + InferSymbolicShapeForSubgraph(sub_graph_ops, shape_analysis); +} + std::optional GetOutOfRewrittenGenerateShapeOp( pir::Value shape, pir::PatternRewriter* rewriter, @@ -125,10 +271,61 @@ std::optional GetOutOfRewrittenGenerateShapeOp( std::vector input_tensors = FindSourceDenseTensorOfDimTensor(shape, ShapeOrDataDimExprs4Value); if (input_tensors.empty()) return std::nullopt; + const std::unordered_map dim_expr_map = + [&] { + std::unordered_map dim_expr_map; + int64_t local_dim_expr_id = 0; + for (auto input_tensor : input_tensors) { + const auto& shape_or_data = ShapeOrDataDimExprs4Value(input_tensor); + for (const auto& dim_expr : shape_or_data.shape()) { + if (!dim_expr.isa() && dim_expr_map.count(dim_expr) == 0) { + dim_expr_map[dim_expr] = + symbol::DimExpr("SS" + std::to_string(local_dim_expr_id++)); + } + } + if (shape_or_data.data()) { + for (const auto& dim_expr : shape_or_data.data().value()) { + if (!dim_expr.isa() && + dim_expr_map.count(dim_expr) == 0) { + dim_expr_map[dim_expr] = + symbol::DimExpr("SS" + std::to_string(local_dim_expr_id++)); + } + } + } + } + return dim_expr_map; + }(); + + const bool has_complex_dim_expr = [&]() { + bool has_complex_dim_expr = false; + for (const auto& kv : dim_expr_map) { + if (!kv.first.isa() && !kv.first.isa()) { + has_complex_dim_expr = true; + break; + } + } + return has_complex_dim_expr; + }(); + pir::ShapeConstraintIRAnalysis shape_analysis; + if (has_complex_dim_expr) { + UpdateLocalShapeAnalysis(input_tensors, + shape, + dim_expr_map, + ShapeOrDataDimExprs4Value, + &shape_analysis); + } + + auto LocalDimExprs4Value = [&](pir::Value value) { + if (has_complex_dim_expr) { + return shape_analysis.GetShapeOrDataForValue(value); + } + return ShapeOrDataDimExprs4Value(value); + }; + std::vector output_dim_expr_attrs{}; GenerateShapeOp::SymbolBindings symbol_bindings{}; bool success = MakeGenerateShapeOpAttribute(rewriter->ir_context(), - ShapeOrDataDimExprs4Value, + LocalDimExprs4Value, shape, /*origin inputs*/ input_tensors, /*minimal inputs*/ &input_tensors, @@ -206,7 +403,7 @@ class FuseShapeOpsIntoGenerateShapeOpPass : public pir::PatternRewritePass { } bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; + return op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/check_infer_symbolic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/check_infer_symbolic_pass.cc index 3ab2e8c7c7a3d..953e268b27a80 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/check_infer_symbolic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/check_infer_symbolic_pass.cc @@ -118,7 +118,7 @@ void CompareStaticAndDynamicValueShape( std::vector> dynamic_value_shape = GetDynamicValueShape(value, shape_analysis); if (static_value_shape != dynamic_value_shape) { - VLOG(4) << "CheckInferSymbolic failed, in the fellowing program, the " + VLOG(4) << "CheckInferSymbolic failed, in the following program, the " << op_index << "th op : the shape is not equal\nthe static shape is: " << SprintShape(static_value_shape) << ", and the dynamic shape is: " diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc index 325421d92abe6..588312cc80114 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_0d_to_1d_pass.cc @@ -19,9 +19,11 @@ #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" namespace cinn { namespace dialect { @@ -35,13 +37,14 @@ class FullOpPattern : public pir::OpRewritePattern { bool Match(paddle::dialect::FullOp op) const override { return op.attribute("shape") - .dyn_cast() - .data() - .size() == 0; + .dyn_cast() + .data() + .size() == 0 && + op.out().type().dyn_cast().dims().size() == 0; } void Rewrite(paddle::dialect::FullOp op, - pir::PatternRewriter &rewriter) const override { + pir::PatternRewriter& rewriter) const override { float factor = op->attribute("value").dyn_cast<::pir::FloatAttribute>().data(); phi::DataType dtype = op->attribute("dtype") @@ -58,20 +61,131 @@ class FullOpPattern : public pir::OpRewritePattern { } }; +class SliceOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool Match(paddle::dialect::SliceOp op) const override { + const auto& tensor_type = + op.result(0).type().dyn_cast(); + + return tensor_type.dims().size() == 0; + } + + void Rewrite(paddle::dialect::SliceOp op, + pir::PatternRewriter& rewriter) const override { + std::vector vec_dims; + pir::Attribute attr_dims = + pir::ArrayAttribute::get(pir::IrContext::Instance(), vec_dims); + + op->set_attribute("decrease_axis", attr_dims); + } +}; + +class SumOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool Match(paddle::dialect::SumOp op) const override { + const auto& tensor_type = + op.result(0).type().dyn_cast(); + return tensor_type.dims().size() == 0; + } + + void Rewrite(paddle::dialect::SumOp op, + pir::PatternRewriter& rewriter) const override { + std::vector axis{}; + const auto& dtype = op->attribute("dtype") + .dyn_cast() + .data(); + auto new_reduce_op = rewriter.Build( + op.operand_source(0), axis, dtype, /*keepdim=*/true); + auto reshape_op = rewriter.Build( + new_reduce_op.result(0), /*shape=*/std::vector({1})); + rewriter.ReplaceAllUsesWith(op.result(0), reshape_op.result(0)); + rewriter.EraseOp(op); + } +}; + +pir::DenseTensorType Make1DTensorType(const pir::DenseTensorType& tensor_type) { + return pir::DenseTensorType::get(pir::IrContext::Instance(), + tensor_type.dtype(), + {1}, + tensor_type.data_layout(), + tensor_type.lod(), + tensor_type.offset()); +} + +void ConvertValue0DTo1D(pir::Value operand) { + auto ConvertVectorType0DTo1D = + [](const pir::VectorType& vector_tensor_type) -> std::vector { + std::vector types; + for (std::size_t i = 0; i < vector_tensor_type.size(); ++i) { + CHECK(vector_tensor_type[i].isa()); + const auto& dense_type = + vector_tensor_type[i].dyn_cast(); + types.push_back(dense_type.dims().size() == 0 + ? Make1DTensorType(dense_type) + : vector_tensor_type[i]); + } + return types; + }; + + if (const auto& tensor_type = + operand.type().dyn_cast()) { + if (tensor_type.dims().size() == 0) { + operand.set_type(Make1DTensorType(tensor_type)); + } + } else if (const auto& vector_tensor_type = + operand.type().dyn_cast()) { + pir::Builder builder(pir::IrContext::Instance()); + std::vector inputs_type = + ConvertVectorType0DTo1D(vector_tensor_type); + operand.set_type(builder.vec_type(inputs_type)); + } else { + VLOG(4) << "Unsupported operand type: " << operand.type(); + } +} + +class WhileOpPattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool Match(paddle::dialect::WhileOp op) const override { + for (const auto& value : op.block_args()) { + if (const auto& tensor_type = + value.type().template dyn_cast()) { + if (tensor_type.dims().size() == 0) { + return true; + } + } + } + return false; + } + + void Rewrite(paddle::dialect::WhileOp op, + pir::PatternRewriter& rewriter) const override { + for (pir::Value value : op.block_args()) { + ConvertValue0DTo1D(value); + } + } +}; + class CombineOpPattern : public pir::OpRewritePattern { public: using pir::OpRewritePattern::OpRewritePattern; bool Match(pir::CombineOp op) const override { - auto out_type = op.result(0).type().dyn_cast(); - for (auto type : out_type.data()) { - if (HasZeroDim(type)) return true; + for (std::size_t i = 1; i < op->operands().size(); ++i) { + if (op.operand_source(i).type() != op.operand_source(0).type()) { + return true; + } } return false; } void Rewrite(pir::CombineOp op, - pir::PatternRewriter &rewriter) const override { + pir::PatternRewriter& rewriter) const override { pir::Builder builder(rewriter.ir_context()); const std::vector inputs_type = [&]() { @@ -83,30 +197,68 @@ class CombineOpPattern : public pir::OpRewritePattern { }(); op.result(0).set_type(builder.vec_type(inputs_type)); } - - private: - bool HasZeroDim(pir::Type type) const { - if (!type) return false; - const auto dense_tensor_type = type.dyn_cast(); - return dense_tensor_type && (dense_tensor_type.dims().size() == 0U); - } }; -class Convert0DTo1DPass : public pir::PatternRewritePass { +class Convert0DTo1DPass : public pir::Pass { public: - Convert0DTo1DPass() : pir::PatternRewritePass("convert_0D_to_1D", 1) {} + Convert0DTo1DPass() : pir::Pass("convert_0D_to_1D", 1) {} - pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + bool Initialize(pir::IrContext* context) override { pir::RewritePatternSet ps(context); ps.Add(context); ps.Add(context); + ps.Add(context); + ps.Add(context); + ps.Add(context); + patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); + return true; + } - return ps; + void Run(pir::Operation* op) override { + for (uint32_t i = 0; i < op->num_regions(); ++i) { + ApplyPatternOnOperation(op->region(i)); + for (const auto& block : op->region(i)) { + ConvertBlock0DTo1D(block); + } + } + } + + void ApplyPatternOnOperation(pir::Region& region) { // NOLINT + pir::GreedyRewriteConfig cfg; + cfg.use_top_down_traversal = true; + cfg.max_iterations = 10; + const auto& [_, num_rewrites] = + pir::ApplyPatternsGreedily(region, patterns_, cfg); + AddStatistics(num_rewrites); } - bool CanApplyOn(pir::Operation *op) const override { + bool CanApplyOn(pir::Operation* op) const override { return op->isa() && op->num_regions() > 0; } + + void ConvertOperation0DTo1D(const pir::Operation& op) { // NOLINT + for (std::size_t i = 0; i < op.num_operands(); ++i) { + ConvertValue0DTo1D(op.operand_source(i)); + } + for (std::size_t i = 0; i < op.num_results(); ++i) { + ConvertValue0DTo1D(op.result(i)); + } + } + + void ConvertBlock0DTo1D(const pir::Block& block) { + for (auto& op : block) { + ConvertOperation0DTo1D(op); + for (std::size_t i = 0; i < op.num_regions(); ++i) { + ApplyPatternOnOperation(op.region(i)); + for (auto& inner_block : op.region(i)) { + ConvertBlock0DTo1D(inner_block); + } + } + } + } + + private: + pir::FrozenRewritePatternSet patterns_; }; } // namespace diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc index 21c5047c998c9..d1550a2bdf257 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_dynamic_to_static_dim_pass.cc @@ -24,15 +24,14 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" - -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" PD_DECLARE_string(cinn_convert_dynamic_dim_to_static_dim); namespace { template -void ForEachRawDyanmicToStaticDimPair(const DoEachT& DoEach) { +void ForEachRawDynamicToStaticDimPair(const DoEachT& DoEach) { const std::string& env_var = FLAGS_cinn_convert_dynamic_dim_to_static_dim; size_t start = 0; while (true) { @@ -43,7 +42,7 @@ void ForEachRawDyanmicToStaticDimPair(const DoEachT& DoEach) { } } -std::optional> ParseRawDyanmicToStaticDimPair( +std::optional> ParseRawDynamicToStaticDimPair( const std::string& raw_pair) { size_t pos = raw_pair.find(":", 0); if (pos == std::string::npos) return std::nullopt; @@ -70,8 +69,8 @@ std::optional> ParseRawDyanmicToStaticDimPair( std::unordered_map GetDynamicToStaticDimFlag() { std::unordered_map map; - ForEachRawDyanmicToStaticDimPair([&](const std::string& raw_pair) { - if (auto pair = ParseRawDyanmicToStaticDimPair(raw_pair)) { + ForEachRawDynamicToStaticDimPair([&](const std::string& raw_pair) { + if (auto pair = ParseRawDynamicToStaticDimPair(raw_pair)) { map.insert(pair.value()); } }); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc index dd6c2d2e74905..e67cb5aacabfa 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc @@ -14,13 +14,15 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.h" -#include "paddle/cinn/common/dim_expr_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" #include "paddle/cinn/runtime/flags.h" +#include "paddle/common/enforce.h" #include "paddle/common/flags.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" #include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h" PD_DECLARE_string(cinn_convert_static_dim_to_dynamic_dim); @@ -30,7 +32,7 @@ namespace cinn::dialect::ir { namespace { template -void ForEachRawStaticDimToDyanmicPair(const DoEachT& DoEach) { +void ForEachRawStaticDimToDynamicPair(const DoEachT& DoEach) { const std::string& env_var = FLAGS_cinn_convert_static_dim_to_dynamic_dim; size_t start = 0; while (true) { @@ -41,7 +43,7 @@ void ForEachRawStaticDimToDyanmicPair(const DoEachT& DoEach) { } } -std::optional> ParseRawStaticDimToDyanmicPair( +std::optional> ParseRawStaticDimToDynamicPair( const std::string& raw_pair) { size_t pos = raw_pair.find(":", 0); if (pos == std::string::npos) return std::nullopt; @@ -66,10 +68,10 @@ std::optional> ParseRawStaticDimToDyanmicPair( return std::pair{int64_t{constant}, symbol}; } -std::unordered_map GetStaticDimToDyanmicFromFlag() { +std::unordered_map GetStaticDimToDynamicFromFlag() { std::unordered_map map; - ForEachRawStaticDimToDyanmicPair([&](const std::string& raw_pair) { - if (auto pair = ParseRawStaticDimToDyanmicPair(raw_pair)) { + ForEachRawStaticDimToDynamicPair([&](const std::string& raw_pair) { + if (auto pair = ParseRawStaticDimToDynamicPair(raw_pair)) { map.insert(pair.value()); } }); @@ -81,7 +83,7 @@ using GlobalStaticDimToDynamicMapT = std::optional CalcGlobalStaticDimToDynamicMap() { std::unordered_map map = - GetStaticDimToDyanmicFromFlag(); + GetStaticDimToDynamicFromFlag(); if (map.empty()) return std::nullopt; auto DividedByOther = [&](int64_t constant) { for (const auto& [other_constant, _] : map) { @@ -378,7 +380,7 @@ struct StaticDimToDynamicConverter { symbol::TensorShapeOrDataDimExprs(old)}; } } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } template diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc index 886cc29efa5b1..8f64980baf1c8 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.cc @@ -124,13 +124,13 @@ class GroupOpPattern : public pir::OpRewritePattern { auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(group_op->GetParentProgram()); // Record map info for yield value to each fusion_op's result - std::unordered_map<::pir::Value, ::pir::Value> fusion_yiled_values; + std::unordered_map<::pir::Value, ::pir::Value> fusion_yield_values; const auto& TryReplaceOperandSource = [&](::pir::Operation* op) { for (auto& operand : op->operands()) { const auto value = operand.source(); - if (fusion_yiled_values.find(value) != fusion_yiled_values.end()) { - operand.set_source(fusion_yiled_values.at(value)); + if (fusion_yield_values.find(value) != fusion_yield_values.end()) { + operand.set_source(fusion_yield_values.at(value)); } } }; @@ -158,9 +158,9 @@ class GroupOpPattern : public pir::OpRewritePattern { auto fusion_op = CreateFusionOp(vec_outs, group); for (size_t i = 0; i < fusion_op.num_results(); ++i) { - CHECK(fusion_yiled_values.insert({vec_outs[i], fusion_op.result(i)}) + CHECK(fusion_yield_values.insert({vec_outs[i], fusion_op.result(i)}) .second) - << "fusion_yiled_values already has key!"; + << "fusion_yield_values already has key!"; const auto& shape_expr = shape_analysis.GetShapeOrDataForValue(vec_outs[i]); shape_analysis.SetShapeOrDataForValue(fusion_op.result(i), shape_expr); @@ -216,5 +216,3 @@ std::unique_ptr<::pir::Pass> CreateDivideGroupOpToFusionOpPass() { } // namespace ir } // namespace dialect } // namespace cinn - -// REGISTER_IR_PASS(cinn_group_lowering, DivideGroupOpToFusionOpPass); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc index 7ee55cc7c9396..79b8a70d28acc 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_pass.cc @@ -30,6 +30,7 @@ #include "paddle/common/flags.h" #include "paddle/cinn/common/is_reachable_predicator.h" +#include "paddle/common/enforce.h" PD_DECLARE_bool(enhance_vertical_fusion_with_recompute); @@ -431,7 +432,7 @@ template struct HorizontalFuseUtil { using KindKeyT = std::pair; - static bool DetectFusabilityByKind(FusePassCtxT* ctx, + static bool DetectFusibilityByKind(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { const KindKeyT kind_pair(src.kind(), dst.kind()); @@ -590,7 +591,7 @@ class DefaultInputFusePass final : public InputFusePass { bool fusionable = false; for (auto& groups : fusionable_consumers) { auto& last = groups.back(); - if (!HorizontalFuseUtil::DetectFusabilityByKind( + if (!HorizontalFuseUtil::DetectFusibilityByKind( ctx, candidate, last)) { continue; } @@ -681,7 +682,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { bool fusionable = false; for (auto& groups : fusionable_consumers) { auto& last = groups.back(); - if (!HorizontalFuseUtil::DetectFusabilityByKind( + if (!HorizontalFuseUtil::DetectFusibilityByKind( ctx, candidate, last)) { continue; } @@ -752,7 +753,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { std::vector candidates; for (size_t i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); - if (!DetectFusabilityByKind(ctx, producer, consumer)) { + if (!DetectFusibilityByKind(ctx, producer, consumer)) { break; } candidates.push_back(consumer); @@ -764,7 +765,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { for (size_t i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); - if (!DetectFusabilityByKind(ctx, producer, consumer)) { + if (!DetectFusibilityByKind(ctx, producer, consumer)) { continue; } if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { @@ -776,7 +777,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { } using KindKeyT = std::pair; - bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + bool DetectFusibilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { const KindKeyT kind_pair(src.kind(), dst.kind()); @@ -941,7 +942,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { std::vector candidates; for (size_t i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); - if (!DetectFusabilityByKind(ctx, producer, consumer)) { + if (!DetectFusibilityByKind(ctx, producer, consumer)) { continue; } unsafe_candidates.push_back(consumer); @@ -960,7 +961,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { } using KindKeyT = std::pair; - bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + bool DetectFusibilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) const { const KindKeyT kind_pair(src.kind(), dst.kind()); @@ -1139,11 +1140,11 @@ class GeneralFusionMergePassHelper { while (DoGeneralRecomputeAndVerticalFusion()) { } - DoPrologueGenerateShapeOpGroupFustion(); + DoPrologueGenerateShapeOpGroupFusion(); } - void DoPrologueGenerateShapeOpGroupFustion() { - VLOG(3) << "DoPrologueGenerateShapeOpGroupFustion...!"; + void DoPrologueGenerateShapeOpGroupFusion() { + VLOG(3) << "DoPrologueGenerateShapeOpGroupFusion...!"; bool updated = false; for (size_t idx = 0; idx < fusion_groups_.size(); ++idx) { auto producer = fusion_groups_[idx]; @@ -1296,7 +1297,7 @@ class GeneralFusionMergePassHelper { } } if (is_ring) { - LOG(FATAL) << "Exists Ring, Please Check!"; + PADDLE_THROW(phi::errors::Fatal("Exists Ring, Please Check!")); } } } @@ -1328,7 +1329,7 @@ class GeneralFusionMergePassHelper { bool GeneralHorizontalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralHorizontalFuse handling producer : " << producer->group_id; - const auto& GetFusableConsumerGroupLists = + const auto& GetFusibleConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& MarkFusible = [&](const OpGroupList& candidates) { @@ -1339,8 +1340,8 @@ class GeneralFusionMergePassHelper { EnableFusedHorizontalGroups(&fuse_ctx); return tagged_lists; }; - const auto& GetFusableConsumerGroupList = [&]() -> std::vector { - const auto& group_lists = GetFusableConsumerGroupLists(); + const auto& GetFusibleConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusibleConsumerGroupLists(); if (group_lists.empty()) { return std::vector{}; } @@ -1355,7 +1356,7 @@ class GeneralFusionMergePassHelper { return ret; }; - const auto& group_lists = GetFusableConsumerGroupList(); + const auto& group_lists = GetFusibleConsumerGroupList(); if (group_lists.empty()) { return false; } @@ -1387,7 +1388,7 @@ class GeneralFusionMergePassHelper { bool CallGeneralInputFusePass( const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; - const auto& GetFusableConsumerGroupLists = + const auto& GetFusibleConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& MarkFusible = [&](const OpGroupList& candidates) { @@ -1402,8 +1403,8 @@ class GeneralFusionMergePassHelper { EnableFusedInputGroups(&fuse_ctx); return tagged_lists; }; - const auto& GetFusableConsumerGroupList = [&]() -> std::vector { - const auto& group_lists = GetFusableConsumerGroupLists(); + const auto& GetFusibleConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusibleConsumerGroupLists(); if (group_lists.empty()) { return std::vector{}; } @@ -1418,7 +1419,7 @@ class GeneralFusionMergePassHelper { return ret; }; - const auto& group_lists = GetFusableConsumerGroupList(); + const auto& group_lists = GetFusibleConsumerGroupList(); if (group_lists.empty()) { return false; } @@ -1613,7 +1614,7 @@ class GeneralFusionMergePassHelper { bool GeneralVerticalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralVerticalFuse...!"; using GroupSets = std::vector>; - const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + const auto& GetFusibleConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; const auto& MarkFusible = [&](const OpGroupPtr& first, const OpGroupPtr& second) { @@ -1625,9 +1626,9 @@ class GeneralFusionMergePassHelper { return tagged_sets; }; - auto GetFusableConsumerGroupSet = + auto GetFusibleConsumerGroupSet = [&]() -> std::unordered_set { - const auto& group_sets = GetFusableConsumerOpGroupSets(); + const auto& group_sets = GetFusibleConsumerOpGroupSets(); if (group_sets.empty()) { return {}; } @@ -1639,7 +1640,7 @@ class GeneralFusionMergePassHelper { }; bool update = false; - auto consumer_groups = GetFusableConsumerGroupSet(); + auto consumer_groups = GetFusibleConsumerGroupSet(); if (consumer_groups.size()) { SelectConsumerToFuse(producer, &consumer_groups); } @@ -1868,7 +1869,7 @@ class GeneralFusionMergePassHelper { VLOG(3) << "GeneralRecomputeFuse handling producer : " << producer->group_id; using GroupSets = std::set>; - const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + const auto& GetFusibleConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; const auto& MarkFusible = [&](const OpGroupPtr& first, const OpGroupPtr& second) { @@ -1880,9 +1881,9 @@ class GeneralFusionMergePassHelper { return tagged_sets; }; - auto GetFusableConsumerGroupSet = + auto GetFusibleConsumerGroupSet = [&]() -> std::unordered_set { - const auto& group_sets = GetFusableConsumerOpGroupSets(); + const auto& group_sets = GetFusibleConsumerOpGroupSets(); if (group_sets.empty()) { return {}; } @@ -1894,7 +1895,7 @@ class GeneralFusionMergePassHelper { }; bool update = false; - auto consumer_groups = GetFusableConsumerGroupSet(); + auto consumer_groups = GetFusibleConsumerGroupSet(); if (consumer_groups.size() > 0) { CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; @@ -2220,7 +2221,7 @@ class GeneralFusionMergePassHelper { GroupList GeneralFusionMergePassInternal(const GroupList& group_list) { if (group_list.size() <= 1) { - VLOG(3) << "Don't do Fusoin Merge Pass...!"; + VLOG(3) << "Don't do Fusion Merge Pass...!"; return group_list; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h index f6c17ae28ebfb..f04ee9212f9f3 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h @@ -146,7 +146,7 @@ inline bool horizontal_elementwise_fuse_reduce( auto ele_node_shape = GetValueShape((*ele_group->master_ops.begin())->result(0)); int32_t size_ele = ::common::product(ele_node_shape); - // TODO(phlrain): seems extrame danger herem, why compare multi Master Node? + // TODO(phlrain): seems extreme danger here, why compare multi Master Node? for (auto* master : reduce_group->master_ops) { auto master_node_shape = GetValueShape(master->result(0)); int32_t size_master = ::common::product(master_node_shape); @@ -349,7 +349,7 @@ inline bool horizontal_relation(const std::shared_ptr& first, }; auto selected_nodes = select_node_set(second_set, op_pattern_kind); - auto check_depency = [&](::pir::Operation* node) { + auto check_dependency = [&](::pir::Operation* node) { std::queue<::pir::Operation*> candidates; std::unordered_set<::pir::Operation*> visited_set; candidates.push(node); @@ -381,7 +381,7 @@ inline bool horizontal_relation(const std::shared_ptr& first, }; for (auto node : selected_nodes) { - if (check_depency(node)) { + if (check_dependency(node)) { return false; } } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc index b2dfea14d4d67..f395a1fb3e28b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.cc @@ -67,22 +67,32 @@ class GroupOpGenerateShapeOpsPattern } }; -class MoveGenerateShapeOpsToProloguePass : public pir::PatternRewritePass { +class MoveGenerateShapeOpsToProloguePass : public pir::Pass { public: MoveGenerateShapeOpsToProloguePass() - : pir::PatternRewritePass("move_generate_shape_ops_to_prologue", 1) {} + : pir::Pass("move_generate_shape_ops_to_prologue", /*opt_level=*/1) {} - pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { - pir::RewritePatternSet ps(context); - ps.Add(context); - return ps; + void Run(pir::Operation* op) override { + auto group_op = op->dyn_cast(); + CHECK(group_op); + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(group_op->GetParentProgram()); + ShapeOrDataDimExprsAccessor dim_exprs_accessor{ + .GetShapeOrDataDimExprs = + [&](pir::Value value) -> const symbol::ShapeOrDataDimExprs& { + return shape_analysis.GetShapeOrDataForValue(value); + }, + .SetShapeOrDataDimExprs = + [&](pir::Value value, + const symbol::ShapeOrDataDimExprs& dim_exprs) { + shape_analysis.SetShapeOrDataForValue(value, dim_exprs); + }}; + MoveGenerateShapeOpsToPrologue(ctx, group_op.block(), dim_exprs_accessor); } bool CanApplyOn(pir::Operation* op) const override { - if (!(op->isa() && op->num_regions() > 0)) return false; - auto* program = op->GetParentProgram(); - VLOG(4) << "Before MoveGenerateShapeOpsToProloguePass: " << *program; - return true; + return op->isa() && op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h index 41dd5c9089c71..4fbe41385ec62 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h @@ -181,7 +181,7 @@ inline bool reduce_fuse_reduce( inline bool is_horizontal_relation(::pir::Operation* producer, const std::shared_ptr& consumer) { - auto check_depency = [&](::pir::Operation* op) { + auto check_dependency = [&](::pir::Operation* op) { std::queue<::pir::Operation*> candidates; std::unordered_set<::pir::Operation*> visited_set; candidates.push(op); @@ -192,7 +192,7 @@ inline bool is_horizontal_relation(::pir::Operation* producer, // visit all producer op for (size_t i = 0; i < candidate->num_operands(); ++i) { auto tmp_op = candidate->operand_source(i).defining_op(); - // check depency. + // check dependency. if (producer == tmp_op) { return true; } @@ -216,7 +216,7 @@ inline bool is_horizontal_relation(::pir::Operation* producer, consumer->op_pattern_kind) { continue; } - if (check_depency(op)) { + if (check_dependency(op)) { return false; } } @@ -246,6 +246,11 @@ inline bool horizontal_or_vertical_reduce_relation( // check producer has same shape with reducer op. auto reduce_shape = ::common::vectorize(GetFirstInputShape(reducer)); auto reduce_axes = GetVectorAttr(reducer, "dim"); + if (reduce_axes.empty()) { + for (size_t i = 0; i < reduce_shape.size(); ++i) { + reduce_axes.push_back(i); + } + } for (auto& axis : reduce_axes) { // if axis = -1, set as shape.size() - 1 @@ -271,22 +276,22 @@ inline bool horizontal_or_vertical_reduce_relation( return false; } - int succesive_reduce_dimension = reduce_shape.at(reduce_axes.back()); + int successive_reduce_dimension = reduce_shape.at(reduce_axes.back()); for (int idx = reduce_axes.size() - 2; idx >= 0; --idx) { if (reduce_axes[idx] == reduce_axes[idx + 1] - 1) { - succesive_reduce_dimension *= reduce_shape[reduce_axes[idx]]; + successive_reduce_dimension *= reduce_shape[reduce_axes[idx]]; continue; } break; } // helper->target_ == cinn::common::DefaultNVGPUTarget() - // succesive_reduce_dimension <= helper->target_.max_num_threads() + // successive_reduce_dimension <= helper->target_.max_num_threads() // TODO(phlrain): support is_gpu_target and max_thread bool is_gpu_target = true; int max_thread = 32 * 1024; return is_gpu_target - ? (succesive_reduce_dimension <= max_thread ? true : false) + ? (successive_reduce_dimension <= max_thread ? true : false) : true; } diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc index e8d8355872cd2..5d3baeb21f92a 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.cc @@ -19,7 +19,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace cinn { namespace dialect { @@ -28,11 +28,14 @@ namespace ir { namespace { template -void VisitEachOp(pir::ModuleOp module_op, const DoEachT& DoEach) { - for (uint32_t i = 0; i < module_op->num_regions(); i++) { - for (pir::Block& block : module_op->region(i)) { - for (pir::Operation& op : block) { - DoEach(op); +void VisitEachOp(pir::Operation* op, const DoEachT& DoEach) { + for (uint32_t i = 0; i < op->num_regions(); i++) { + for (pir::Block& block : op->region(i)) { + for (pir::Operation& sub_op : block) { + DoEach(sub_op); + if (sub_op.num_regions() > 0) { + VisitEachOp(&sub_op, DoEach); + } } } } @@ -90,24 +93,36 @@ symbol::ShapeOrDataDimExprs SimplifyShapeOrData( return std::visit(lambdas, shape_or_data.variant()); } -void SimplifyDimExpr(pir::ModuleOp module_op) { +void SimplifyDimExpr(pir::Operation* module_op) { VLOG(4) << "SimplifyDimExpr start"; - pir::ShapeConstraintIRAnalysis shape_analysis = - pir::ShapeAnalysisManager::Instance().Get(module_op.program()); + pir::ShapeConstraintIRAnalysis* shape_analysis = + &pir::ShapeAnalysisManager::Instance().Get( + module_op->dyn_cast().program()); + VisitEachOp(module_op, [&](pir::Operation& op) { VisitEachValue(op, [&](pir::Value value) { - if (!shape_analysis.HasShapeOrDataForValue(value)) { + if (!shape_analysis->HasShapeOrDataForValue(value)) { VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for " "value of the op:" << op.name(); } else { const symbol::ShapeOrDataDimExprs& shape_or_data = - shape_analysis.GetShapeOrDataForValue(value); + shape_analysis->GetShapeOrDataForValue(value); + VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data; symbol::ShapeOrDataDimExprs simplified_shape_or_data = SimplifyShapeOrData(shape_or_data); - shape_analysis.SetShapeOrDataForValue(value, simplified_shape_or_data); + VLOG(8) << op.name() + << " simplified_shape_or_data: " << simplified_shape_or_data; + shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data); } }); + if (op.num_results() > 0) { + pir::shape::SetShapeAttrForOp( + &op, shape_analysis->GetShapeOrDataForValue(op.result(0))); + } else { + pir::shape::SetShapeAttrForOp( + &op, shape_analysis->GetShapeOrDataForValue(op.operand_source(0))); + } // TODO(JiaWenxuan): simplify the attribute "sym_shape_str" of the op }); VLOG(4) << "SimplifyDimExpr end"; @@ -117,10 +132,7 @@ class SimplifyDimExprPass : public pir::Pass { public: SimplifyDimExprPass() : pir::Pass("simplify_dim_expr_pass", 1) {} - void Run(pir::Operation* op) override { - pir::ModuleOp module_op = op->dyn_cast(); - SimplifyDimExpr(module_op); - } + void Run(pir::Operation* op) override { SimplifyDimExpr(op); } bool CanApplyOn(pir::Operation* op) const override { return op->isa() && op->num_regions() > 0; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc new file mode 100644 index 0000000000000..f859c09400c16 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h" + +#include "build/paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "build/paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace dialect { +namespace ir { + +namespace { + +class FusionOpPattern : public pir::OpRewritePattern { + public: + explicit FusionOpPattern(::pir::IrContext* context) + : pir::OpRewritePattern(context) {} + + bool MatchAndRewrite(cinn::dialect::FusionOp fusion_op, + pir::PatternRewriter& rewriter) const override { + // Fallback only when FusionOp has two operators inside: AnySingleOp + + // cf.yield + if (fusion_op.GetOperators().size() > 2) { + return false; + } + PADDLE_ENFORCE_EQ( + fusion_op.GetOperators().size(), + 2, + phi::errors::InvalidArgument( + "fusion_op should have two operators inside, but got %d", + fusion_op.GetOperators().size())); + PADDLE_ENFORCE( + fusion_op.GetOperators()[1]->isa<::pir::YieldOp>(), + phi::errors::InvalidArgument( + "The last operator of fusion_op must be YieldOp, but got %s", + fusion_op.GetOperators()[1]->name())); + + auto* program = fusion_op->GetParentProgram(); + auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get( + fusion_op->GetParentProgram()); + std::optional paddle_op = + FallBackOp(fusion_op.GetOperators()[0], rewriter); + if (!paddle_op.has_value()) { + return false; + } + + for (size_t i = 0; i < fusion_op.num_results(); ++i) { + rewriter.ReplaceAllUsesWith(fusion_op.result(i), + paddle_op.value()->result(i)); + if (shape_analysis.HasShapeOrDataForValue(fusion_op.result(i))) { + shape_analysis.SetShapeOrDataForValue( + paddle_op.value()->result(i), + shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); + } else { + LOG(WARNING) << "No shape_data for " + << fusion_op.result(i).defining_op()->name() << "_result_" + << i << ", this may cause error in dynamic shape"; + } + } + + rewriter.EraseOp(fusion_op); + return true; + } + + private: + typedef pir::Operation* (FusionOpPattern::*CinnOpHandler)( + pir::Operation*, pir::PatternRewriter&) const; + + pir::Operation* ReshapeOpPattern( + pir::Operation* op, + pir::PatternRewriter& rewriter) const { // NOLINT + PADDLE_ENFORCE(op->isa(), + phi::errors::InvalidArgument( + "Input should be cinn::dialect::ReshapeOp, but got %s", + op->name())); + auto reshape_op = op->dyn_cast(); + + const std::vector vec_out_shape = [&]() { + auto out_shape_attr = reshape_op.attribute("shape") + .dyn_cast() + .AsVector(); + PADDLE_ENFORCE_GT(out_shape_attr.size(), + 0, + phi::errors::InvalidArgument( + "The shape attribute should not be empty")); + + std::vector ret; + std::transform( + out_shape_attr.begin(), + out_shape_attr.end(), + std::back_inserter(ret), + [](const auto& attr) { + return attr.template dyn_cast<::pir::Int32Attribute>().data(); + }); + return ret; + }(); + + auto paddle_reshape = rewriter.Build( + reshape_op->operand_source(0), vec_out_shape); + return paddle_reshape; + } + + const std::unordered_map& op_handler_map() const { + static std::unordered_map handler_map = { + {cinn::dialect::ReshapeOp::name(), &FusionOpPattern::ReshapeOpPattern}, + }; + return handler_map; + } + + std::optional FallBackOp( + pir::Operation* op, + pir::PatternRewriter& rewriter) const { // NOLINT + auto it = op_handler_map().find(op->name()); + if (it == op_handler_map().end()) { + LOG(WARNING) << "No fallback handler for op: " << op->name(); + return std::nullopt; + } + return (this->*(it->second))(op, rewriter); + } +}; + +class SingleOpFallbackToPhiPass : public pir::PatternRewritePass { + public: + SingleOpFallbackToPhiPass() + : pir::PatternRewritePass("single_op_fallback_to_phi", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + context->GetOrRegisterDialect(); + context->GetOrRegisterDialect(); + context->GetOrRegisterDialect(); + + pir::RewritePatternSet ps(context); + ps.Add(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->num_regions() > 0; + } +}; + +} // namespace + +std::unique_ptr<::pir::Pass> CreateSingleOpFallbackToPhiPass() { + return std::make_unique(); +} + +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h new file mode 100644 index 0000000000000..9b35400dc245f --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/pass/pass.h" + +namespace cinn { +namespace dialect { +namespace ir { +std::unique_ptr<::pir::Pass> CreateSingleOpFallbackToPhiPass(); +} // namespace ir +} // namespace dialect +} // namespace cinn diff --git a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.cc index 68372afa3e9ca..97570459eebc1 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.cc @@ -16,8 +16,10 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.h" -#include "paddle/cinn/common/dim_expr_util.h" #include "paddle/cinn/common/union_find.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace cinn { namespace dialect { @@ -26,23 +28,24 @@ namespace ir { namespace { template -void VisitEachOp(pir::ModuleOp module_op, const DoEachT& DoEach) { - for (uint32_t i = 0; i < module_op->num_regions(); i++) { - for (pir::Block& block : module_op->region(i)) { - for (pir::Operation& op : block) { - DoEach(op); +void VisitEachOp(pir::Operation* op, const DoEachT& DoEach) { + DoEach(op); + for (auto& region : *op) { + for (auto& block : region) { + for (auto& op_in_block : block) { + DoEach(&op_in_block); } } } } template -void VisitEachValue(const pir::Operation& op, const DoEachT& DoEach) { - for (std::size_t i = 0; i < op.num_operands(); ++i) { - DoEach(op.operand_source(i)); +void VisitEachValue(const pir::Operation* op, const DoEachT& DoEach) { + for (std::size_t i = 0; i < op->num_operands(); ++i) { + DoEach(op->operand_source(i)); } - for (std::size_t i = 0; i < op.num_results(); ++i) { - DoEach(op.result(i)); + for (std::size_t i = 0; i < op->num_results(); ++i) { + DoEach(op->result(i)); } } @@ -56,8 +59,9 @@ symbol::TensorShapeOrDataDimExprs SubstituteTensorShapeOrData( substitution_pattern) -> std::vector { std::vector substituted_dim_expr{}; for (const symbol::DimExpr& dim_expr : original_dim_expr) { - substituted_dim_expr.push_back( - cinn::common::SubstituteDimExpr(dim_expr, substitution_pattern)); + const auto& tmp_dim_expr = + symbol::SubstituteDimExpr(dim_expr, substitution_pattern); + substituted_dim_expr.push_back(symbol::SimplifyDimExpr(tmp_dim_expr)); } return substituted_dim_expr; }; @@ -95,10 +99,26 @@ symbol::ShapeOrDataDimExprs SubstituteShapeOrData( return std::visit(lambdas, shape_or_data.variant()); } +int GetDimExprPriority(const symbol::DimExpr& dim_expr) { + return std::visit( + symbol::Overloaded{ + [&](std::int64_t) { return 0; }, + [&](const std::string&) { return 1; }, + [&](const symbol::Negative&) { return 2; }, + [&](const symbol::Reciprocal&) { return 2; }, + [&](const symbol::Add&) { return 2; }, + [&](const symbol::Mul&) { return 2; }, + [&](const symbol::Max&) { return 2; }, + [&](const symbol::Min&) { return 2; }, + [&](const symbol::Broadcast&) { return 2; }, + }, + dim_expr.variant()); +} + std::unordered_map GetDimExprSubstitution( pir::ShapeConstraintIRAnalysis* shape_analysis) { const std::vector& dim_expr_constraints = - shape_analysis->CreateDimExprBuilder().constraints(); + shape_analysis->DimExprBuilder().constraints(); const cinn::common::UnionFindSet& union_find_set = [&]() { cinn::common::UnionFindSet union_find_set; for (const auto& constraint : dim_expr_constraints) { @@ -119,9 +139,8 @@ std::unordered_map GetDimExprSubstitution( CHECK(!dim_expr_cluster.empty()); auto dim_expr_root = dim_expr_cluster[0]; for (const auto& dim_expr : dim_expr_cluster) { - if (std::holds_alternative(dim_expr)) { + if (GetDimExprPriority(dim_expr) < GetDimExprPriority(dim_expr_root)) { dim_expr_root = dim_expr; - break; } } for (const auto& dim_expr : dim_expr_cluster) { @@ -133,26 +152,41 @@ std::unordered_map GetDimExprSubstitution( return substitution_pattern; } -void SubstituteDimExprBasedOnConstraints(pir::ModuleOp module_op) { +void SubstituteDimExprBasedOnConstraints(pir::Operation* region_op) { VLOG(4) << "SubstituteDimExprBasedOnConstraints start"; - pir::ShapeConstraintIRAnalysis shape_analysis = - pir::ShapeAnalysisManager::Instance().Get(module_op.program()); + pir::ShapeConstraintIRAnalysis* shape_analysis = + &pir::ShapeAnalysisManager::Instance().Get(region_op->GetParentProgram()); const std::unordered_map& - substitution_pattern = GetDimExprSubstitution(&shape_analysis); - VisitEachOp(module_op, [&](pir::Operation& op) { + substitution_pattern = GetDimExprSubstitution(shape_analysis); + + VisitEachOp(region_op, [&](pir::Operation* op) { VisitEachValue(op, [&](pir::Value value) { - if (!shape_analysis.HasShapeOrDataForValue(value)) { - VLOG(4) << "Can not find ShapeOrData for value of op(" << op.name() + if (!shape_analysis->HasShapeOrDataForValue(value)) { + VLOG(4) << "Can not find ShapeOrData for value of op(" << op->name() << ") in shape_analysis"; } else { const symbol::ShapeOrDataDimExprs& origin_shape_or_data = - shape_analysis.GetShapeOrDataForValue(value); + shape_analysis->GetShapeOrDataForValue(value); + VLOG(8) << op->name() + << " origin_shape_or_data: " << origin_shape_or_data; const symbol::ShapeOrDataDimExprs& substituted_shape_or_data = SubstituteShapeOrData(origin_shape_or_data, substitution_pattern); - shape_analysis.SetShapeOrDataForValue(value, substituted_shape_or_data); + VLOG(8) << op->name() + << " substituted_shape_or_data: " << substituted_shape_or_data; + shape_analysis->SetShapeOrDataForValue(value, + substituted_shape_or_data); } }); - // TODO(JiaWenxuan): substitute the attribute "sym_shape_str" of the op + if (op->num_regions() > 0) { + return; + } + if (op->num_results() > 0) { + pir::shape::SetShapeAttrForOp( + op, shape_analysis->GetShapeOrDataForValue(op->result(0))); + } else { + pir::shape::SetShapeAttrForOp( + op, shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); + } }); VLOG(4) << "SubstituteDimExprBasedOnConstraints end"; } @@ -163,12 +197,11 @@ class SubstituteDimExprBasedOnConstraintsPass : public pir::Pass { : pir::Pass("substitute_dim_expr_based_on_constraints_pass", 1) {} void Run(pir::Operation* op) override { - pir::ModuleOp module_op = op->dyn_cast(); - SubstituteDimExprBasedOnConstraints(module_op); + SubstituteDimExprBasedOnConstraints(op); } bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; + return op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc index f7eea680a3b61..6ef8dd56edebc 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.cc @@ -15,6 +15,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/common/ddim.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" @@ -35,11 +36,19 @@ namespace { pir::Value GetOutputDimTensor(pir::PatternRewriter* rewriter, pir::Value x, - pir::Value y) { - pir::Value x_shape = rewriter->Build(x).out(); - pir::Value y_shape = rewriter->Build(y).out(); - return rewriter->Build(x_shape, y_shape) - .out(); + pir::Value y, + pir::ShapeConstraintIRAnalysis* shape_analysis) { + pir::Operation* x_shape_op = rewriter->Build(x); + pir::Operation* y_shape_op = rewriter->Build(y); + pir::Operation* shape_broadcast_op = + rewriter->Build(x_shape_op->result(0), + y_shape_op->result(0)); + for (auto* op : std::vector{x_shape_op, y_shape_op, shape_broadcast_op}) { + auto infer_symbolic_shape_interface = + op->dyn_cast(); + infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis); + } + return shape_broadcast_op->result(0); } bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { @@ -51,12 +60,14 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { const auto& y_shape = shape_analysis.GetShapeOrDataForValue(y); const auto& out_shape = shape_analysis.GetShapeOrDataForValue(op->result(0)); - bool has_insert_broadcast = false; + if (x_shape == y_shape) { + return false; + } - pir::Value output_dim_tensor = GetOutputDimTensor(rewriter, x, y); + pir::Value output_dim_tensor = + GetOutputDimTensor(rewriter, x, y, &shape_analysis); if (x_shape.shape() != out_shape.shape() || x_shape.data() != out_shape.data()) { - has_insert_broadcast = true; pir::Value broadcasted_x = rewriter->Build(x, output_dim_tensor).out(); op->operand(0).set_source(broadcasted_x); @@ -64,13 +75,12 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) { } if (y_shape.shape() != out_shape.shape() || y_shape.data() != out_shape.data()) { - has_insert_broadcast = true; pir::Value broadcasted_y = rewriter->Build(y, output_dim_tensor).out(); op->operand(1).set_source(broadcasted_y); shape_analysis.SetShapeOrDataForValue(broadcasted_y, out_shape); } - return has_insert_broadcast; + return true; } } // namespace @@ -111,7 +121,13 @@ class InsertBroadcastPass : public pir::PatternRewritePass { ps.Add>(context); ps.Add>(context); + // logical ops + ps.Add>(context); + ps.Add>(context); + ps.Add>(context); + // bitwise ops + ps.Add>(context); ps.Add>(context); ps.Add>(context); ps.Add>(context); @@ -120,7 +136,7 @@ class InsertBroadcastPass : public pir::PatternRewritePass { } bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; + return op->isa() && op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc similarity index 56% rename from paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.cc rename to paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc index a2393a09fae21..7068221d77fe5 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.cc @@ -12,47 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - -#include "paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.h" - -#include - -#include "paddle/cinn/adt/generate_map_expr.h" -#include "paddle/cinn/common/broadcast_tree.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h" #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" -#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" -#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" -#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" -#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" -#include "paddle/cinn/hlir/framework/pir/group.h" -#include "paddle/cinn/hlir/framework/pir/utils.h" -#include "paddle/cinn/hlir/framework/pir_compiler.h" -#include "paddle/cinn/runtime/flags.h" -#include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" -#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" -#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h" -#include "paddle/pir/include/pass/pass_registry.h" #include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h" -PD_DECLARE_bool(cinn_enable_map_expr); +using OpLoweringGroup = cinn::hlir::framework::pir::OpLoweringGroup; +using OpLoweringGroupPtr = std::shared_ptr; +using cinn::dialect::ir::details::CompileGroupAsOpAttribute; +using cinn::dialect::ir::details::GetBlockOutsideInput; namespace { - -using Group = cinn::hlir::framework::pir::Group; -using GroupPtr = std::shared_ptr; -using cinn::hlir::framework::pir::CompatibleInfo; +std::vector GetOpOuputValues(const pir::Operation* op) { + std::vector outputs; + outputs.reserve(op->num_results()); + for (size_t i = 0; i < op->num_results(); ++i) { + outputs.push_back(op->result(i)); + } + return outputs; +} using ShapeOrDataDimExprs4ValueT = std::function; -bool SameInputOutputShape( +static bool SameInputOutputShape( paddle::dialect::ExpandOp expand_op, const ShapeOrDataDimExprs4ValueT& ShapeOrDataDimExprs4Value) { const auto& x = ShapeOrDataDimExprs4Value(expand_op.x()); @@ -65,6 +52,76 @@ bool SameInputOutputShape( return x.shape() == out.shape(); } +void CompileGroupToJitKernelOp( + pir::PatternRewriter& rewriter, // NOLINT + std::unordered_map* group_map) { + // prepare attribute for jit_kernel_op + std::vector group_list; + group_list.reserve(group_map->size()); + for (const auto& [_, group] : *group_map) { + group_list.push_back(group); + } + auto op_attr_map = CompileGroupAsOpAttribute(group_list); + VLOG(4) << "The size of group_map is : " << group_map->size(); + for (auto& [block, group] : *group_map) { + std::vector output_types; + const auto& group_output_values = group->output_values(); + for (size_t i = 0; i < group_output_values.size(); ++i) { + output_types.push_back(group_output_values[i].type()); + } + auto& yield_op = block->back(); + CHECK(yield_op.isa()) << "Last op of block should be yield"; + rewriter.set_insertion_point(&yield_op); + const auto& group_inputs = GetBlockOutsideInput(group->ops()); + auto jit_kernel_op = rewriter.Build( + group_inputs, op_attr_map.at(group), output_types); + CHECK(jit_kernel_op.num_results() == group_output_values.size()); + for (size_t i = 0; i < jit_kernel_op.num_results(); ++i) { + rewriter.ReplaceAllUsesWith(group_output_values[i], + jit_kernel_op.result(i)); + } + + // Delete origin group ops + std::vector group_ops; + for (auto iter = block->rbegin(); iter != block->rend(); iter++) { + if (!iter->isa()) { + group_ops.push_back(&(*iter)); + } + } + for (auto* op : group_ops) { + if (op->use_empty()) { + op->Erase(); + } + } + } +} + +void UpdateGroupShapeExprs( + const OpLoweringGroupPtr& new_group, + const OpLoweringGroupPtr& origin_group, + const pir::IrMapping& ir_mapping, + const cinn::common::BroadcastLeaf& value_dim_exprs_list, + const std::unordered_map& value_to_dim_expr_idx) { + for (const auto& [origin_val, new_val] : ir_mapping.GetMap()) { + const auto& shape_dim_expr = + value_dim_exprs_list->at(value_to_dim_expr_idx.at(origin_val)); + const auto& origin_shape_or_data = + origin_group->GetShapeOrDataExprs(origin_val); + if (origin_shape_or_data.data()) { + new_group->SetShapeOrDataExprs( + new_val, + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs( + std::vector{shape_dim_expr.size()}, + shape_dim_expr)}); + } else { + new_group->SetShapeOrDataExprs( + new_val, + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(shape_dim_expr)}); + } + } +} + // Returns true if success bool EraseOneExpand( pir::Block* block, @@ -99,7 +156,7 @@ void EraseUnnecessaryExpandsInBlock( void ReplaceExpandWithBroadcast(pir::IrContext* ir_context, pir::Block* block, - const GroupPtr& group) { + const OpLoweringGroupPtr& group) { std::vector op_list; for (auto& op : *block) { op_list.push_back(&op); @@ -140,29 +197,6 @@ void ReplaceExpandWithBroadcast(pir::IrContext* ir_context, } } -std::vector GetBlockOutsideInput( - const std::vector& op_list) { - std::vector vec_res; - std::unordered_set<::pir::Value> block_inner_output; - for (size_t k = 0; k < op_list.size(); ++k) { - for (size_t i = 0; i < op_list[k]->num_results(); ++i) { - block_inner_output.insert(op_list[k]->result(i)); - } - } - - std::unordered_set<::pir::Value> insert_value; - for (size_t k = 0; k < op_list.size(); ++k) { - for (size_t i = 0; i < op_list[k]->num_operands(); ++i) { - if (!block_inner_output.count(op_list[k]->operand_source(i)) && - !insert_value.count(op_list[k]->operand_source(i))) { - vec_res.push_back(op_list[k]->operand_source(i)); - insert_value.insert(op_list[k]->operand_source(i)); - } - } - } - return vec_res; -} - std::tuple BroadcastableToCondValue( const symbol::Broadcastable& broadcastable_condition, pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT @@ -226,53 +260,27 @@ std::tuple BroadcastableToCondValue( lhs_eq_rhs_cond, lhs_eq_one_cond, rhs_eq_one_cond); } -GroupPtr CloneGroup(const GroupPtr& group, - pir::Block* block, - pir::IrMapping* ir_mapping) { - return group->Clone(block, *ir_mapping); -} - -void UpdateGroupShapeExprs( - const GroupPtr& new_group, - const GroupPtr& origin_group, - const pir::IrMapping& ir_mapping, - const cinn::common::BroadcastLeaf& value_dim_exprs_list, - const std::unordered_map& value_to_dim_expr_idx) { - for (const auto& [origin_val, new_val] : ir_mapping.GetMap()) { - const auto& shape_dim_expr = - value_dim_exprs_list->at(value_to_dim_expr_idx.at(origin_val)); - const auto& origin_shape_or_data = - origin_group->GetShapeOrDataExprs(origin_val); - if (origin_shape_or_data.data()) { - new_group->SetShapeOrDataExprs( - new_val, - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs( - std::vector{shape_dim_expr.size()}, - shape_dim_expr)}); - } else { - new_group->SetShapeOrDataExprs( - new_val, - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(shape_dim_expr)}); - } - } +OpLoweringGroupPtr CloneGroup(const OpLoweringGroupPtr& group, + pir::Block* block, + pir::IrMapping* ir_mapping) { + return group->Clone(block, ir_mapping); } void SetLeafBlockByGroupView( - const GroupPtr& origin_group, + const OpLoweringGroupPtr& origin_group, const cinn::common::BroadcastLeaf& value_dim_exprs_list, const std::unordered_map& value_to_dim_expr_idx, pir::Builder& builder, // NOLINT pir::Block* block, - std::unordered_map* group_map) { + std::unordered_map* group_map) { pir::IrMapping ir_mapping; - auto origin_group_inputs = GetBlockOutsideInput(origin_group->ops); + auto origin_group_inputs = GetBlockOutsideInput(origin_group->ops()); for (auto input : origin_group_inputs) { ir_mapping.Add(input, input); } auto new_group = CloneGroup(origin_group, block, &ir_mapping); - CHECK_EQ(origin_group->ops.size(), new_group->ops.size()); + CHECK_EQ(origin_group->ops().size(), new_group->ops().size()); UpdateGroupShapeExprs(new_group, origin_group, ir_mapping, @@ -290,15 +298,6 @@ void SetLeafBlockByGroupView( group_map->insert({block, new_group}); } -std::vector GetOpOuputValues(const pir::Operation* op) { - std::vector outputs; - outputs.reserve(op->num_results()); - for (size_t i = 0; i < op->num_results(); ++i) { - outputs.push_back(op->result(i)); - } - return outputs; -} - void InsertYieldOpForCondBlock(pir::Operation* cond_op, pir::Builder& builder) { // NOLINT if (cond_op) { @@ -310,14 +309,14 @@ void InsertYieldOpForCondBlock(pir::Operation* cond_op, // Visit broadcast_tree by dfs pir::Operation* CreateConditionBlock( const cinn::common::BroadcastTree& broadcast_tree, - const GroupPtr& origin_group, + const OpLoweringGroupPtr& origin_group, pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT const std::unordered_map& value_to_dim_expr_idx, const std::vector& group_inputs, const std::vector& output_types, pir::Builder& builder, // NOLINT pir::Block* block, - std::unordered_map* group_map) { + std::unordered_map* group_map) { if (broadcast_tree.Has()) { const auto& broadcast_leaf = broadcast_tree.Get(); @@ -392,45 +391,23 @@ pir::Operation* CreateConditionBlock( } } -std::unordered_map> -CompileGroupAsOpAttribute( - const std::shared_ptr& pir_compiler, - const std::vector& group_list) { - auto fn_ptr_res = pir_compiler->BuildCUDAJITInfo(group_list); - - std::unordered_map> - result; - for (size_t i = 0; i < group_list.size(); ++i) { - std::unordered_map op_attrs{ - {cinn::dialect::JitKernelOp::kAttrName, - cinn::dialect::CINNKernelInfoAttribute::get(pir::IrContext::Instance(), - fn_ptr_res[i])}, - }; - result.insert({group_list[i], op_attrs}); - } - return result; -} - void SimplyConditionBlock( pir::PatternRewriter& rewriter, // NOLINT - std::unordered_map* group_map) { + std::unordered_map* group_map) { VLOG(4) << "simply condition block"; using DoEachMutBlockGroupT = - std::function; + std::function; const auto& ForEachMutBlockGroup = [&](const DoEachMutBlockGroupT& DoEach) { for (auto& [block, group] : *group_map) { DoEach(block, group); std::vector group_new_ops; group_new_ops.reserve(block->size()); - std::unordered_set group_ops_set; for (auto& op : *block) { if (!op.isa()) { group_new_ops.push_back(&op); - group_ops_set.insert(&op); } } - group->ops = group_new_ops; - group->ops_set = group_ops_set; + group->SetOps(group_new_ops); } }; ForEachMutBlockGroup([&](auto* block, const auto& group) { @@ -440,68 +417,72 @@ void SimplyConditionBlock( }; EraseUnnecessaryExpandsInBlock(block, rewriter, GetShapeOrDataForValue); }); - ForEachMutBlockGroup([&](auto* block, const auto& group) { - ReplaceExpandWithBroadcast(rewriter.ir_context(), block, group); - }); } +} // namespace -void CompileGroupToJitKernelOp( - const std::vector& group_inputs, - const std::shared_ptr& pir_compiler, - pir::PatternRewriter& rewriter, // NOLINT - std::unordered_map* group_map) { - // prepare attribute for jit_kernel_op - std::vector group_list; - group_list.reserve(group_map->size()); - for (const auto& [_, group] : *group_map) { - group_list.push_back(group); - } - auto op_attr_map = CompileGroupAsOpAttribute(pir_compiler, group_list); - VLOG(4) << "The size of group_map is : " << group_map->size(); - for (auto& [block, group] : *group_map) { - std::vector output_types; - const auto& group_output_values = group->output_values; - for (size_t i = 0; i < group_output_values.size(); ++i) { - output_types.push_back(group_output_values[i].type()); +namespace cinn::dialect::ir::details { + +std::shared_ptr ConstructBroadcastTree( + const cinn::common::BroadcastLeaf& leaves) { + VLOG(6) << "before constructed. broadcast-leaf: \n" + << ToTxtString(cinn::common::BroadcastTree(leaves)); + auto broadcast_tree = std::make_shared( + cinn::common::ConstructBroadcastTree( + cinn::common::BroadcastLeaf(leaves))); + VLOG(4) << "broadcast-tree: \n" << ToTxtString(*broadcast_tree); + return broadcast_tree; +} + +GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group) { + std::unordered_set value_view; + group->WalkOps([&group, &value_view](pir::Operation* op) { + for (size_t i = 0; i < op->num_operands(); ++i) { + value_view.insert(op->operand_source(i)); } - auto& yield_op = block->back(); - CHECK(yield_op.isa()) << "Last op of block should be yield"; - rewriter.set_insertion_point(&yield_op); - auto jit_kernel_op = rewriter.Build( - group_inputs, op_attr_map.at(group), output_types); - CHECK(jit_kernel_op.num_results() == group_output_values.size()); - for (size_t i = 0; i < jit_kernel_op.num_results(); ++i) { - rewriter.ReplaceAllUsesWith(group_output_values[i], - jit_kernel_op.result(i)); + for (size_t i = 0; i < op->num_results(); ++i) { + value_view.insert(op->result(i)); } + }); - // Delete origin group ops - std::vector group_ops; - for (auto iter = block->rbegin(); iter != block->rend(); iter++) { - if (!iter->isa()) { - group_ops.push_back(&(*iter)); - } - } - for (auto* op : group_ops) { - if (op->use_empty()) { - op->Erase(); - } + GroupDimExprInfo group_dim_expr_info; + for (auto value : value_view) { + const auto& shape_dim_expr = group->GetShapeOrDataExprs(value); + const auto& data_shape = shape_dim_expr.data(); + if (data_shape) { + group_dim_expr_info.all_value_dim_exprs->push_back(*data_shape); + } else { + group_dim_expr_info.all_value_dim_exprs->push_back( + shape_dim_expr.shape()); } + group_dim_expr_info.value_to_dim_expr_idx[value] = + group_dim_expr_info.all_value_dim_exprs->size() - 1; } + return group_dim_expr_info; +} + +bool NeedBroadcastWithCF(const OpLoweringGroupPtr& group) { + GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group); + const auto& leaves = group_dim_expr_info.all_value_dim_exprs; + return NeedBroadcastWithCF(leaves); +} + +bool NeedBroadcastWithCF(const cinn::common::BroadcastLeaf& leaves) { + std::optional> + broadcastable_condition = cinn::common::GetFirstCstrBroadcastable(leaves); + return broadcastable_condition.has_value(); } pir::Operation* CompileBroadcastTreeToConditionBlock( - const cinn::common::BroadcastTree& broadcast_tree, - const GroupPtr& group, + const OpLoweringGroupPtr& group, + const BroadcastTree& broadcast_tree, pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT - const std::shared_ptr& pir_compiler, const std::unordered_map& value_to_dim_expr_idx, const std::vector& group_inputs, const std::vector& output_types, pir::PatternRewriter& rewriter) { // NOLINT // 1. broadcast tree to condition op VLOG(4) << "broadcast tree to condition op"; - std::unordered_map group_map; + std::unordered_map group_map; pir::Operation* cond_op = CreateConditionBlock(broadcast_tree, group, shape_analysis, @@ -512,286 +493,16 @@ pir::Operation* CompileBroadcastTreeToConditionBlock( rewriter.block(), &group_map); // 2. simply every condition block - auto* program = group->ops.front()->GetParentProgram(); + auto* program = group->ops().front()->GetParentProgram(); VLOG(6) << "Before simply condition block: " << *program; SimplyConditionBlock(rewriter, &group_map); VLOG(6) << "After simply condition block: " << *program; // 3. compile condition block to jit_kernel_op - CompileGroupToJitKernelOp(group_inputs, pir_compiler, rewriter, &group_map); + CompileGroupToJitKernelOp(rewriter, &group_map); VLOG(6) << "compile condition block to jit_kernel_op: " << *program; return cond_op; } - -pir::Operation* ProcessDyShapeGroup( - const GroupPtr& group, - pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT - const std::shared_ptr& pir_compiler, - pir::PatternRewriter& rewriter) { // NOLINT - std::unordered_set value_view; - group->WalkOps([&group, &value_view](pir::Operation* op) { - for (size_t i = 0; i < op->num_operands(); ++i) { - value_view.insert(op->operand_source(i)); - } - for (size_t i = 0; i < op->num_results(); ++i) { - value_view.insert(op->result(i)); - } - }); - - // construct broadcast tree - VLOG(4) << "construct broadcast tree"; - cinn::adt::List> all_value_dim_exprs; - std::unordered_map value_to_dim_expr_idx; - for (auto value : value_view) { - const auto& shape_dim_expr = group->GetShapeOrDataExprs(value); - const auto& data_shape = shape_dim_expr.data(); - if (data_shape) { - all_value_dim_exprs->push_back(*data_shape); - } else { - all_value_dim_exprs->push_back(shape_dim_expr.shape()); - } - value_to_dim_expr_idx[value] = all_value_dim_exprs->size() - 1; - } - VLOG(6) << "before constructed. broadcast-leaf: \n" - << ToTxtString(cinn::common::BroadcastTree(all_value_dim_exprs)); - cinn::common::BroadcastTree broadcast_tree = - cinn::common::ConstructBroadcastTree( - cinn::common::BroadcastLeaf(all_value_dim_exprs)); - VLOG(4) << "broadcast-tree: \n" << ToTxtString(broadcast_tree); - - auto group_inputs = GetBlockOutsideInput(group->ops); - - // has multiple branch - if (broadcast_tree - .Has>()) { - std::vector output_types; - auto group_output_values = group->GetGroupOutputValues(); - for (size_t i = 0; i < group_output_values.size(); ++i) { - output_types.push_back(group_output_values[i].type()); - } - return CompileBroadcastTreeToConditionBlock(broadcast_tree, - group, - shape_analysis, - pir_compiler, - value_to_dim_expr_idx, - group_inputs, - output_types, - rewriter); - } else { // no condition block - // compile group to jit_kernel_op - auto op_attr_map = CompileGroupAsOpAttribute(pir_compiler, {group}); - std::vector output_types; - const auto& group_output_values = group->output_values; - for (size_t i = 0; i < group_output_values.size(); ++i) { - output_types.push_back(group_output_values[i].type()); - } - auto jit_kernel_op = rewriter.Build( - group_inputs, op_attr_map.at(group), output_types); - return jit_kernel_op; - } -} - -std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> -CreateGroupShapeOrDataExprs( - const GroupPtr& group, - pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT - std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> value2shape; - for (auto* op : group->ops) { - for (size_t i = 0; i < op->num_operands(); ++i) { - auto operand = op->operand_source(i); - if (operand && value2shape.find(operand) == value2shape.end() && - shape_analysis.HasShapeOrDataForValue(operand)) { - value2shape.insert( - {operand, shape_analysis.GetShapeOrDataForValue(operand)}); - } - } - for (size_t i = 0; i < op->num_results(); ++i) { - auto result = op->result(i); - if (result && value2shape.find(result) == value2shape.end() && - shape_analysis.HasShapeOrDataForValue(result)) { - value2shape.insert( - {result, shape_analysis.GetShapeOrDataForValue(result)}); - } - } - } - return value2shape; -} - -class FusionOpPattern : public pir::OpRewritePattern { - public: - explicit FusionOpPattern(::pir::IrContext* context) - : pir::OpRewritePattern(context) {} - - bool MatchAndRewrite(cinn::dialect::FusionOp fusion_op, - pir::PatternRewriter& rewriter) const override { - ::pir::IrContext* ctx = ::pir::IrContext::Instance(); - auto target = cinn::common::DefaultNVGPUTarget(); - // TODO(Aurelius84): Remove scope after cleaning PirCompiler useless Build - // Interface - auto scope = std::make_shared(); - auto* program = fusion_op->GetParentProgram(); - auto ir_compiler = cinn::hlir::framework::PirCompilerManager::Create( - *program, target, scope); - auto group = RebuildGroup(fusion_op); - // Because the group is rebuilt, the order of group.output_values generated - // by BuildCUDAJITInfo may not be same with the order bound in the yield op, - // so a mapping is required. - - auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get( - fusion_op->GetParentProgram()); - group->set_value_to_shape_or_data_exprs( - CreateGroupShapeOrDataExprs(group, shape_analysis)); - if (FLAGS_cinn_enable_map_expr) { - cinn::adt::TryGenerateMapExprFromGroup(group); - } - - // TODO(zhangyuqin1998): Replace pir::Group with a new structure - pir::Operation* compiled_op = - ProcessGroup(group, shape_analysis, ir_compiler, rewriter); - - for (size_t i = 0; i < fusion_op.num_results(); ++i) { - rewriter.ReplaceAllUsesWith(fusion_op.result(i), compiled_op->result(i)); - if (shape_analysis.HasShapeOrDataForValue(fusion_op.result(i))) { - shape_analysis.SetShapeOrDataForValue( - compiled_op->result(i), - shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); - } else { - LOG(WARNING) << "No shape_data for " - << fusion_op.result(i).defining_op()->name() << "_result_" - << i; - } - } - - rewriter.EraseOp(fusion_op); - return true; - } - - protected: - virtual pir::Operation* ProcessGroup( - const GroupPtr& group, - pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT - const std::shared_ptr& pir_compiler, - pir::PatternRewriter& rewriter) const { // NOLINT - auto group_inputs = GetBlockOutsideInput(group->ops); - // compile group to jit_kernel_op - auto op_attr_map = CompileGroupAsOpAttribute(pir_compiler, {group}); - std::vector output_types; - const auto& group_output_values = group->output_values; - for (size_t i = 0; i < group_output_values.size(); ++i) { - output_types.push_back(group_output_values[i].type()); - } - auto jit_kernel_op = rewriter.Build( - group_inputs, op_attr_map.at(group), output_types); - return jit_kernel_op; - } - - private: - std::shared_ptr RebuildGroup(cinn::dialect::FusionOp fusion_op) const { - auto group = std::make_shared(); - group->op_pattern_kind = cinn::hlir::framework::OpPatternKind::kElementWise; - - // Rebuild ops of the group - for (auto op : fusion_op.GetOperators()) { - if (!op->isa<::pir::YieldOp>()) { - group->ops.push_back(op); - group->ops_set.insert(op); - group->op_pattern_kind = - static_cast(CompatibleInfo::OpKind(*op)) > - static_cast(group->op_pattern_kind) - ? CompatibleInfo::OpKind(*op) - : group->op_pattern_kind; - } - } - - // Rebuild output_ops and input_ops of the group - auto yield_op = fusion_op.GetOperators().back(); - for (size_t i = 0; i < yield_op->num_operands(); ++i) { - auto in = yield_op->operand_source(i); - group->output_values.push_back(in); - - group->output_ops.insert(in.defining_op()); - } - - // Rebuild other informations - // TODO(zhangyuqin1998): Do we need group.master_ops? - return group; - } -}; - -class DyShapeFusionOpPattern : public FusionOpPattern { - public: - using FusionOpPattern::FusionOpPattern; - - protected: - virtual pir::Operation* ProcessGroup( - const GroupPtr& group, - pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT - const std::shared_ptr& pir_compiler, - pir::PatternRewriter& rewriter) const { // NOLINT - return ProcessDyShapeGroup(group, shape_analysis, pir_compiler, rewriter); - } -}; - -class LowerCinnFusionOpPass : public pir::PatternRewritePass { - public: - LowerCinnFusionOpPass() - : pir::PatternRewritePass("lower_cinn_fusion_op", 1) {} - - pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { - context->GetOrRegisterDialect(); - context->GetOrRegisterDialect(); - context->GetOrRegisterDialect(); - - pir::RewritePatternSet ps(context); - ps.Add(context); - - return ps; - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->num_regions() > 0; - } -}; - -class LowerCinnDyShapeFusionOpPass : public pir::PatternRewritePass { - public: - LowerCinnDyShapeFusionOpPass() - : pir::PatternRewritePass("lower_cinn_dynamic_shape_fusion_op", 1) {} - - pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { - context->GetOrRegisterDialect(); - context->GetOrRegisterDialect(); - context->GetOrRegisterDialect(); - - pir::RewritePatternSet ps(context); - ps.Add(context); - - return ps; - } - - bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; - } -}; - -} // namespace - -namespace cinn { -namespace dialect { -namespace ir { - -std::unique_ptr<::pir::Pass> CreateLowerCinnFusionOpPass() { - return std::make_unique(); -} - -std::unique_ptr<::pir::Pass> CreateLowerCinnDyShapeFusionOpPass() { - return std::make_unique(); -} - -} // namespace ir -} // namespace dialect -} // namespace cinn - -// REGISTER_IR_PASS(cinn_group_lowering, LowerCinnFusionOpPass); +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h new file mode 100644 index 0000000000000..0ef058de08ef5 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h @@ -0,0 +1,46 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/common/broadcast_tree.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h" +#include "paddle/pir/include/pattern_rewrite/pattern_match.h" + +namespace cinn::dialect::ir::details { +using cinn::common::BroadcastTree; + +class BroadcastTreeInfo; + +struct GroupDimExprInfo { + common::BroadcastLeaf all_value_dim_exprs; + std::unordered_map value_to_dim_expr_idx; +}; + +std::shared_ptr ConstructBroadcastTree( + const common::BroadcastLeaf& leaves); + +bool NeedBroadcastWithCF(const OpLoweringGroupPtr& group); +bool NeedBroadcastWithCF(const common::BroadcastLeaf& leaves); +GroupDimExprInfo GetGroupDimExprInfo(const OpLoweringGroupPtr& group); + +pir::Operation* CompileBroadcastTreeToConditionBlock( + const OpLoweringGroupPtr& group, + const BroadcastTree& broadcast_tree, + pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT + const std::unordered_map& value_to_dim_expr_idx, + const std::vector& group_inputs, + const std::vector& output_types, + pir::PatternRewriter& rewriter // NOLINT +); +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc new file mode 100644 index 0000000000000..4ef8a486f21e0 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.cc @@ -0,0 +1,276 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" + +namespace { +using cinn::dialect::ir::details::GetBlockOutsideInput; +using cinn::dialect::ir::details::OpLoweringGroup; +using cinn::dialect::ir::details::OpLoweringGroupPtr; + +bool IsComplicatedDimExpr(const symbol::DimExpr& dim_expr) { + auto lambdas = symbol::Overloaded{ + [](std::int64_t dim_expr) { return false; }, + [](const std::string& dim_expr) { return false; }, + [](const symbol::Negative& dim_expr) { return true; }, + [](const symbol::Reciprocal& dim_expr) { return true; }, + [](const symbol::Add& dim_expr) { return true; }, + [](const symbol::Mul& dim_expr) { return true; }, + [](const symbol::Max& dim_expr) { return true; }, + [](const symbol::Min& dim_expr) { return true; }, + [](const symbol::Broadcast& dim_expr) { return true; }}; + return std::visit(lambdas, dim_expr.variant()); +} + +template +void VisitEachInputValue(const OpLoweringGroupPtr& group, + const DoEachT& DoEach) { + for (pir::Value value : GetBlockOutsideInput(group->ops())) { + DoEach(value); + } +} + +template +void VisitEachDimExprFromTensorShapeOrData( + const symbol::TensorShapeOrDataDimExprs& shape_or_data, + const DoEachT& DoEach) { + for (const auto& dim_expr : shape_or_data.shape()) { + DoEach(dim_expr); + } + if (!shape_or_data.data().has_value()) { + return; + } + for (const auto& dim_expr : shape_or_data.data().value()) { + DoEach(dim_expr); + } +} + +template +void VisitEachDimExpr(const symbol::ShapeOrDataDimExprs& shape_or_data, + const DoEachT& DoEach) { + auto lambdas = symbol::Overloaded{ + [&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) { + VisitEachDimExprFromTensorShapeOrData(tensor_shape_or_data, DoEach); + }, + [&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) { + symbol::TensorListShapeOrDataDimExprs simplified_tensor_list; + for (const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data : + tensor_list) { + VisitEachDimExprFromTensorShapeOrData(tensor_shape_or_data, DoEach); + } + }}; + return std::visit(lambdas, shape_or_data.variant()); +} + +std::unordered_map +CollectSubstituteDimExprMap( + const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT + std::unordered_map dim_expr_map; + std::unordered_set base_dim_expr_set; + + VisitEachInputValue(group, [&](::pir::Value value) { + if (!shape_analysis.HasShapeOrDataForValue(value)) { + return; + } + auto& shape_or_data = shape_analysis.GetShapeOrDataForValue(value); + VisitEachDimExpr(shape_or_data, [&](const symbol::DimExpr& dim_expr) { + if (IsComplicatedDimExpr(dim_expr) && + dim_expr_map.find(dim_expr) == dim_expr_map.end()) { + dim_expr_map[dim_expr] = + symbol::DimExpr(shape_analysis.GetNextSymName()); + } + if (dim_expr.isa()) { + base_dim_expr_set.insert(dim_expr.Get()); + } + }); + }); + + const std::unordered_set dim_exprs_no_outer_symbol = [&] { + auto HasOuterBasicSymbol = [&](const symbol::DimExpr& dim_expr) { + for (const auto& symbol : symbol::CollectDimExprSymbols(dim_expr)) { + if (base_dim_expr_set.count(symbol) == 0) { + return true; + } + } + return false; + }; + std::unordered_set result; + for (const auto& kv : dim_expr_map) { + if (IsComplicatedDimExpr(kv.first) && !HasOuterBasicSymbol(kv.first)) { + result.insert(kv.first); + } + } + return result; + }(); + for (const auto& dim_expr : dim_exprs_no_outer_symbol) { + dim_expr_map.erase(dim_expr); + } + + return dim_expr_map; +} + +bool IsShapeOrDataNeedSubstitute( + const symbol::ShapeOrDataDimExprs& shape_or_data, + const std::unordered_map& dim_expr_map) { + bool ret = false; + VisitEachDimExpr(shape_or_data, [&](const symbol::DimExpr& dim_expr) { + if (dim_expr_map.find(dim_expr) != dim_expr_map.end()) { + ret = true; + } + }); + return ret; +} + +symbol::TensorShapeOrDataDimExprs SubstituteTensorShapeOrData( + const symbol::TensorShapeOrDataDimExprs& shape_or_data, + const std::unordered_map& dim_expr_map) { + const auto& SimplifyDimExpr = + [&](const std::vector& original_dim_expr) + -> std::vector { + std::vector simplified_dim_expr{}; + for (const symbol::DimExpr& dim_expr : original_dim_expr) { + simplified_dim_expr.push_back(symbol::SimplifyDimExpr( + symbol::SubstituteDimExpr(dim_expr, dim_expr_map))); + } + return simplified_dim_expr; + }; + + std::vector simplified_shape = + SimplifyDimExpr(shape_or_data.shape()); + if (!shape_or_data.data().has_value()) { + return symbol::ShapeOrData(simplified_shape); + } + std::vector simplified_data = + SimplifyDimExpr(shape_or_data.data().value()); + return symbol::ShapeOrData(simplified_shape, + simplified_data); +} + +symbol::ShapeOrDataDimExprs SubstituteShapeOrData( + const symbol::ShapeOrDataDimExprs& shape_or_data, + const std::unordered_map& dim_expr_map) { + auto lambdas = symbol::Overloaded{ + [&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) { + return symbol::ShapeOrDataDimExprs( + SubstituteTensorShapeOrData(tensor_shape_or_data, dim_expr_map)); + }, + [&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) { + symbol::TensorListShapeOrDataDimExprs simplified_tensor_list; + for (symbol::TensorShapeOrDataDimExprs tensor_shape_or_data : + tensor_list) { + simplified_tensor_list.push_back( + SubstituteTensorShapeOrData(tensor_shape_or_data, dim_expr_map)); + } + return symbol::ShapeOrDataDimExprs(simplified_tensor_list); + }}; + return std::visit(lambdas, shape_or_data.variant()); +} + +symbol::ShapeOrDataDimExprs TrySubstitute( + const symbol::ShapeOrDataDimExprs& shape_or_data, + const std::unordered_map& dim_expr_map) { + if (!IsShapeOrDataNeedSubstitute(shape_or_data, dim_expr_map)) { + return shape_or_data; + } + return SubstituteShapeOrData(shape_or_data, dim_expr_map); +} + +void InferSymbolicShapeForOperation( + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + auto infer_symbolic_shape_interface = + op->dyn_cast(); + if (infer_symbolic_shape_interface) { + infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " DOES NOT have InferSymbolicShapeInterface!")); + } +} + +std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> +GetGroupValue2Shape(const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT + std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> value2shape; + for (auto op : group->ops()) { + for (size_t i = 0; i < op->num_operands(); ++i) { + auto operand = op->operand_source(i); + if (operand && value2shape.find(operand) == value2shape.end() && + shape_analysis.HasShapeOrDataForValue(operand)) { + VLOG(6) << "Add value_to_shape_or_data_exprs for " << operand.impl(); + value2shape.insert( + {operand, shape_analysis.GetShapeOrDataForValue(operand)}); + } + } + for (size_t i = 0; i < op->num_results(); ++i) { + auto result = op->result(i); + if (result && value2shape.find(result) == value2shape.end() && + shape_analysis.HasShapeOrDataForValue(result)) { + VLOG(6) << "Add value_to_shape_or_data_exprs for " << result.impl(); + value2shape.insert( + {result, shape_analysis.GetShapeOrDataForValue(result)}); + } + } + } + return value2shape; +} + +} // namespace + +namespace cinn::dialect::ir::details { + +std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> +CreateGroupShapeOrDataExprs( + const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& global_shape_analysis) { // NOLINT + std::unordered_map dim_expr_map = + CollectSubstituteDimExprMap(group, global_shape_analysis); + std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> value2shape; + if (dim_expr_map.size() == 0) { + return GetGroupValue2Shape(group, global_shape_analysis); + } + + pir::ShapeConstraintIRAnalysis local_shape_analysis({}); + + // process input values. + VisitEachInputValue(group, [&](::pir::Value value) { + auto new_shape_expr = TrySubstitute( + global_shape_analysis.GetShapeOrDataForValue(value), dim_expr_map); + local_shape_analysis.SetShapeOrDataForValue(value, new_shape_expr); + value2shape.insert({value, new_shape_expr}); + VLOG(6) << "Add value_to_shape_or_data_exprs for " << value.impl(); + }); + + // process the result values of each op. + for (auto* op : group->ops()) { + InferSymbolicShapeForOperation(op, &local_shape_analysis); + for (size_t i = 0; i < op->num_results(); ++i) { + auto result = op->result(i); + if (result && !value2shape.count(result) && + local_shape_analysis.HasShapeOrDataForValue(result)) { + VLOG(6) << "Add value_to_shape_or_data_exprs for " << result.impl(); + value2shape.insert( + {result, local_shape_analysis.GetShapeOrDataForValue(result)}); + } + } + } + VLOG(5) << group.get() + << " value_to_shape_or_data_exprs.size() : " << value2shape.size(); + return value2shape; +} + +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.h b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.h new file mode 100644 index 0000000000000..7cdb1755f3450 --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.h @@ -0,0 +1,29 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" +#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h" + +namespace cinn::dialect::ir::details { +using OpLoweringGroup = cinn::hlir::framework::pir::OpLoweringGroup; +using OpLoweringGroupPtr = std::shared_ptr; + +std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> +CreateGroupShapeOrDataExprs( + const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& shape_analysis // NOLINT +); + +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc new file mode 100644 index 0000000000000..0e7ebb8e9499d --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.cc @@ -0,0 +1,228 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace cinn::dialect::ir::details { + +pir::Operation* ProcessDyShapeGroup( + const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT + pir::PatternRewriter& rewriter) { // NOLINT + auto group_inputs = GetBlockOutsideInput(group->ops()); + GroupDimExprInfo group_dim_expr_info = GetGroupDimExprInfo(group); + const auto& leaves = group_dim_expr_info.all_value_dim_exprs; + // has multiple branch + if (NeedBroadcastWithCF(leaves)) { + const auto& value_to_dim_expr_idx = + group_dim_expr_info.value_to_dim_expr_idx; + const std::shared_ptr broadcast_tree = + ConstructBroadcastTree(leaves); + std::vector output_types; + auto group_output_values = group->GetGroupOutputValues(); + for (size_t i = 0; i < group_output_values.size(); ++i) { + output_types.push_back(group_output_values[i].type()); + } + return CompileBroadcastTreeToConditionBlock(group, + *broadcast_tree, + shape_analysis, + value_to_dim_expr_idx, + group_inputs, + output_types, + rewriter); + } else { // no condition block + // compile group to jit_kernel_op + std::vector output_types; + const auto& group_output_values = group->output_values(); + for (size_t i = 0; i < group_output_values.size(); ++i) { + auto base_type = + group_output_values[i].type().dyn_cast<::pir::DenseTensorType>(); + auto dim_info = base_type.dims(); + if (shape_analysis.HasShapeOrDataForValue(group_output_values[i])) { + auto shape = group->GetShapeOrDataExprs(group_output_values[i]).shape(); + for (size_t k = 0; k < shape.size(); ++k) { + if (shape[k].isa()) { + dim_info[k] = shape[k].Get(); + } + } + } + auto new_type = ::pir::DenseTensorType::get(pir::IrContext::Instance(), + base_type.dtype(), + dim_info, + base_type.data_layout(), + base_type.lod(), + base_type.offset()); + output_types.push_back(new_type); + } + auto jit_kernel_op = rewriter.Build( + group_inputs, GetJitKernelAttr(group), output_types); + return jit_kernel_op; + } +} +class FusionOpPattern : public pir::OpRewritePattern { + public: + FusionOpPattern(::pir::IrContext* context, const GroupInfoMap& group_infos) + : pir::OpRewritePattern(context), + group_infos_(group_infos) {} + + bool MatchAndRewrite(cinn::dialect::FusionOp fusion_op, + pir::PatternRewriter& rewriter) const override { + ::pir::IrContext* ctx = ::pir::IrContext::Instance(); + auto* program = fusion_op->GetParentProgram(); + auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program); + VLOG(4) << "Program before lowering: \n" + << pir::CustomPrintHelper(*program, shape_analysis.PrintHook()); + + // TODO(zhangyuqin1998): Replace pir::Group with a new structure + OpLoweringGroupPtr group = GetGroup(fusion_op); + pir::Operation* compiled_op = ProcessGroup(group, shape_analysis, rewriter); + + for (size_t i = 0; i < fusion_op.num_results(); ++i) { + rewriter.ReplaceAllUsesWith(fusion_op.result(i), compiled_op->result(i)); + if (shape_analysis.HasShapeOrDataForValue(fusion_op.result(i))) { + shape_analysis.SetShapeOrDataForValue( + compiled_op->result(i), + shape_analysis.GetShapeOrDataForValue(fusion_op.result(i))); + } else { + LOG(WARNING) << "No shape_data for " + << fusion_op.result(i).defining_op()->name() << "_result_" + << i; + } + } + rewriter.EraseOp(fusion_op); + return true; + } + + protected: + virtual OpLoweringGroupPtr GetGroup(cinn::dialect::FusionOp fusion_op) const { + return group_infos_.at(fusion_op.operation()); + } + + virtual pir::Operation* ProcessGroup( + const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT + pir::PatternRewriter& rewriter) const { // NOLINT + auto group_inputs = GetBlockOutsideInput(group->ops()); + // compile group to jit_kernel_op + std::vector output_types; + const auto& group_output_values = group->output_values(); + for (size_t i = 0; i < group_output_values.size(); ++i) { + output_types.push_back(group_output_values[i].type()); + } + auto jit_kernel_op = rewriter.Build( + group_inputs, GetJitKernelAttr(group), output_types); + return jit_kernel_op; + } + + private: + const GroupInfoMap& group_infos_; // not owned +}; + +class LowerCinnFusionOpPass : public pir::PatternRewritePass { + public: + LowerCinnFusionOpPass() + : pir::PatternRewritePass("lower_cinn_fusion_op", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + context->GetOrRegisterDialect(); + context->GetOrRegisterDialect(); + + pir::RewritePatternSet ps(context); + ps.Add(context, group_infos_); + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + if (op->isa()) { + VLOG(5) << "start to pre-analysis all fusion ops in ModuleOp with static " + "shape mode."; + FusionOpAnalysis(&group_infos_, /*is_dy_shape=*/false).Run(op); + } + return op->num_regions() > 0; + } + + private: + mutable GroupInfoMap group_infos_; +}; + +class DyShapeFusionOpPattern : public FusionOpPattern { + public: + using FusionOpPattern::FusionOpPattern; + + protected: + virtual pir::Operation* ProcessGroup( + const OpLoweringGroupPtr& group, + pir::ShapeConstraintIRAnalysis& shape_analysis, // NOLINT + pir::PatternRewriter& rewriter) const { // NOLINT + return ProcessDyShapeGroup(group, shape_analysis, rewriter); + } +}; + +class LowerCinnDyShapeFusionOpPass : public pir::PatternRewritePass { + public: + LowerCinnDyShapeFusionOpPass() + : pir::PatternRewritePass("lower_cinn_dynamic_shape_fusion_op", 1) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { + context->GetOrRegisterDialect(); + context->GetOrRegisterDialect(); + + pir::RewritePatternSet ps(context); + ps.Add(context, group_infos_); + ps.Add(context); + + return ps; + } + + bool CanApplyOn(pir::Operation* op) const override { + if (op->isa()) { + VLOG(5) << "start to pre-analysis all fusion ops in ModuleOp with " + "dynamic shape mode."; + FusionOpAnalysis(&group_infos_, /*is_dy_shape=*/true).Run(op); + } + return op->num_regions() > 0; + } + + private: + mutable GroupInfoMap group_infos_; +}; + +} // namespace cinn::dialect::ir::details + +namespace cinn::dialect::ir { +std::unique_ptr<::pir::Pass> CreateLowerCinnFusionOpPass() { + return std::make_unique(); +} + +std::unique_ptr<::pir::Pass> CreateLowerCinnDyShapeFusionOpPass() { + return std::make_unique(); +} + +} // namespace cinn::dialect::ir + +// REGISTER_IR_PASS(cinn_group_lowering, LowerCinnFusionOpPass); diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.h b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h similarity index 100% rename from paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.h rename to paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.cc new file mode 100644 index 0000000000000..771ea930db38d --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.cc @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" + +namespace cinn::dialect::ir::details { +using cinn::hlir::framework::PirCompiler; + +void FusionOpAnalysis::GatherGroup(pir::Operation* fusion_op) { + OpLoweringGroupPtr group_ptr = BuildOpLoweringGroup(fusion_op); + VLOG(6) << "Gather Group " << group_ptr->FuncName() + << " for fusion_op : " << fusion_op->id(); + group_infos_->insert({fusion_op, group_ptr}); +} + +void FusionOpAnalysis::RunImpl(pir::Operation* op) { + if (op->isa()) { + GatherGroup(op); + return; + } + for (uint32_t i = 0; i < op->num_regions(); ++i) { + for (auto& block : op->region(i)) { + for (auto& op : block) { + RunImpl(&op); + } + } + } +} + +void FusionOpAnalysis::PreCompileGroup() { + std::vector groups; + for (auto& group_info : *group_infos_) { + if (is_dy_shape_ && NeedBroadcastWithCF(group_info.second)) continue; + groups.push_back(group_info.second); + } + // Build and trigger compilaion cache. + VLOG(4) << "Parallel Pre-Compile for Group with size: " << groups.size(); + PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget()); + pir_compiler.Build(groups); +} +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h new file mode 100644 index 0000000000000..4c539078ccada --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/pre_analysis.h @@ -0,0 +1,43 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/broadcast_with_cf.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" + +namespace cinn::dialect::ir::details { +using OpLoweringGroup = cinn::hlir::framework::pir::OpLoweringGroup; +using OpLoweringGroupPtr = std::shared_ptr; +using GroupInfoMap = std::unordered_map<::pir::Operation*, OpLoweringGroupPtr>; + +class FusionOpAnalysis final { + public: + FusionOpAnalysis(GroupInfoMap* group_infos, bool is_dy_shape) + : group_infos_(group_infos), is_dy_shape_(is_dy_shape) {} + void Run(pir::Operation* module_op) { + RunImpl(module_op); + PreCompileGroup(); + } + + protected: + void RunImpl(pir::Operation* op); + void GatherGroup(pir::Operation* fusion_op); + void PreCompileGroup(); + + private: + GroupInfoMap* group_infos_; // not_owned + bool is_dy_shape_; +}; +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc new file mode 100644 index 0000000000000..e4724c617dfaf --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h" + +#include "paddle/cinn/adt/generate_map_expr.h" +#include "paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/collect_sym_expr.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/jit_kernel_op.h" +#include "paddle/cinn/hlir/dialect/runtime/ir/runtime_dialect.h" +#include "paddle/cinn/hlir/framework/pir/compilation_cache.h" +#include "paddle/cinn/hlir/framework/pir_compiler.h" +#include "paddle/cinn/runtime/flags.h" + +PD_DECLARE_bool(cinn_enable_map_expr); + +namespace cinn::dialect::ir::details { + +using cinn::hlir::framework::CompilationCache; +using cinn::hlir::framework::PirCompiler; +using cinn::hlir::framework::pir::CINNKernelInfo; +using cinn::hlir::framework::pir::CompatibleInfo; + +std::vector GetBlockOutsideInput( + const std::vector& op_list) { + std::vector vec_res; + std::unordered_set<::pir::Value> block_inner_output; + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_results(); ++i) { + block_inner_output.insert(op_list[k]->result(i)); + } + } + + std::unordered_set<::pir::Value> insert_value; + for (size_t k = 0; k < op_list.size(); ++k) { + for (size_t i = 0; i < op_list[k]->num_operands(); ++i) { + if (!block_inner_output.count(op_list[k]->operand_source(i)) && + !insert_value.count(op_list[k]->operand_source(i))) { + vec_res.push_back(op_list[k]->operand_source(i)); + insert_value.insert(op_list[k]->operand_source(i)); + } + } + } + return vec_res; +} + +std::unordered_map> +CompileGroupAsOpAttribute(const std::vector& group_list) { + PirCompiler pir_compiler(cinn::common::DefaultNVGPUTarget()); + auto fn_ptr_res = pir_compiler.Build(group_list); + + std::unordered_map> + result; + for (size_t i = 0; i < group_list.size(); ++i) { + std::unordered_map op_attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CINNKernelInfoAttribute::get(pir::IrContext::Instance(), + fn_ptr_res[i])}, + }; + result.insert({group_list[i], op_attrs}); + } + return result; +} + +std::unordered_map GetJitKernelAttr( + const OpLoweringGroupPtr& group) { + auto kernel_info = CompilationCache::Instance().GetKernelInfo(group); + std::unordered_map attrs{ + {cinn::dialect::JitKernelOp::kAttrName, + cinn::dialect::CINNKernelInfoAttribute::get(pir::IrContext::Instance(), + kernel_info)}}; + return attrs; +} + +OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr) { + auto fusion_op = fusion_op_ptr->dyn_cast(); + auto group = std::make_shared(); + group->set_op_pattern_kind( + cinn::hlir::framework::OpPatternKind::kElementWise); + if (fusion_op.attributes().count("group_info")) { + auto attr = fusion_op.attribute("group_info") + .dyn_cast() + .data(); + + group->set_op_pattern_kind(attr.op_pattern_kind); + group->set_loop_ranges(attr.loop_ranges); + group->set_loop_ranges_expr(attr.loop_ranges_expr); + + group->set_reduce_axis(attr.reduce_axis); + group->set_alignment_schedule_info(attr.alignment_schedule_info); + } + + // Rebuild ops of the group + for (auto op : fusion_op.GetOperators()) { + if (!op->isa<::pir::YieldOp>()) { + group->mut_ops().push_back(op); + auto op_pattern_kind = static_cast(CompatibleInfo::OpKind(*op)) > + static_cast(group->op_pattern_kind()) + ? CompatibleInfo::OpKind(*op) + : group->op_pattern_kind(); + group->set_op_pattern_kind(op_pattern_kind); + } + } + + // Rebuild output_ops and input_ops of the group + auto yield_op = fusion_op.GetOperators().back(); + for (size_t i = 0; i < yield_op->num_operands(); ++i) { + auto in = yield_op->operand_source(i); + group->mut_output_values().push_back(in); + group->mut_output_ops().insert(in.defining_op()); + } + + // Because the group is rebuilt, the order of group.output_values generated + // by BuildCUDAJITInfo may not be same with the order bound in the yield op, + // so a mapping is required. + auto& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(fusion_op->GetParentProgram()); + group->set_value_to_shape_or_data_exprs( + CreateGroupShapeOrDataExprs(group, shape_analysis)); + if (FLAGS_cinn_enable_map_expr) { + cinn::adt::TryGenerateMapExprFromGroup(group); + } + // Rebuild other informations + // TODO(zhangyuqin1998): Do we need group.master_ops? + return group; +} + +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h new file mode 100644 index 0000000000000..3b3ba4379d57c --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/utils.h @@ -0,0 +1,34 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" + +namespace cinn::dialect::ir::details { +using OpLoweringGroup = cinn::hlir::framework::pir::OpLoweringGroup; +using OpLoweringGroupPtr = std::shared_ptr; + +std::vector GetBlockOutsideInput( + const std::vector& op_list); + +std::unordered_map> +CompileGroupAsOpAttribute(const std::vector& group_list); + +std::unordered_map GetJitKernelAttr( + const OpLoweringGroupPtr& group); + +OpLoweringGroupPtr BuildOpLoweringGroup(pir::Operation* fusion_op_ptr); + +} // namespace cinn::dialect::ir::details diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index ad6c7b9a060da..3bf32aa91837d 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -17,8 +17,10 @@ #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/include/core/builtin_dialect.h" #include "paddle/pir/include/core/builtin_op.h" @@ -145,8 +147,8 @@ class ScaleOpPattern : public pir::OpRewritePattern { using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::ScaleOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); - return flag; + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); + return !is_denied; } void Rewrite(paddle::dialect::ScaleOp op, @@ -199,17 +201,16 @@ class ReshapeOpPattern using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::ReshapeOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto scale_factor_gen_op = op->operand_source(1).defining_op(); auto full_op = scale_factor_gen_op->dyn_cast(); - return flag && full_op; + return !is_denied && full_op; } void Rewrite(paddle::dialect::ReshapeOp op, pir::PatternRewriter &rewriter) const override { auto scale_factor_gen_op = op->operand_source(1).defining_op(); - auto full_op = scale_factor_gen_op->dyn_cast(); // scale is generator by full op @@ -243,11 +244,11 @@ class Pool2dOpPattern using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::Pool2dOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto kernel_size_gen_op = op->operand_source(1).defining_op(); auto full_op = kernel_size_gen_op->dyn_cast(); - return flag && full_op; + return !is_denied && full_op; } void Rewrite(paddle::dialect::Pool2dOp op, @@ -289,14 +290,14 @@ class IsCloseOpPattern using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::IscloseOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto rtol_op = op->operand_source(2) .defining_op() ->dyn_cast(); auto atol_op = op->operand_source(3) .defining_op() ->dyn_cast(); - return flag && rtol_op && atol_op; + return !is_denied && rtol_op && atol_op; } void Rewrite(paddle::dialect::IscloseOp op, @@ -332,7 +333,7 @@ class SliceOpPattern : public pir::OpRewritePattern { using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::SliceOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto start_gen_op = op->operand_source(1) .defining_op() ->dyn_cast(); @@ -340,7 +341,7 @@ class SliceOpPattern : public pir::OpRewritePattern { auto end_gen_op = op->operand_source(2) .defining_op() ->dyn_cast(); - return flag && start_gen_op && end_gen_op; + return !is_denied && start_gen_op && end_gen_op; } void Rewrite(paddle::dialect::SliceOp op, @@ -381,9 +382,9 @@ class ConcatOpPattern using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::ConcatOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto axis_gen_op = op->operand_source(1).defining_op(); - return flag && axis_gen_op->dyn_cast(); + return !is_denied && axis_gen_op->dyn_cast(); } void Rewrite(paddle::dialect::ConcatOp op, @@ -409,8 +410,8 @@ class PowOpPattern : public pir::OpRewritePattern { using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::PowOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); - return flag; + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); + return !is_denied; } void Rewrite(paddle::dialect::PowOp op, @@ -429,6 +430,46 @@ class PowOpPattern : public pir::OpRewritePattern { } }; +class ElementwisePowOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern< + paddle::dialect::ElementwisePowOp>::OpRewritePattern; + + bool Match(paddle::dialect::ElementwisePowOp op) const override { + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); + auto y_op = op->operand_source(1) + .defining_op() + ->dyn_cast(); + return !is_denied && y_op; + } + + void Rewrite(paddle::dialect::ElementwisePowOp op, + pir::PatternRewriter &rewriter) const override { + auto y_op = op->operand_source(1) + .defining_op() + ->dyn_cast(); + auto factor = + y_op.attribute("value").dyn_cast<::pir::FloatAttribute>().data(); + if (factor == 2.0) { + auto multiply = rewriter.Build( + op->operand_source(0), op->operand_source(0)); + rewriter.ReplaceAllUsesWith(op.result(0), multiply.result(0)); + rewriter.EraseOp(op); + } else if (factor == -0.5) { + auto rsqrt = + rewriter.Build(op->operand_source(0)); + rewriter.ReplaceAllUsesWith(op.result(0), rsqrt.result(0)); + rewriter.EraseOp(op); + } else if (factor == 0.5) { + auto sqrt = + rewriter.Build(op->operand_source(0)); + rewriter.ReplaceAllUsesWith(op.result(0), sqrt.result(0)); + rewriter.EraseOp(op); + } + } +}; + static void ReplaceSliceOp(const cinn::dialect::SplitOp &cinn_split, pir::Operation *slice_op, pir::PatternRewriter &rewriter) { // NOLINT @@ -456,14 +497,14 @@ class SplitOpPattern : public pir::OpRewritePattern { using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::SplitOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto sections_gen_op = op->operand_source(1) .defining_op() ->dyn_cast(); auto axis_gen_op = op->operand_source(2) .defining_op() ->dyn_cast(); - return flag && sections_gen_op && axis_gen_op; + return !is_denied && sections_gen_op && axis_gen_op; } void Rewrite(paddle::dialect::SplitOp op, @@ -528,10 +569,10 @@ class SplitWithNumOpPattern paddle::dialect::SplitWithNumOp>::OpRewritePattern; bool Match(paddle::dialect::SplitWithNumOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto axis_gen_op = op->operand_source(1).defining_op(); auto full_op = axis_gen_op->dyn_cast(); - return flag && full_op; + return !is_denied && full_op; } void Rewrite(paddle::dialect::SplitWithNumOp op, @@ -618,11 +659,11 @@ class ExpandOpPattern using pir::OpRewritePattern::OpRewritePattern; bool Match(paddle::dialect::ExpandOp op) const override { - bool flag = CompatibleInfo::IsSupportCinn(*op.operation()); + const bool is_denied = CompatibleInfo::IsDeniedForCinn(*op.operation()); auto out_shape_gen_op = op->operand_source(1) .defining_op() ->dyn_cast(); - return flag && out_shape_gen_op; + return !is_denied && out_shape_gen_op; } void Rewrite(paddle::dialect::ExpandOp op, @@ -712,6 +753,43 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase { } }; +class FullWithTensorOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern< + paddle::dialect::FullWithTensorOp>::OpRewritePattern; + + bool MatchAndRewrite(paddle::dialect::FullWithTensorOp op, + pir::PatternRewriter &rewriter) const override { + auto shape = op->operand_source(0); + auto value = op->operand_source(1); + + if (paddle::dialect::TransToPhiDataType( + value.type() + .dyn_cast() + .dtype()) != op.attribute("dtype") + .dyn_cast() + .data()) { + value = rewriter + .Build( + value, + op.attribute("dtype") + .dyn_cast() + .data()) + .result(0); + } + + auto out = + rewriter.Build(value, shape).result(0); + + rewriter.ReplaceAllUsesWith(op.result(0), out); + + rewriter.EraseOp(op); + + return true; + } +}; + PdOpToCinnOpPass::PdOpToCinnOpPass() : pir::PatternRewritePass("pd_to_cinn_pass", 1) {} @@ -725,22 +803,22 @@ pir::RewritePatternSet PdOpToCinnOpPass::InitializePatterns( ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); ps.Add(context); - ps.Add(context); + ps.Add(context); ps.Add(context); ps.Add(context); - ps.Add(context); - ps.Add(context); ps.Add(context); - ps.Add(context); + // ps.Add(context); ps.Add(context); ps.Add(context); - // ps.Add(paddle::drr::Create(context)); + ps.Add(context); + ps.Add(context); + ps.Add(context); return ps; } bool PdOpToCinnOpPass::CanApplyOn(pir::Operation *op) const { - return op->isa() && op->num_regions() > 0; + return op->num_regions() > 0; } std::unique_ptr CreatePdOpToCinnOpPass() { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h b/paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h new file mode 100644 index 0000000000000..ddfb8bdc34acf --- /dev/null +++ b/paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/pattern_rewrite/pattern_match.h" + +class RefreshCombineOpPattern + : public ::pir::OpRewritePattern<::pir::CombineOp> { + public: + using ::pir::OpRewritePattern<::pir::CombineOp>::OpRewritePattern; + bool MatchAndRewrite(pir::CombineOp op, + pir::PatternRewriter& rewriter) const override { + auto new_combine_op = rewriter.Build<::pir::CombineOp>(op.inputs()); + rewriter.ReplaceAllUsesWith(op.result(0), new_combine_op.result(0)); + rewriter.EraseOp(op); + return true; + } +}; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/remove_unchanged_reshape_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/remove_unchanged_reshape_pass.cc index 1f885ef0185e0..a2c09cc14a8dc 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/remove_unchanged_reshape_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/remove_unchanged_reshape_pass.cc @@ -16,8 +16,10 @@ #include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/refresh_combine_pattern.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -32,30 +34,54 @@ namespace cinn { namespace dialect { namespace ir { +using paddle::dialect::details::GetExprVecFromShape; + +bool RemoveOp(pir::Operation* op, pir::PatternRewriter* rewriter) { + const auto& IsDynamicShape = [](const pir::Value& value) -> bool { + return value.type().dyn_cast().IsDynamicShape(); + }; + const auto& GetDims = [](const pir::Value& value) -> decltype(auto) { + return value.type().dyn_cast().dims(); + }; + + pir::Value input = op->operand_source(0); + pir::Value output = op->result(0); + const auto& IsSameShape = [&]() -> bool { + const bool has_dynamic_shape = + IsDynamicShape(input) || IsDynamicShape(output); + if (has_dynamic_shape) { + auto& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + if (shape_analysis.HasShapeOrDataForValue(input) && + shape_analysis.HasShapeOrDataForValue(output)) { + auto input_sym_shape = + GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input)); + auto output_sym_shape = + GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output)); + return input_sym_shape == output_sym_shape; + } + return false; + } + return GetDims(input) == GetDims(output); + }; -class RemoveUnchangedReshapePattern - : public pir::OpRewritePattern { - public: - using pir::OpRewritePattern::OpRewritePattern; + if (IsSameShape()) { + rewriter->ReplaceAllUsesWith(output, input); + rewriter->EraseOp(op); + return true; + } - bool MatchAndRewrite(cinn::dialect::ReshapeOp op, - pir::PatternRewriter &rewriter) const override { - auto in_dim = op->operand_source(0) - .type() - .dyn_cast() - .dims(); - auto out_dim = op->result(0) - .type() - .dyn_cast() - .dims(); - - if (in_dim == out_dim) { - rewriter.ReplaceAllUsesWith(op->result(0), op->operand_source(0)); - rewriter.EraseOp(op); - return true; - } + return false; +} - return false; +template +class RemoveUnchangedReshapePattern : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(OPTYPE op, + pir::PatternRewriter& rewriter) const override { + return RemoveOp(op, &rewriter); } }; @@ -65,7 +91,7 @@ class MergeReshapePattern using pir::OpRewritePattern::OpRewritePattern; bool MatchAndRewrite(cinn::dialect::ReshapeOp op, - pir::PatternRewriter &rewriter) const override { + pir::PatternRewriter& rewriter) const override { if (auto pre_shape = op->operand_source(0) .defining_op() ->dyn_cast()) { @@ -83,17 +109,19 @@ class RemoveUnchangedReshapePass : public pir::PatternRewritePass { RemoveUnchangedReshapePass() : pir::PatternRewritePass("remove_unchanged_reshape_pass", 1) {} - pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { pir::RewritePatternSet ps(context); // remove out_shape equal in_shape reshape op - ps.Add(context); + ps.Add>(context); + ps.Add>(context); ps.Add(context); + ps.Add(context); return ps; } - bool CanApplyOn(pir::Operation *op) const override { + bool CanApplyOn(pir::Operation* op) const override { return op->num_regions() > 0; } }; diff --git a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc index b37ab970da882..3690a91eb4d37 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.cc @@ -33,12 +33,6 @@ class DynamicExpandOpPattern bool MatchAndRewrite(paddle::dialect::ExpandOp op, pir::PatternRewriter& rewriter) const override { - if (!op->operand_source(1) - .defining_op() - ->isa()) { - return false; - } - const ::pir::Operation* broadcast = [&] { int x_rank = op->operand_source(0) .type() @@ -52,7 +46,28 @@ class DynamicExpandOpPattern for (size_t i = 0; i < x_rank; ++i) { broadcast_axes[i] = i + index_gap; } - std::vector out_shape(out_rank, -1); + + pir::ShapeConstraintIRAnalysis& shape_analysis = + pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram()); + + const auto& GetOutputShapeByDimExpr = [&]() -> std::vector { + std::vector out_shape(out_rank, -1); + if (shape_analysis.HasShapeOrDataForValue(op->result(0))) { + VLOG(3) << "found shape dialect"; + auto shape_info = + shape_analysis.GetShapeOrDataForValue(op->result(0)).shape(); + + for (size_t i = 0; i < shape_info.size(); ++i) { + if (shape_info[i].isa()) { + out_shape[i] = shape_info[i].Get(); + } + } + } + return out_shape; + }; + + auto out_shape = GetOutputShapeByDimExpr(); + return rewriter.Build( op->operand_source(0), broadcast_axes, out_shape); }(); @@ -65,6 +80,20 @@ class DynamicExpandOpPattern broadcast->result(0), shape_analysis.GetShapeOrDataForValue(op.result(0))); + if (auto pre_full = broadcast->operand_source(0) + .defining_op() + ->dyn_cast()) { + auto input_dim = pre_full.result(0) + .type() + .dyn_cast() + .dims(); + if (input_dim.size() == 1 && input_dim[0] == 1) { + shape_analysis.SetShapeOrDataForValue( + pre_full->result(0), + shape_analysis.GetShapeOrDataForValue(op.result(0))); + } + } + rewriter.ReplaceAllUsesWith(op->result(0), broadcast->result(0)); rewriter.EraseOp(op); @@ -72,41 +101,20 @@ class DynamicExpandOpPattern } }; -class ReplaceDynamicExpandOpPass : public pir::Pass { +class ReplaceDynamicExpandOpPass : public pir::PatternRewritePass { public: ReplaceDynamicExpandOpPass() - : pir::Pass("replace_dynamic_expand_op_pass", /*opt_level=*/1) {} + : pir::PatternRewritePass("replace_dynamic_expand_op_pass", 1) {} - bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { pir::RewritePatternSet ps(context); ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation* op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 10; - for (uint32_t i = 0; i < op->num_regions(); ++i) { - for (auto& block : op->region(i)) { - for (auto& op : block) { - if (op.isa()) { - const auto& [_, num_rewrites] = - pir::ApplyPatternsGreedily(&op, patterns_, cfg); - AddStatistics(num_rewrites); - } - } - } - } + return ps; } bool CanApplyOn(pir::Operation* op) const override { - return op->num_regions() > 0; + return op->isa() && op->num_regions() > 0; } - - private: - pir::FrozenRewritePatternSet patterns_; }; std::unique_ptr CreateReplaceDynamicExpandOpPass() { diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc index dd9df65356a92..19e7f5060eb96 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc @@ -19,13 +19,14 @@ #include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/common/ddim.h" +#include "paddle/common/enforce.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_dialect.h" #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/include/pattern_rewrite/pattern_applicator.h" @@ -128,14 +129,16 @@ struct CachedDimExprToValueConverter { pir::Value ConvertToValueImpl( const symbol::Negative& dim_expr) { - LOG(FATAL) << "Dead code. This logical should handled by " - "ConvertToValueImpl(symbol::Add)"; + PADDLE_THROW( + phi::errors::Fatal("Dead code. This logical should handled by " + "ConvertToValueImpl(symbol::Add)")); } pir::Value ConvertToValueImpl( const symbol::Reciprocal& dim_expr) { - LOG(FATAL) << "Dead code. This logical should handled by " - "ConvertToValueImpl(symbol::Mul)"; + PADDLE_THROW( + phi::errors::Fatal("Dead code. This logical should handled by " + "ConvertToValueImpl(symbol::Mul)")); } pir::Value ConvertToValueImpl(const symbol::Add& dim_expr) { diff --git a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt index 3452dcd74ab9f..7e6183f4c5976 100644 --- a/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt +++ b/paddle/cinn/hlir/dialect/runtime/ir/CMakeLists.txt @@ -1,10 +1,8 @@ -if(NOT CINN_ONLY) - cinn_cc_library( - cinn_runtime_dialect - SRCS - runtime_dialect.cc - jit_kernel_op.cc - DEPS - cinn_op_dialect - pir) -endif() +cinn_cc_library( + cinn_runtime_dialect + SRCS + runtime_dialect.cc + jit_kernel_op.cc + DEPS + cinn_op_dialect + pir) diff --git a/paddle/cinn/hlir/framework/CMakeLists.txt b/paddle/cinn/hlir/framework/CMakeLists.txt index a9385d627828a..ee9af9fb44780 100755 --- a/paddle/cinn/hlir/framework/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/CMakeLists.txt @@ -24,11 +24,7 @@ gather_srcs( visualize_helper.cc compile_error.cc) -# TODO(Aurelius84): pir_compiler depends on op_dialect_vjp and could -# not found under CINN_ONLY mode -if(NOT CINN_ONLY) - cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp) -endif() +cinn_cc_library(pir_compiler SRCS pir_compiler.cc DEPS cinnapi op_dialect_vjp) if(WITH_CUDA) cinn_nv_test(test_hlir_framework_buffer SRCS buffer_test.cc DEPS cinncore) diff --git a/paddle/cinn/hlir/framework/graph_compiler.cc b/paddle/cinn/hlir/framework/graph_compiler.cc index 4ed9ff14d217b..1cbe88f9d98c5 100644 --- a/paddle/cinn/hlir/framework/graph_compiler.cc +++ b/paddle/cinn/hlir/framework/graph_compiler.cc @@ -422,7 +422,8 @@ std::vector GetFuncFromImpl( } else if (funcs.size() == expr_pack.size()) { funcs_after_schedule = funcs; } else { - LOG(FATAL) << "The number of funcs should not less than expr_pack's"; + PADDLE_THROW(phi::errors::InvalidArgument( + "The number of funcs should not less than expr_pack's")); } CHECK_EQ(funcs_after_schedule.size(), expr_pack.size()); std::vector res; diff --git a/paddle/cinn/hlir/framework/graph_compiler_util.cc b/paddle/cinn/hlir/framework/graph_compiler_util.cc index 7098ea015ce3b..5381055e5410c 100644 --- a/paddle/cinn/hlir/framework/graph_compiler_util.cc +++ b/paddle/cinn/hlir/framework/graph_compiler_util.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/cinn/hlir/framework/graph_compiler_util.h" -#include "paddle/cinn/utils/error.h" +#include "paddle/common/enforce.h" namespace cinn { namespace hlir { @@ -128,7 +128,7 @@ std::string CompilationResult::Message(int idx) const { ss << "The index(" << idx << ") is expected to be less than the size of group(" << lowered_funcs_.size() << ")."; - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return messages_[idx]; } @@ -145,7 +145,7 @@ std::vector> CompilationResult::LoweredFuncs() << "Some errors may have occurred during or before the lower " "process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } } return res; @@ -157,14 +157,14 @@ std::vector CompilationResult::LoweredFuncs(int idx) const { ss << "The index(" << idx << ") is expected to be less than the size of group(" << lowered_funcs_.size() << ")."; - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (!lowered_funcs_[idx].has_value()) { std::stringstream ss; ss << "LoweredFuncs of group[" << idx << "] is not generated.\n" << "Some errors may have occurred during or before the lower process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } return lowered_funcs_[idx].value(); } @@ -180,7 +180,7 @@ std::vector CompilationResult::SourceCodes() const { << "Some errors may have occurred during or before the codegen " "process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } } return res; @@ -192,7 +192,7 @@ std::string CompilationResult::SourceCode(int idx) const { ss << "The index(" << idx << ") is expected to be less than the size of group(" << lowered_funcs_.size() << ")."; - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (!source_codes_[idx].has_value()) { std::stringstream ss; @@ -200,7 +200,7 @@ std::string CompilationResult::SourceCode(int idx) const { << "Some errors may have occurred during or before the codegen " "process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } return source_codes_[idx].value(); } @@ -216,7 +216,7 @@ std::vector CompilationResult::SourcePtxs() const { << "Some errors may have occurred during or before the nvrtc compile " "process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } } return res; @@ -228,7 +228,7 @@ std::string CompilationResult::SourcePtx(int idx) const { ss << "The index(" << idx << ") is expected to be less than the size of group(" << lowered_funcs_.size() << ")."; - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (!source_ptxs_[idx].has_value()) { std::stringstream ss; @@ -236,7 +236,7 @@ std::string CompilationResult::SourcePtx(int idx) const { << "Some errors may have occurred during or before the nvrtc compile " "process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } return source_ptxs_[idx].value(); } @@ -253,7 +253,7 @@ CompilationResult::RuntimeInstructions() const { << "Some errors may have occurred during or before the build " "instruction process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } } return instructions_; @@ -268,7 +268,7 @@ const std::unique_ptr& CompilationResult::RuntimeInstruction( ss << "The index(" << idx << ") is expected to be less than the size of group(" << insts.size() << ")."; - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return insts[idx]; } @@ -279,7 +279,7 @@ std::unique_ptr CompilationResult::RuntimeProgram() { ss << "Runtime program is not generated.\n" << "Some errors may have occurred during the compilation process.\n" << Message(); - CINN_THROW(ss.str()); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } return std::move(runtime_program_); } diff --git a/paddle/cinn/hlir/framework/instruction_test.cc b/paddle/cinn/hlir/framework/instruction_test.cc index f665c628b5a0a..e7952a4ca160c 100644 --- a/paddle/cinn/hlir/framework/instruction_test.cc +++ b/paddle/cinn/hlir/framework/instruction_test.cc @@ -267,7 +267,7 @@ class TestInstruction : public Instruction { args_[18], stream_); } else { - LOG(FATAL) << "Unkown Conv Type!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown Conv Type!")); } CUDA_CALL(cudaStreamSynchronize(stream_)); } diff --git a/paddle/cinn/hlir/framework/op.h b/paddle/cinn/hlir/framework/op.h old mode 100755 new mode 100644 index 78e408c5e9980..1d53902816642 --- a/paddle/cinn/hlir/framework/op.h +++ b/paddle/cinn/hlir/framework/op.h @@ -239,7 +239,7 @@ bool OpValueType::Find(const Operator* op) const { static ::cinn::hlir::framework::Operator& __make_##HlirOp##_##OpName /** - * @def CINNR_REGISTER_OP + * @def CINN_REGISTER_OP * \brief Register a new operator, or set attribute of the corresponding op. * * @param OpName The name of registry diff --git a/paddle/cinn/hlir/framework/op_lowering.h b/paddle/cinn/hlir/framework/op_lowering.h index f1f1554870663..6b259e5423c99 100644 --- a/paddle/cinn/hlir/framework/op_lowering.h +++ b/paddle/cinn/hlir/framework/op_lowering.h @@ -78,13 +78,14 @@ inline OpLowerer CreateOpLowerer( } #ifndef CINN_WITH_ONLY -template +template OpLowerer CreateOpLowerer(const Target&); template <> -inline OpLowerer CreateOpLowerer(const Target& target) { +inline OpLowerer CreateOpLowerer( + const Target& target) { auto* impl_base = new pir::OpLowererImpl(target); - return OpLowerer(impl_base); + return OpLowerer(impl_base); } #endif diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.cc b/paddle/cinn/hlir/framework/op_lowering_impl.cc index a9bb46c8a4f26..0629968a07ac3 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/op_lowering_impl.cc @@ -31,9 +31,6 @@ namespace cinn { namespace hlir { namespace framework { -using cinn::common::bfloat16; -using cinn::common::float16; - using framework::Node; using framework::NodeData; using framework::OpPatternKind; @@ -74,7 +71,8 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, apply_pass, &OpLowererImpl::ReduceScheduleDetermineFunction); case framework::kOutFusible: - LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; + PADDLE_THROW(phi::errors::Unimplemented( + "Group Pattern Kind kOutFusible Is Not Implemented!")); case framework::kNonFusible: return LowerGroup(group, apply_op_schedule, @@ -82,7 +80,8 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, apply_pass, &OpLowererImpl::NonFusibleScheduleDetermineFunction); default: - LOG(FATAL) << "Group Pattern Kind Is Unknown!"; + PADDLE_THROW( + phi::errors::InvalidArgument("Group Pattern Kind Is Unknown!")); } } diff --git a/paddle/cinn/hlir/framework/op_lowering_impl.h b/paddle/cinn/hlir/framework/op_lowering_impl.h index 80c79b3c64b8d..ef18def90affc 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl.h @@ -28,9 +28,9 @@ #include "paddle/cinn/lang/packed_func.h" // Fusion Op lowering, there are four kinds of lowering function: -// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusable,NonFusible. +// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusible,NonFusible. // Elementwise/Broadcast/Injective Ops is with same schedule. -// Reduce,OutEWiseFusable,NonFusible are using different schedule. +// Reduce,OutEWiseFusible,NonFusible are using different schedule. namespace cinn { namespace hlir { diff --git a/paddle/cinn/hlir/framework/op_lowering_impl_base.h b/paddle/cinn/hlir/framework/op_lowering_impl_base.h index edd5c6e8e627e..4d5284f22f6ed 100644 --- a/paddle/cinn/hlir/framework/op_lowering_impl_base.h +++ b/paddle/cinn/hlir/framework/op_lowering_impl_base.h @@ -19,9 +19,9 @@ #include "paddle/cinn/ir/lowered_func.h" // Fusion Op lowering, there are four kinds of lowering function: -// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusable,NonFusible. +// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusible,NonFusible. // Elementwise/Broadcast/Injective Ops is with same schedule. -// Reduce,OutEWiseFusable,NonFusible are using different schedule. +// Reduce,OutEWiseFusible,NonFusible are using different schedule. namespace cinn { namespace hlir { diff --git a/paddle/cinn/hlir/framework/op_lowering_util.cc b/paddle/cinn/hlir/framework/op_lowering_util.cc index 2366fd584aa0b..1948a5189b6f1 100644 --- a/paddle/cinn/hlir/framework/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/op_lowering_util.cc @@ -86,7 +86,9 @@ ir::Tensor GetTensor( return lang::Placeholder(node_data->id(), shape_dict.at(node_data->id())); } else { - LOG(FATAL) << "Unsupport dtype: " << dtype; + std::stringstream ss; + ss << "Unsupport dtype: " << dtype; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -319,13 +321,13 @@ std::unordered_map BuildVirtualConsumer( auto output_shape = GetOutputShape(t_node, shape_dict); if (!found && t_node != e_node && e_node) { - auto enode_output_shape = GetOutputShape(e_node, shape_dict); + auto e_node_output_shape = GetOutputShape(e_node, shape_dict); if (std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies()) == - std::accumulate(enode_output_shape.begin(), - enode_output_shape.end(), + std::accumulate(e_node_output_shape.begin(), + e_node_output_shape.end(), 1, std::multiplies())) { virtual_consumers[t_node] = e_node; @@ -739,8 +741,8 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT } lane *= inshape[axes[index]]; if (index == 0 && lane <= max_num_threads) { - LOG(FATAL) - << "Error! lane is less equal than max_num_threads, Please check!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Error! lane is less equal than max_num_threads, Please check!")); } if (lane >= max_num_threads / 2) { if (lane <= max_num_threads) { @@ -805,7 +807,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT ir_sch.Fuse(block_name, {axes[index + 1], axes[index + 1] + 1}); } LoopOrderAssignReduce(ir_sch, block_name, first_axes, target, true); - // fuse axis before reduce to bind blockidx. + // fuse axis before reduce to bind block idx. for (int idx = 0; idx < static_cast(inshape.size() - axes.size()) - 1; ++idx) { ir_sch.Fuse(block_name, {0, 1}); @@ -1181,7 +1183,7 @@ void LoopAssignReduce( // copy loop info form rloops. copy_loop_info(nloops, rloops); } else { - LOG(FATAL) << "Error! Unkown Reduce Type!"; + PADDLE_THROW(phi::errors::InvalidArgument("Error! Unkown Reduce Type!")); } } } @@ -1398,7 +1400,8 @@ void MergeReduceToReduce( n_loops.size() - 1); } } else { - LOG(FATAL) << "not support this type fusion!"; + PADDLE_THROW( + phi::errors::InvalidArgument("not support this type fusion!")); } } } else { @@ -1502,7 +1505,8 @@ void MergeReduceToReduce( ir_sch.SimpleComputeAt(block, loops.back()); } } else { - LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Error! Unkown Reduce Type, Please Check!")); } } } diff --git a/paddle/cinn/hlir/framework/pir/CMakeLists.txt b/paddle/cinn/hlir/framework/pir/CMakeLists.txt index 6a9c87ff05ec6..a0930aea095d9 100755 --- a/paddle/cinn/hlir/framework/pir/CMakeLists.txt +++ b/paddle/cinn/hlir/framework/pir/CMakeLists.txt @@ -1,12 +1,14 @@ -if(NOT CINN_ONLY) - core_gather_headers() - gather_srcs( - cinnapi_src - SRCS - group.cc - utils.cc - op_lowering_impl.cc - op_mapper.cc - op_lowering_util.cc - compilation_task.cc) -endif() +core_gather_headers() +gather_srcs( + cinnapi_src + SRCS + group.cc + utils.cc + op_lowering_group.cc + op_lowering_impl.cc + op_mapper.cc + op_lowering_util.cc + trivial_op_impl.cc + trivial_op_util.cc + compilation_task.cc + compilation_cache.cc) diff --git a/paddle/cinn/hlir/framework/pir/compilation_cache.cc b/paddle/cinn/hlir/framework/pir/compilation_cache.cc new file mode 100644 index 0000000000000..47a38442b58a5 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/compilation_cache.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/compilation_cache.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" + +#include "paddle/common/enforce.h" + +namespace cinn::hlir::framework { + +namespace pir { +void* BackendResource::GetHostFuncPtr() const { + VLOG(4) << "Lookup kernel name: " << host_fn_name_; + void* ptr = backend_compiler_->Lookup(host_fn_name_); + PADDLE_ENFORCE_NOT_NULL(ptr, + phi::errors::InvalidArgument( + "Can't find kernel function %s", host_fn_name_)); + return ptr; +} + +void* BackendResource::GetInferFuncPtr() const { + VLOG(4) << "Lookup infer shape fn name: " << infer_fn_name_; + void* ptr = backend_compiler_->Lookup(infer_fn_name_); + PADDLE_ENFORCE_NOT_NULL( + ptr, + phi::errors::InvalidArgument("Can't find infer shape function %s", + infer_fn_name_)); + return ptr; +} + +std::shared_ptr& BackendResource::GetBackendCompiler() { + return backend_compiler_; +} + +const std::shared_ptr& BackendResource::GetBackendCompiler() + const { + return backend_compiler_; +} + +void BackendResource::SetHostFnName(const std::string& name) { + host_fn_name_ = name; +} + +void BackendResource::SetInferFnName(const std::string& name) { + infer_fn_name_ = name; +} + +pir::CINNKernelInfo BackendResource::GernerateKernelInfo( + const std::shared_ptr& group) const { + pir::CINNKernelInfo kernel_info; + kernel_info.fn_name = host_fn_name_; + kernel_info.fn_ptr = GetHostFuncPtr(); + kernel_info.infer_shape_fn_ptr = GetInferFuncPtr(); + kernel_info.int_args_map = group->int_args_map(); + return kernel_info; +} +} // namespace pir + +bool CompilationCache::Has(const CacheKey& key) const { + const bool has_existed = cache_.find(KeyHash(key)) != cache_.end(); + VLOG(6) << "Check IsExisted in CompilationCache: " << key->FuncName() << " " + << has_existed; + return has_existed; +} + +const CompilationCache::CacheValue& CompilationCache::Get( + const CacheKey& key) const { + PADDLE_ENFORCE_EQ( + Has(key), + true, + phi::errors::NotFound("%s is not in CompliatonCache.", key->FuncName())); + return cache_.at(KeyHash(key)); +} + +pir::CINNKernelInfo CompilationCache::GetKernelInfo(const CacheKey& key) const { + return Get(key)->GetKernelInfo(key); +} + +void CompilationCache::Insert(const CacheKey& key, const CacheValue& value) { + VLOG(6) << "Insert CompilationCache for: " << key->FuncName(); + cache_.insert({KeyHash(key), value}); +} + +void CompilationCache::Clear() { cache_.clear(); } + +size_t CompilationCache::KeyHash(const CacheKey& key) const { + // TODO(Aurelius84): use a better hash function in next pr. + return std::hash{}(key->FuncName()); +} + +} // namespace cinn::hlir::framework diff --git a/paddle/cinn/hlir/framework/pir/compilation_cache.h b/paddle/cinn/hlir/framework/pir/compilation_cache.h new file mode 100644 index 0000000000000..018bd6fd85572 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/compilation_cache.h @@ -0,0 +1,102 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/cinn/backends/compiler.h" +#include "paddle/cinn/common/macros.h" +#include "paddle/cinn/common/target.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" + +namespace cinn::hlir::framework { + +namespace pir { +class OpLoweringGroup; +class BackendResource final { + public: + BackendResource(const Target& target) { + backend_compiler_ = backends::Compiler::Create(target); + } + + BackendResource(const Target& target, + const std::string& host_fn_name, + const std::string& infer_fn_name) + : host_fn_name_(host_fn_name), infer_fn_name_(infer_fn_name) { + backend_compiler_ = backends::Compiler::Create(target); + } + + void* GetHostFuncPtr() const; + void* GetInferFuncPtr() const; + pir::CINNKernelInfo GernerateKernelInfo( + const std::shared_ptr& group) const; + std::shared_ptr& GetBackendCompiler(); + const std::shared_ptr& GetBackendCompiler() const; + void SetHostFnName(const std::string& name); + void SetInferFnName(const std::string& name); + + private: + std::string host_fn_name_; + std::string infer_fn_name_; + // std::string host_code_; + // std::vector device_code_; + std::shared_ptr backend_compiler_; +}; + +class CompilationResult final { + public: + explicit CompilationResult(const Target& target) + : target_(target), backend_resource_(target) {} + + BackendResource& MutableBackendResource() { return backend_resource_; } + const BackendResource& GetBackendResource() const { + return backend_resource_; + } + pir::CINNKernelInfo GetKernelInfo( + const std::shared_ptr& group) { + return backend_resource_.GernerateKernelInfo(group); + } + + private: + Target target_; + BackendResource backend_resource_; +}; +} // namespace pir + +class CompilationCache { + public: + using CacheKey = std::shared_ptr; + using CacheValue = std::shared_ptr; + + static CompilationCache& Instance() { + static CompilationCache instance; + return instance; + } + + bool Has(const CacheKey& key) const; + const CacheValue& Get(const CacheKey& key) const; + pir::CINNKernelInfo GetKernelInfo(const CacheKey& key) const; + void Insert(const CacheKey& key, const CacheValue& value); + void Clear(); + size_t KeyHash(const CacheKey& key) const; + + private: + CompilationCache() = default; + CINN_DISALLOW_COPY_AND_ASSIGN(CompilationCache); + + std::unordered_map cache_; +}; + +} // namespace cinn::hlir::framework diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.cc b/paddle/cinn/hlir/framework/pir/compilation_task.cc index 4e84ef4428515..a93ac960d496a 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.cc +++ b/paddle/cinn/hlir/framework/pir/compilation_task.cc @@ -17,7 +17,7 @@ #include "paddle/cinn/hlir/framework/pir/compilation_task.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/op_lowering.h" -#include "paddle/cinn/ir/module.h" +#include "paddle/common/enforce.h" namespace cinn { namespace hlir { @@ -29,7 +29,6 @@ void GroupCompilationContext::SetLoweredFuncs( funcs.predicate2funcs) { predicates_.push_back(std::move(predicate2func.first)); lowered_funcs_.push_back(std::move(predicate2func.second)); - ++func_size_; } infer_shape_lowered_func_ = std::move(funcs.infer_shape_func); } @@ -43,21 +42,19 @@ std::string GroupCompilationContext::PrintPredicate2Funcs() const { return ss.str(); } -void* GroupCompilationContext::FuncPtr() { - return backend_compiler_->Lookup(host_func_name_); -} - -std::shared_ptr GroupCompilationContext::BackendCompiler() { - return backend_compiler_; -} - void CompilationTask::operator()() { + VLOG(4) << "Run Compilation Task for : " << context_->group_.get(); + if (CompilationCache::Instance().Has(context_->group_)) { + VLOG(4) << "Found cached kernel info for group: " + << context_->group_->FuncName(); + return; + } Lowering(); CodegenAndJit(); } void CompilationTask::Lowering() { - auto op_lowerer = CreateOpLowerer(context_->target_); + auto op_lowerer = CreateOpLowerer(context_->target_); context_->SetLoweredFuncs( op_lowerer.BucketLower(context_->group_, /* apply op schedule = */ false, @@ -77,43 +74,27 @@ void CompilationTask::CodegenAndJit() { } builder.SetInferShapeFunc(context_->infer_shape_lowered_func_); ir::Module ir_module = builder.Build(); - - context_->backend_compiler_ = backends::Compiler::Create(context_->target_); - context_->backend_compiler_->Build(ir_module, ""); + BuildPirCINNKernelInfo(ir_module); } -std::unique_ptr CompilationTask::BuildInstruction() { - std::string fn_name = context_->group_->FuncName(); - std::unique_ptr instr = - std::make_unique(context_->target_, - context_->scope_.get(), - context_->group_->input_names, - context_->group_->output_names, - fn_name); - VLOG(4) << "Lookup kernel name: " << fn_name; - auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name); - CHECK(fn_ptr); - auto* infer_shape_fn_ptr = - context_->backend_compiler_->Lookup(fn_name + "_infer_shape" + fn_name); - CHECK(infer_shape_fn_ptr); - instr->SetLoweredFunc(reinterpret_cast(fn_ptr), fn_name); - instr->Finalize(); - return instr; +pir::CINNKernelInfo CompilationTask::GetCINNKernelInfo() { + if (!CompilationCache::Instance().Has(context_->group_)) { + PADDLE_THROW(phi::errors::NotFound( + "Kernel info has been cached for current group.")); + } + return CompilationCache::Instance().GetKernelInfo(context_->group_); } -pir::CINNKernelInfo CompilationTask::BuildPirCINNKernelInfo() { - std::string fn_name = context_->group_->FuncName(); - VLOG(4) << "Lookup kernel name: " << fn_name; - auto* fn_ptr = context_->backend_compiler_->Lookup(fn_name); - CHECK(fn_ptr); - auto* infer_shape_fn_ptr = - context_->backend_compiler_->Lookup(fn_name + "_infer_shape"); - CHECK(infer_shape_fn_ptr); - pir::CINNKernelInfo cinn_kernel_info; - cinn_kernel_info.fn_ptr = fn_ptr; - cinn_kernel_info.infer_shape_fn_ptr = infer_shape_fn_ptr; - cinn_kernel_info.int_args_map = context_->group_->int_args_map; - return cinn_kernel_info; +void CompilationTask::BuildPirCINNKernelInfo(const ir::Module& module) { + auto compilation_result = + std::make_shared(context_->target_); + pir::BackendResource& backend_resource = + compilation_result->MutableBackendResource(); + backend_resource.GetBackendCompiler()->Build(module, ""); + backend_resource.SetHostFnName(context_->group_->FuncName()); + backend_resource.SetInferFnName(context_->group_->FuncName() + + "_infer_shape"); + CompilationCache::Instance().Insert(context_->group_, compilation_result); } } // namespace framework diff --git a/paddle/cinn/hlir/framework/pir/compilation_task.h b/paddle/cinn/hlir/framework/pir/compilation_task.h index e76f93d206096..69e985afd7869 100644 --- a/paddle/cinn/hlir/framework/pir/compilation_task.h +++ b/paddle/cinn/hlir/framework/pir/compilation_task.h @@ -16,41 +16,33 @@ #include "paddle/cinn/backends/compiler.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/instruction.h" +#include "paddle/cinn/hlir/framework/pir/compilation_cache.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/module.h" namespace cinn { namespace hlir { namespace framework { +class CompilationTask; class GroupCompilationContext { public: GroupCompilationContext(const Target& target, - const pir::GroupPtr& group, - std::shared_ptr scope) - : target_(target), group_(group), scope_(scope) {} + const pir::OpLoweringGroupPtr& group) + : target_(target), group_(group) {} void SetLoweredFuncs(BucketLoweredFuncsWrapper&& funcs); std::string PrintPredicate2Funcs() const; - void* FuncPtr(); - std::shared_ptr BackendCompiler(); private: friend class CompilationTask; - const Target& target_; - const pir::GroupPtr& group_; - std::shared_ptr scope_; - - size_t func_size_ = 0; + const pir::OpLoweringGroupPtr& group_; std::vector predicates_; std::vector lowered_funcs_; ir::LoweredFunc infer_shape_lowered_func_; - std::string host_func_name_; - std::string host_code_; - std::vector device_code_; - std::shared_ptr backend_compiler_; }; class CompilationTask { @@ -59,13 +51,14 @@ class CompilationTask { : context_(context) {} void operator()(); + pir::CINNKernelInfo GetCINNKernelInfo(); + private: void Lowering(); void CodegenAndJit(); std::unique_ptr BuildInstruction(); - pir::CINNKernelInfo BuildPirCINNKernelInfo(); + void BuildPirCINNKernelInfo(const ir::Module& module); - private: GroupCompilationContext* context_; }; diff --git a/paddle/cinn/hlir/framework/pir/group.cc b/paddle/cinn/hlir/framework/pir/group.cc index 706dfcafd6819..befa2e5b12908 100644 --- a/paddle/cinn/hlir/framework/pir/group.cc +++ b/paddle/cinn/hlir/framework/pir/group.cc @@ -46,10 +46,6 @@ std::shared_ptr Group::Clone(::pir::Block* target_block, for (auto* op : this->output_ops) { new_group->output_ops.insert(ops_mapper.at(op)); } - for (const auto& output_value : this->output_values) { - new_group->output_values.push_back(ir_mapping.Lookup(output_value)); - } - return new_group; } diff --git a/paddle/cinn/hlir/framework/pir/group.h b/paddle/cinn/hlir/framework/pir/group.h index 29ff85d099220..8332a3fc82a5a 100644 --- a/paddle/cinn/hlir/framework/pir/group.h +++ b/paddle/cinn/hlir/framework/pir/group.h @@ -63,29 +63,6 @@ struct Group { ::pir::IrMapping& ir_mapping, const Options& option = Options()) const; - const symbol::ShapeOrDataDimExprs& GetShapeOrDataExprs( - const ::pir::Value& value) const { - CHECK(value_to_shape_or_data_exprs_.count(value)) - << "value not found in value_to_shape_or_data_exprs_"; - return value_to_shape_or_data_exprs_.at(value); - } - - void SetShapeOrDataExprs(const ::pir::Value& value, - const symbol::ShapeOrDataDimExprs& shape_or_data) { - auto iter = value_to_shape_or_data_exprs_.find(value); - if (iter == value_to_shape_or_data_exprs_.end()) { - value_to_shape_or_data_exprs_.emplace(value, shape_or_data); - } else { - iter->second = shape_or_data; - } - } - - void set_value_to_shape_or_data_exprs( - const std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs>& - value_to_shape_or_data_exprs) { - value_to_shape_or_data_exprs_ = value_to_shape_or_data_exprs; - } - // distance to last group. int depth{0}; int max_depth{0}; @@ -114,13 +91,6 @@ struct Group { // if as sub-group, used for belong groups. std::unordered_set> belong_groups; - // for op lowering. - std::vector input_names; - std::vector output_names; - std::vector<::pir::Value> output_values; - std::string fn_name{""}; - std::map int_args_map; - struct SharedGroupHasher { size_t operator()(const std::shared_ptr& group) const noexcept { return std::hash()(reinterpret_cast(group.get())); @@ -203,10 +173,6 @@ struct Group { return group_outputs; } - const std::vector<::pir::Value>& GetGroupOutputValues() const { - return this->output_values; - } - std::string GetFuncName() { return "fn_" + group_id + unique_id; } std::vector<::pir::Value> GenerateGroupOutputValues() const { @@ -233,19 +199,6 @@ struct Group { return output_values; } - std::shared_ptr mut_map_expr_ctx() { - CHECK_NOTNULL(map_expr_ctx_); - return map_expr_ctx_; - } - - const adt::MapExprCtx& map_expr_ctx() const { - return *CHECK_NOTNULL(map_expr_ctx_); - } - - void set_map_expr_ctx(const std::shared_ptr& map_expr_ctx) { - map_expr_ctx_ = map_expr_ctx; - } - public: const std::unordered_set, SharedGroupHasher, @@ -277,29 +230,17 @@ struct Group { OpPatternKind kind() const { return op_pattern_kind; } - std::string FuncName() const { - if (fn_name == "") { - // TODO(Aurelius84): Polish this implementation. - const_cast(this)->fn_name = CompatibleInfo::GroupOpsName(ops); - } - return this->fn_name; - } - private: // input groups std::unordered_set, SharedGroupHasher, SharedGroupComparator> producer_groups_; - // output grous + // output groups std::unordered_set, SharedGroupHasher, SharedGroupComparator> consumer_groups_; - std::shared_ptr map_expr_ctx_; - - std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> - value_to_shape_or_data_exprs_; }; std::ostream& operator<<(std::ostream& os, const Group& group); diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_group.cc b/paddle/cinn/hlir/framework/pir/op_lowering_group.cc new file mode 100644 index 0000000000000..8799c84969a04 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_lowering_group.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { + +std::shared_ptr OpLoweringGroup::Clone( + ::pir::Block* target_block, ::pir::IrMapping* ir_mapping) const { + std::vector<::pir::Operation*> new_ops; + // Mapper from original to new ops. + std::unordered_map<::pir::Operation*, ::pir::Operation*> ops_mapper; + auto clone_options = ::pir::CloneOptions(false, true, false); + for (auto* op : ops_) { + VLOG(4) << "clone op :" << op->name(); + auto* new_op = op->Clone(*ir_mapping, clone_options); + // NOTE(dev): Must call block.insert to deal with ownership, otherwise it + // will lead memory-leak. + target_block->insert(target_block->end(), new_op); + new_ops.push_back(new_op); + ops_mapper[op] = new_op; + } + + // Construct Base information for new Group + auto new_group = std::make_shared(new_ops); + for (auto* op : this->output_ops_) { + new_group->output_ops_.insert(ops_mapper.at(op)); + } + for (const auto& output_value : this->output_values_) { + new_group->output_values_.push_back(ir_mapping->Lookup(output_value)); + } + + new_group->input_names_ = this->input_names_; + new_group->output_names_ = this->output_names_; + new_group->fn_name_ = this->fn_name_; + new_group->int_args_map_ = this->int_args_map_; + new_group->alignment_schedule_info_ = this->alignment_schedule_info_; + new_group->reduce_axis_ = this->reduce_axis_; + new_group->loop_ranges_ = this->loop_ranges_; + return new_group; +} + +std::ostream& operator<<(std::ostream& os, const OpLoweringGroup& group) { + auto PrintSymbolDims = [&](const ::pir::Operation& op) { + if (group.value_to_shape_or_data_exprs_.empty()) return; + os << " {"; + for (uint32_t i = 0; i < op.num_operands(); ++i) { + if (i > 0) os << ","; + if (group.HasShapeOrDataExprs(op.operand_source(i))) { + os << "<" << group.GetShapeOrDataExprs(op.operand_source(i)) << ">"; + } + } + os << "} -> {"; + for (uint32_t i = 0; i < op.num_results(); ++i) { + if (i > 0) os << ","; + if (group.HasShapeOrDataExprs(op.result(i))) { + os << "<" << group.GetShapeOrDataExprs(op.result(i)) << ">"; + } + } + os << "}"; + }; + ::pir::IrPrinter printer(os); + os << "Group " << group.group_id() << " :\n"; + for (auto* op : group.ops()) { + printer.PrintOperation(op); + PrintSymbolDims(*op); + os << "\n"; + } + return os; +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_group.h b/paddle/cinn/hlir/framework/pir/op_lowering_group.h new file mode 100644 index 0000000000000..aaa2f31f0a60c --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/op_lowering_group.h @@ -0,0 +1,313 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include +#include "glog/logging.h" + +#include "paddle/cinn/common/context.h" +#include "paddle/cinn/hlir/framework/op.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/pir/include/core/builtin_type_interfaces.h" +#include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/core/value.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace cinn { + +namespace adt { +class MapExprCtx; +} // namespace adt + +namespace hlir { +namespace framework { +namespace pir { +class OpLoweringGroup { + public: + OpLoweringGroup() = default; + OpLoweringGroup(const OpLoweringGroup&) = delete; + OpLoweringGroup(OpLoweringGroup&&) = delete; + + explicit OpLoweringGroup(const std::vector<::pir::Operation*>& group_ops) + : ops_(group_ops) {} + + explicit OpLoweringGroup(std::initializer_list<::pir::Operation*> group_ops) + : ops_(group_ops) {} + + struct SharedGroupHasher { + size_t operator()( + const std::shared_ptr& group) const noexcept { + return std::hash()(group->group_id()); + } + }; + struct SharedGroupComparator { + bool operator()( + const std::shared_ptr& first, + const std::shared_ptr& second) const noexcept { + return first->group_id() == second->group_id(); + } + }; + + std::vector<::pir::Value> GetGroupOutputValues() const { + std::unordered_set<::pir::Operation*> group_ops_set(this->ops_.begin(), + this->ops_.end()); + + std::vector<::pir::Value> output_values; + for (auto* op : this->ops_) { + for (size_t i = 0; i < op->num_results(); ++i) { + auto result = op->result(i); + if (!result) { + continue; + } + for (auto use_iter = result.use_begin(); use_iter != result.use_end(); + ++use_iter) { + auto* use_op = use_iter->owner(); + if (group_ops_set.find(use_op) == group_ops_set.end()) { + output_values.push_back(result); + break; + } + } + } + } + return output_values; + } + + std::unordered_set<::pir::Value> GetInputOpValues() const { + std::unordered_set<::pir::Value> group_inputs; + + std::unordered_set<::pir::Operation*> ops_set; + for (auto op : this->ops_) { + ops_set.insert(op); + } + + // count all op's input Value + for (auto op : this->ops_) { + for (auto& value : op->operands_source()) { + if (!value || !value.type()) { + continue; + } + + if (!ops_set.count(value.defining_op())) { + // if the input value owner op is not in OpSet, it's the group's input + group_inputs.insert(value); + continue; + } + } + } + + return group_inputs; + } + + std::unordered_set<::pir::Value> GetOutputOpValues() const { + std::unordered_set<::pir::Value> group_outputs; + + for (auto op : this->output_ops_) { + for (auto& result : op->results()) { + if (!result || result.type()) { + continue; + } + + group_outputs.insert(result); + } + } + return group_outputs; + } + + std::string FuncName() const { + if (fn_name_ == "") { + // TODO(Aurelius84): Polish this implementation. + const_cast(this)->fn_name_ = + CompatibleInfo::GroupOpsName(ops_); + } + return this->fn_name_; + } + + const symbol::ShapeOrDataDimExprs& GetShapeOrDataExprs( + const ::pir::Value& value) const { + CHECK(value_to_shape_or_data_exprs_.count(value)) + << "value not found in value_to_shape_or_data_exprs_"; + return value_to_shape_or_data_exprs_.at(value); + } + + bool HasShapeOrDataExprs(const ::pir::Value& value) const { + return value_to_shape_or_data_exprs_.count(value); + } + + void SetShapeOrDataExprs(const ::pir::Value& value, + const symbol::ShapeOrDataDimExprs& shape_or_data) { + auto iter = value_to_shape_or_data_exprs_.find(value); + if (iter == value_to_shape_or_data_exprs_.end()) { + value_to_shape_or_data_exprs_.emplace(value, shape_or_data); + } else { + iter->second = shape_or_data; + } + } + + void WalkOps(const std::function& VisitOp) const { + for (const auto& op : ops_) { + VisitOp(op); + } + } + + const std::vector<::pir::Operation*>& ops() const { return ops_; } + + std::vector<::pir::Operation*>& mut_ops() { return ops_; } + + void SetOps(const std::vector<::pir::Operation*>& new_ops) { ops_ = new_ops; } + + const std::vector& input_names() const { + return this->input_names_; + } + + std::vector& mut_input_names() { return this->input_names_; } + + const std::vector& output_names() const { + return this->output_names_; + } + + std::vector& mut_output_names() { return this->output_names_; } + + const std::vector<::pir::Value>& output_values() const { + return this->output_values_; + } + + std::vector<::pir::Value>& mut_output_values() { + return this->output_values_; + } + + const std::unordered_set<::pir::Operation*>& output_ops() const { + return this->output_ops_; + } + + std::unordered_set<::pir::Operation*>& mut_output_ops() { + return this->output_ops_; + } + + std::shared_ptr mut_map_expr_ctx() { + CHECK_NOTNULL(map_expr_ctx_); + return map_expr_ctx_; + } + + const adt::MapExprCtx& map_expr_ctx() const { + return *CHECK_NOTNULL(map_expr_ctx_); + } + + void set_value_to_shape_or_data_exprs( + const std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs>& + value_to_shape_or_data_exprs) { + value_to_shape_or_data_exprs_ = value_to_shape_or_data_exprs; + } + + void set_map_expr_ctx(const std::shared_ptr& map_expr_ctx) { + map_expr_ctx_ = map_expr_ctx; + } + + const std::string& group_id() const { return this->group_id_; } + + OpPatternKind op_pattern_kind() const { return this->op_pattern_kind_; } + + void set_op_pattern_kind(OpPatternKind pattern_kind) { + this->op_pattern_kind_ = pattern_kind; + } + + const std::vector& loop_ranges() const { return loop_ranges_; } + + void set_loop_ranges(const std::vector& loop_ranges) { + this->loop_ranges_ = loop_ranges; + } + + const std::vector& loop_ranges_expr() const { + return loop_ranges_expr_; + } + + void set_loop_ranges_expr( + const std::vector& loop_ranges_expr) { + this->loop_ranges_expr_ = loop_ranges_expr; + } + + const std::vector& reduce_axis() const { return reduce_axis_; } + + void set_reduce_axis(const std::vector& reduce_axis) { + this->reduce_axis_ = reduce_axis; + } + + const std::map& int_args_map() const { + return this->int_args_map_; + } + + std::map& mut_int_args_map() { + return this->int_args_map_; + } + + private: + using alignment_schedule_info_t = std::unordered_map< + ::pir::Operation*, + std::vector>; + + public: + const alignment_schedule_info_t& alignment_schedule_info() const { + return alignment_schedule_info_; + } + + alignment_schedule_info_t& mut_alignment_schedule_info() { + return alignment_schedule_info_; + } + + void set_alignment_schedule_info( + const std::unordered_map< + ::pir::Operation*, + std::vector>& + alignment_schedule_info) { + this->alignment_schedule_info_ = alignment_schedule_info; + } + + std::shared_ptr Clone(::pir::Block* target_block, + ::pir::IrMapping* ir_mapping) const; + + private: + friend std::ostream& operator<<(std::ostream&, const OpLoweringGroup&); + + // group id, consisted of op's id. + std::string group_id_{common::UniqName("group_")}; + // op in this group + std::vector<::pir::Operation*> ops_; + // output ops of the group. + std::unordered_set<::pir::Operation*> output_ops_; + // op pattern kind. + OpPatternKind op_pattern_kind_{kElementWise}; + + std::vector input_names_; + std::vector output_names_; + std::vector<::pir::Value> output_values_; + std::string fn_name_{""}; + std::map int_args_map_; + + alignment_schedule_info_t alignment_schedule_info_; + std::vector reduce_axis_; + std::vector loop_ranges_; + std::vector loop_ranges_expr_; + + std::shared_ptr map_expr_ctx_; + std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> + value_to_shape_or_data_exprs_; +}; + +std::ostream& operator<<(std::ostream& os, const OpLoweringGroup& group); +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc index 828437f0f4abe..bab37b959ddfc 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.cc @@ -19,8 +19,10 @@ #include "paddle/cinn/adt/map_expr_ctx.h" #include "paddle/cinn/ast_gen_ius/tensor_group.h" #include "paddle/cinn/backends/codegen_cuda_util.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/compile_error.h" #include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/hlir/op/external_api_registry.h" #include "paddle/cinn/hlir/pe/map_expr_to_ir.h" @@ -29,10 +31,17 @@ #include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/eliminate_common_global_memory_read.h" +#include "paddle/cinn/optim/if_fusion.h" #include "paddle/cinn/optim/schedule_block_dce.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/common/ddim.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" +#include "paddle/cinn/ir/group_schedule/config/group_tile_config.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" PD_DECLARE_bool(cinn_use_cuda_vectorize); PD_DECLARE_bool(cinn_enable_map_expr); @@ -64,19 +73,101 @@ NodeAttr CollectAttrs(const ::pir::Operation& op) { } // namespace details +std::shared_ptr OpLowererImpl::GetGroupInfo( + const FusionGroupInfo& fusion_group_info, + const OpLoweringGroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map) { + std::shared_ptr group_info = std::make_shared(); + group_info->data_space = fusion_group_info.loop_ranges; + group_info->reduce_axis = fusion_group_info.reduce_axis; + group_info->reduce_var_names = + std::set(fusion_group_info.reduce_var_name.begin(), + fusion_group_info.reduce_var_name.end()); + + for (auto& op : group->output_ops()) { + group_info->direct_output_var_names.insert(ValueName(op->result(0))); + // collect all output tensor. + if (op->name() == "cinn_op.yield_store") { + auto input_var_name = ValueName(op->operand_source(0)); + if (group_info->broadcast_info.count(input_var_name)) { + auto base_info = group_info->broadcast_info[input_var_name]; + base_info.with_constrain = true; + group_info->broadcast_info[ValueName(op->result(0))] = base_info; + } + } + for (auto opresult : op->results()) { + if (tensor_map.count(opresult) == 0) { + continue; + } + group_info->direct_output_var_names.insert(ValueName(opresult)); + } + } + + for (auto& val : group->output_values()) { + group_info->direct_output_var_names.insert(ValueName(val)); + } + return group_info; +} + +std::shared_ptr OpLowererImpl::GetGroupInfo( + const OpLoweringGroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map) { + std::shared_ptr group_info = std::make_shared(); + group_info->data_space = group->loop_ranges(); + group_info->reduce_axis = group->reduce_axis(); + for (auto op : group->ops()) { + if (CompatibleInfo::OpKind(*op) == OpPatternKind::kReduction) { + group_info->reduce_var_names.insert(ValueName(op->result(0))); + } + } + + BuildBroadcastInfo(group, group_info); + + for (auto& op : group->output_ops()) { + group_info->direct_output_var_names.insert(ValueName(op->result(0))); + // collect all output tensor. + if (op->name() == "cinn_op.yield_store") { + auto input_var_name = ValueName(op->operand_source(0)); + if (group_info->broadcast_info.count(input_var_name)) { + auto base_info = group_info->broadcast_info[input_var_name]; + base_info.with_constrain = true; + group_info->broadcast_info[ValueName(op->result(0))] = base_info; + } + } + for (auto opresult : op->results()) { + if (tensor_map.count(opresult) == 0) { + continue; + } + group_info->direct_output_var_names.insert(ValueName(opresult)); + } + } + + for (const auto& val : group->output_values()) { + if (val.defining_op()->name() == "cinn_op.reshape" && + erase_reshape.count(val.defining_op())) { + group_info->direct_output_var_names.insert( + ValueName(val.defining_op()->operand_source(0))); + } else { + group_info->direct_output_var_names.insert(ValueName(val)); + } + } + return group_info; +} + OpLowererImpl::OpLowererImpl(const Target& target) : target_(target) { name_gene_ = new PrettyNamer(); } -std::vector OpLowererImpl::Lower(const GroupPtr& group, - bool apply_op_schedule, - bool apply_group_schedule, - bool apply_pass) { - VLOG(3) << "Lowering Group : " << group->group_id - << " , Op Pattern : " << group->op_pattern_kind; - group->input_names.clear(); - group->output_names.clear(); - switch (group->op_pattern_kind) { +std::vector OpLowererImpl::Lower( + const OpLoweringGroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + bool apply_pass) { + VLOG(3) << "Lowering Group : " << group->group_id() + << " , Op Pattern : " << group->op_pattern_kind(); + group->mut_input_names().clear(); + group->mut_output_names().clear(); + switch (group->op_pattern_kind()) { case framework::kElementWise: case framework::kBroadcast: case framework::kInjective: @@ -90,26 +181,30 @@ std::vector OpLowererImpl::Lower(const GroupPtr& group, apply_group_schedule, &OpLowererImpl::ReduceScheduleDetermineFunction); case framework::kOutFusible: - LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!"; + PADDLE_THROW(phi::errors::Unimplemented( + "Group Pattern Kind kOutFusible Is Not Implemented!")); case framework::kNonFusible: return LowerGroup(group, apply_op_schedule, apply_group_schedule, &OpLowererImpl::NonFusibleScheduleDetermineFunction); default: - LOG(FATAL) << "Group Pattern Kind Is Unknown!"; + PADDLE_THROW( + phi::errors::InvalidArgument("Group Pattern Kind Is Unknown!")); } } -BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group, - bool apply_op_schedule, - bool apply_group_schedule, - bool apply_pass) { +BucketLoweredFuncsWrapper OpLowererImpl::BucketLower( + const OpLoweringGroupPtr& group, + bool apply_op_schedule, + bool apply_group_schedule, + bool apply_pass) { VLOG(4) << "BucketLower Group : \n" << *group; // 1.Do compute, lower and schedule for each op. - auto& ops = group->ops; + const auto& ops = group->ops(); if (ops.size() == 1 && ops[0]->name() == "custom_call") { return {{{ir::Expr(1), LowerCustomCall(group)[0]}}, ir::LoweredFunc()}; } + std::vector group_func_arg_tensors; std::unordered_map<::pir::Value, ir::Tensor> tensor_map; // for some op, it will output more tmp value and regard as @@ -124,6 +219,13 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group, &tensor_map, &tmp_tensor_info); + // =========== OpFusion ============ + + func_bodies = OperationFusion(ops, func_bodies); + const auto& fusion_group_info = GetFusionGroupInfo(func_bodies); + + // =========== CodeGen And Optimizer ================ + // 2.Do group schedule. ir::ModuleExpr mod_expr(func_bodies); ir::IRSchedule ir_sch( @@ -131,17 +233,36 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group, ir_sch.MergeExprs(); std::vector> cond2func_bodies; VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); + + std::unordered_set<::pir::Value> inner_genevalue; + std::unordered_set<::pir::Operation*> ops_set(ops.begin(), ops.end()); + for (auto* op : ops) { + for (size_t i = 0; i < op->num_results(); ++i) { + inner_genevalue.insert(op->result(i)); + } + } + if (apply_group_schedule) { std::unordered_set output_tensor_names; for (auto value : group->GetGroupOutputValues()) { output_tensor_names.insert(ValueName(value)); } + std::shared_ptr group_info = + GetGroupInfo(fusion_group_info, group, tensor_map); std::unique_ptr group_scheduler = - ir::GroupScheduler::Make( - &ir_sch, output_tensor_names, target_, /* is_dy_shape = */ true); + ir::GroupScheduler::Make(&ir_sch, + output_tensor_names, + target_, + /* is_dy_shape = */ true, + group_info); + + VLOG(4) << "Start apply group_scheduler->Schedule()"; group_scheduler->Schedule(); + VLOG(4) << "End apply group_scheduler->Schedule()"; + cond2func_bodies = group_scheduler->GetIRs(); + VLOG(4) << "End group_scheduler->GetIRs"; } else { cond2func_bodies.emplace_back(ir::Expr(true), ir_sch.GetModule().GetExprs()[0]); @@ -157,21 +278,24 @@ BucketLoweredFuncsWrapper OpLowererImpl::BucketLower(const GroupPtr& group, } std::vector group_func_arg_tensors_copy = group_func_arg_tensors; std::vector group_func_args; + std::vector infer_shape_tensor_args; std::vector funcs = PostProcess(group, tensor_map, apply_group_schedule, {scheduled_func_bodies}, &group_func_arg_tensors_copy, - &group_func_args); + &group_func_args, + &infer_shape_tensor_args); CHECK_EQ(funcs.size(), cond2func_bodies.size()); BucketLoweredFuncsWrapper funcs_wrapper; for (int i = 0; i < funcs.size(); ++i) { funcs_wrapper.predicate2funcs.emplace_back(cond2func_bodies[i].first, funcs[i]); } - funcs_wrapper.infer_shape_func = GenerateInferShapeFunc( - group, group_func_arg_tensors_copy, group_func_args); + funcs_wrapper.infer_shape_func = + GenerateInferShapeFunc(group, infer_shape_tensor_args, group_func_args); + VLOG(4) << "End This function."; return funcs_wrapper; } @@ -215,7 +339,7 @@ bool OpLowererImpl::DyShapeScheduleDetermineFunction(::pir::Operation* op) { } void OpLowererImpl::LowerOpsForMapExpr( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector<::pir::Operation*>& ops, std::vector* group_func_arg_tensors, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { @@ -250,7 +374,7 @@ void OpLowererImpl::LowerOpsForMapExpr( /* Most of below codes copies from `PostProcess` function */ std::vector OpLowererImpl::LowerMapExpr( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, bool apply_group_schedule, @@ -280,8 +404,10 @@ std::vector OpLowererImpl::LowerMapExpr( for (auto value : group->GetGroupOutputValues()) { output_tensor_names.insert(ValueName(value)); } + + std::shared_ptr group_info; ir::StaticShapeGroupScheduler group_scheduler( - &ir_sch, output_tensor_names, target_); + &ir_sch, output_tensor_names, target_, group_info); group_scheduler.MapExprSchedule(); VLOG(3) << "After group schedule, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); @@ -291,21 +417,23 @@ std::vector OpLowererImpl::LowerMapExpr( // including preparing function args and temporary variables, // applying low-level optimization passes, etc. std::vector group_func_args; + std::vector infer_shape_tensor_args; return PostProcess(group, *tensor_map, apply_op_schedule, {ir_sch.GetModule().GetExprs()[0]}, group_func_arg_tensors, - &group_func_args); + &group_func_args, + &infer_shape_tensor_args); } std::vector OpLowererImpl::LowerGroup( - const GroupPtr& group, + const OpLoweringGroupPtr& group, bool apply_op_schedule, bool apply_group_schedule, ScheduleDetermineFunction schedule_determine_func) { // 1.Do compute, lower and schedule for each op. - auto& ops = group->ops; + const auto& ops = group->ops(); if (ops.size() == 1 && ops[0]->name() == "custom_call") { return LowerCustomCall(group); } @@ -323,40 +451,217 @@ std::vector OpLowererImpl::LowerGroup( &group_func_arg_tensors, &tensor_map); } - std::vector func_bodies = LowerOps(group, - ops, - do_op_schedule, - schedule_determine_func, - &group_func_arg_tensors, - &tensor_map, - &tmp_tensor_info); + std::vector func_bodies = + LowerOps(group, + ops, + do_op_schedule, + &OpLowererImpl::DyShapeScheduleDetermineFunction, + &group_func_arg_tensors, + &tensor_map, + &tmp_tensor_info); + + // func_bodies = TrivialOpFusion(ops, func_bodies); + std::unordered_set<::pir::Value> inner_genevalue; + std::unordered_set<::pir::Operation*> ops_set(ops.begin(), ops.end()); + for (auto* op : ops) { + for (size_t i = 0; i < op->num_results(); ++i) { + inner_genevalue.insert(op->result(i)); + } + } // 2.Do group schedule. ir::ModuleExpr mod_expr(func_bodies); - ir::IRSchedule ir_sch(mod_expr); - ir_sch.MergeExprs(); - VLOG(3) << "After lower, ir is: \n" << ir_sch.GetModule().GetExprs().at(0); - if (apply_group_schedule) { - DoGroupSchedule(ir_sch, group, tensor_map, tmp_tensor_info); - VLOG(3) << "After group schedule, ir is: \n" - << ir_sch.GetModule().GetExprs().at(0); + std::shared_ptr ir_sch = + std::make_shared(mod_expr); + + auto have_dy_shape = false; + for (auto d : group->loop_ranges()) { + if (d < 0) { + have_dy_shape = true; + } + } + if (have_dy_shape) { + ir_sch = std::make_shared( + mod_expr, -1, false, cinn::utils::ErrorMessageLevel::kGeneral, true); } + ir_sch->MergeExprs(); + VLOG(3) << "After lower, ir is: \n" << ir_sch->GetModule().GetExprs().at(0); + // if (apply_group_schedule) { + DoGroupSchedule(*(ir_sch.get()), group, tensor_map, tmp_tensor_info); + VLOG(3) << "After group schedule, ir is: \n" + << ir_sch->GetModule().GetExprs().at(0); + // } // 3.Do post-processing, // including preparing function args and temporary variables, // applying low-level optimization passes, etc. std::vector group_func_args; + std::vector infer_shape_args; return PostProcess(group, tensor_map, do_op_schedule, - {ir_sch.GetModule().GetExprs().at(0)}, + {ir_sch->GetModule().GetExprs().at(0)}, &group_func_arg_tensors, - &group_func_args); + &group_func_args, + &infer_shape_args); +} + +void OpLowererImpl::BuildBroadcastInfo(const OpLoweringGroupPtr& group, + std::shared_ptr group_info) { + // TODO(phlrain): this is primary verion for loop aligment + // will be update by a new method + auto& align_info = group->mut_alignment_schedule_info(); + + auto& ops = group->ops(); + for (auto op1 : ops) { + auto it = align_info.find(op1); + if (it == align_info.end()) { + continue; + } + if (op1->name() == "cinn_op.generate_shape") { + continue; + } + + if (it->second.size() > 1) { + for (size_t i = 0; i < it->second.size(); ++i) { + } + // TODO(phlran): merge to factor info here + it->second.front().factor_info = it->second.back().factor_info; + it->second.resize(1); + } + + PADDLE_ENFORCE_EQ( + it->second.size(), + 1, + phi::errors::Unimplemented("%s, only suppopt one transform yet", + it->first->name())); + + if (it->second[0].type == ScheduleAlignType::kBroadcast) { + // get broadcast op + auto broadcast_axes = it->second[0].axis_info; + auto output_shape = it->second[0].factor_info; + + phi::DDim in_dim; + + if (it->first->name() == "cinn_op.reshape") { + // TODO(phlrain): deal with reshape in a better way + if (it->first->result(0).use_count() == 1 && + it->first->result(0).first_use().owner()->isa<::pir::YieldOp>()) { + continue; + } + } + + if ((it->first->name() != "cinn_op.reshape") && + (it->first->name() != "cinn_op.broadcast") && + (it->first->num_operands() == 1)) { + in_dim = it->first->operand_source(0) + .type() + .dyn_cast() + .dims(); + } else { + in_dim = it->first->result(0) + .type() + .dyn_cast() + .dims(); + } + + cinn::ir::BroadcastInfo info; + if (in_dim.size() == 1u && in_dim[0] == 1u) { + info.full_broadcast = true; + for (size_t i = 0; i < output_shape.size(); ++i) { + info.broadcast_axes.push_back(i); + info.output_shape.push_back(-1); + info.output_dim_expr.push_back(group->loop_ranges_expr()[i]); + } + } else if (in_dim.size() == broadcast_axes.size()) { + if (in_dim.size() != output_shape.size()) { + info.split_first = true; + + if (broadcast_axes.size() == 1) { + std::vector temp_shape(output_shape.size(), 1); + temp_shape[broadcast_axes[0]] = output_shape[broadcast_axes[0]]; + info.split_info.emplace_back(0, temp_shape); + + for (size_t i = 0; i < output_shape.size(); ++i) { + if (i != broadcast_axes[0]) { + info.broadcast_axes.push_back(i); + info.output_shape.push_back(output_shape[i]); + } + } + } else { + throw std::runtime_error("not support multi dim broadcast yet"); + } + } else { + for (size_t i = 0; i < broadcast_axes.size(); ++i) { + if (in_dim[i] < 0 || output_shape[broadcast_axes[i]] < 0) { + continue; + } + if (in_dim[i] != output_shape[broadcast_axes[i]]) { + if (in_dim[i] != 1) { + throw std::runtime_error("Only support 1 - D broadcast "); + } + info.broadcast_axes.push_back(i); + info.output_shape.push_back(output_shape[broadcast_axes[i]]); + } + } + } + } else { + // only deal with broadcast axes + std::set axes_set; + for (size_t i = 0; i < broadcast_axes.size(); ++i) { + axes_set.insert(broadcast_axes[i]); + if (in_dim[broadcast_axes[i]] != 1) { + throw std::runtime_error("Only support 1 - D broadcast "); + } + + info.broadcast_axes.push_back(broadcast_axes[i]); + info.output_shape.push_back(output_shape[broadcast_axes[i]]); + } + } + + for (size_t i = 0; i < it->first->num_operands(); ++i) { + if (!align_info.count(it->first->operand_source(i).defining_op())) { + info.first_broadcast = true; + break; + } + } + + auto op_out = it->first->result(0); + info.op_name = it->first->name(); + + if (op_out.use_count() == 1 && + op_out.first_use().owner()->name() == "cf.yield") { + info.with_constrain = true; + } + + if (erase_reshape.count(op_out.first_use().owner())) { + info.with_constrain = true; + } + + group_info->broadcast_info[ValueName(op_out)] = info; + + for (auto use_it = op_out.use_begin(); use_it != op_out.use_end(); + ++use_it) { + if (use_it->owner()->name() == "cf.yield") { + continue; + } + if (CompatibleInfo::OpKind(*(use_it->owner())) == + framework::kBroadcast) { + if (!info.full_broadcast) { + group_info->broadcast_to_elementwise[ValueName( + use_it->owner()->result(0))] = info; + } + } + } + } else { + throw std::runtime_error("only supportbroadcast type for now"); + } + } } std::vector OpLowererImpl::LowerCustomCall( - const GroupPtr& group) { - auto& ops = group->ops; + const OpLoweringGroupPtr& group) { + const auto& ops = group->ops(); CHECK_EQ(ops.size(), 1); ::pir::Operation* op = ops[0]; std::unordered_map<::pir::Value, ir::Tensor> tensor_map; @@ -401,31 +706,49 @@ std::vector OpLowererImpl::LowerCustomCall( } std::vector OpLowererImpl::PostProcess( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, std::vector func_bodies, std::vector* group_func_arg_tensors, - std::vector* group_func_args) { + std::vector* group_func_args, + std::vector* infer_shape_arg_tensor) { // 1.Prepare function args - group->input_names.clear(); + group->mut_input_names().clear(); std::unordered_set arg_name_set; for (auto& arg_tensor : *group_func_arg_tensors) { // input data name. - group->input_names.push_back(arg_tensor->name); + group->mut_input_names().push_back(arg_tensor->name); // input args (*group_func_args) .emplace_back(arg_tensor->buffer, ir::Argument::IO::kInput); arg_name_set.insert(arg_tensor->buffer->name); } - group->output_names.clear(); + group->mut_output_names().clear(); + // collect all output tensor. for (auto op_result : group->GetGroupOutputValues()) { if (tensor_map.count(op_result) == 0) { continue; } auto tensor = tensor_map.at(op_result); + if (group->HasShapeOrDataExprs(op_result)) { + tensor->shape.clear(); + for (size_t i = 0; + i < group->GetShapeOrDataExprs(op_result).shape().size(); + ++i) { + ir::Dim t(tensor->name, + group->GetShapeOrDataExprs(op_result).shape()[i]); + tensor->shape.push_back(t->dim_expr); + } + } + infer_shape_arg_tensor->push_back(tensor); + if ((op_result.defining_op()->name() == "cinn_op.reshape") && + erase_reshape.count(op_result.defining_op())) { + tensor = tensor_map.at(op_result.defining_op()->operand_source(0)); + } + if (arg_name_set.count(tensor->buffer->name) != 0) { continue; } @@ -433,7 +756,7 @@ std::vector OpLowererImpl::PostProcess( // output arg tensors group_func_arg_tensors->push_back(tensor); // output args - group->output_names.push_back(tensor->name); + group->mut_output_names().push_back(tensor->name); (*group_func_args).emplace_back(tensor->buffer, ir::Argument::IO::kOutput); arg_name_set.insert(tensor->buffer->name); } @@ -443,7 +766,7 @@ std::vector OpLowererImpl::PostProcess( for (auto arg : (*group_func_args)) { args_set.insert(arg.name()); } - for (auto& op : group->ops) { + for (const auto& op : group->ops()) { // collect all output tensor. for (auto opresult : op->results()) { if (tensor_map.count(opresult) == 0) { @@ -453,9 +776,9 @@ std::vector OpLowererImpl::PostProcess( if (args_set.count("_" + tensor->name) != 0) { continue; } - group->output_values.push_back(opresult); + group->mut_output_values().push_back(opresult); group_func_arg_tensors->push_back(tensor); - group->output_names.push_back(tensor->name); + group->mut_output_names().push_back(tensor->name); group_func_args->emplace_back(tensor->buffer, ir::Argument::IO::kOutput); } @@ -482,18 +805,18 @@ std::vector OpLowererImpl::PostProcess( int_args_set.insert(symbol_name); group_func_args->emplace_back( ir::_Var_::Make(symbol_name, cinn::common::Int(64))); - group->int_args_map[non_tensor_arg_idx++] = {tensor_arg_idx, - tensor_arg_dim_idx}; - VLOG(4) << "device kernel func's " << non_tensor_arg_idx << " is from " + group->mut_int_args_map()[non_tensor_arg_idx++] = {tensor_arg_idx, + tensor_arg_dim_idx}; + VLOG(4) << "device kernel func's " << symbol_name << " is from " << tensor_arg_idx << ".shape(" << tensor_arg_dim_idx << ")"; } } } - std::vector lowered_funcs; for (ir::Expr func_body : func_bodies) { - optim::EliminateDeadScheduleBlock(&(func_body), group->output_names); + optim::EliminateDeadScheduleBlock(&(func_body), group->output_names()); #ifdef CINN_WITH_CUDA + optim::EliminateCommonGlobalMemoryRead(&(func_body)); optim::OptimizeExprGPU(&(func_body)); #endif @@ -515,7 +838,7 @@ std::vector OpLowererImpl::PostProcess( } std::vector OpLowererImpl::LowerOps( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, @@ -524,20 +847,46 @@ std::vector OpLowererImpl::LowerOps( std::unordered_map* tmp_tensor_info) { auto& strategy = Operator::GetAttrs("CINNStrategy"); std::vector func_bodies; + std::unordered_set<::pir::Value> inner_used_value; + for (auto* op : ops) { + for (size_t i = 0; i < op->num_operands(); ++i) { + inner_used_value.insert(op->operand_source(i)); + } + } + + std::unordered_set<::pir::Operation*> not_used_op; + for (auto* op : ops) { + bool used = false; + for (size_t i = 0; i < op->num_results(); ++i) { + if (inner_used_value.count(op->result(i))) { + used = true; + break; + } + } + + if (!used) { + not_used_op.insert(op); + } + } + for (auto* op : ops) { VLOG(4) << "start lowering op:" << op->name(); + std::string cinn_op_name = CompatibleInfo::OpName(*op); + + VLOG(4) << "cinn op name " << cinn_op_name << std::endl; + // 1.Select Op impl std::vector op_func_arg_tensors = CollectInputTensor(group, op, group_func_arg_tensors, tensor_map); VLOG(4) << "input size:" << op_func_arg_tensors.size(); - std::string cinn_op_name = CompatibleInfo::OpName(*op); const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); std::shared_ptr op_impl = nullptr; if (FLAGS_cinn_bucket_compile) { std::vector out_types; std::vector> out_shapes; CollectOutputInfo(op, &out_types, &out_shapes, group); + CHECK_EQ(out_types.size(), out_shapes.size()); VLOG(4) << "out_types.size(): " << out_types.size(); NodeAttr node_attrs = details::CollectAttrs(*op); @@ -546,7 +895,7 @@ std::vector OpLowererImpl::LowerOps( StrategyFunctionSymbolic strategy = strategy_map[cinn_op]; CHECK(static_cast(strategy)) << " cinn_op_name: " << cinn_op_name - << "has no CINNStrategySymbolic registered."; + << " has no CINNStrategySymbolic registered."; op_impl = OpStrategy::SelectImpl(strategy(node_attrs, op_func_arg_tensors, out_types, @@ -568,13 +917,8 @@ std::vector OpLowererImpl::LowerOps( std::vector funcs = DoOpLower( op_impl, op, tensor_map, tmp_tensor_info, &op_func_arg_tensors); - if (apply_op_schedule && (this->*schedule_determine_func)(op)) { - // 3.Perform the schedule of Op - func_bodies.push_back(DoOpSchedule(op_impl, op_func_arg_tensors, funcs)); - } else { - for (const ir::LoweredFunc& func : funcs) { - func_bodies.push_back(func->body); - } + for (const ir::LoweredFunc& func : funcs) { + func_bodies.push_back(func->body); } } @@ -688,22 +1032,34 @@ ir::Expr OpLowererImpl::DoOpSchedule( ir::Expr OpLowererImpl::DoGroupSchedule( ir::IRSchedule& ir_sch, - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, const std::unordered_map& tmp_tensor_info) { VLOG(3) << "using StaticShapeGroupScheduler to schedule group."; + bool have_dy_shape = false; + for (auto d : group->loop_ranges()) { + if (d < 0) { + have_dy_shape = true; + } + } + + std::shared_ptr group_info = GetGroupInfo(group, tensor_map); + std::unordered_set output_tensor_names; for (auto value : group->GetGroupOutputValues()) { output_tensor_names.insert(ValueName(value)); } std::unique_ptr group_scheduler = - ir::GroupScheduler::Make( - &ir_sch, output_tensor_names, target_, /* is_dy_shape = */ false); + ir::GroupScheduler::Make(&ir_sch, + output_tensor_names, + target_, + /* is_dy_shape = */ true, + group_info); group_scheduler->Schedule(); return ir_sch.GetModule().GetExprs().at(0); } -ir::Tensor OpLowererImpl::GetTensor(const GroupPtr& group, +ir::Tensor OpLowererImpl::GetTensor(const OpLoweringGroupPtr& group, const ::pir::Value& value) { auto type_info = value.type().dyn_cast(); auto dtype = type_info.dtype(); @@ -722,21 +1078,28 @@ ir::Tensor OpLowererImpl::GetTensor(const GroupPtr& group, } } }; + if (FLAGS_cinn_bucket_compile) { std::vector sym_shape; ForEachDimExpr( [&](const auto& sym) { sym_shape.emplace_back(input_id, sym); }); + if (sym_shape.empty()) { + sym_shape.emplace_back(input_id, symbol::DimExpr{1}); + } return lang::CreatePlaceHolder( sym_shape, CompatibleInfo::ConvertIRType(dtype), input_id); } else { - return lang::CreatePlaceHolder(::common::vectorize(type_info.dims()), - CompatibleInfo::ConvertIRType(dtype), - input_id); + auto shape = ::common::vectorize(type_info.dims()); + if (shape.empty()) { + shape.push_back(1); + } + return lang::CreatePlaceHolder( + shape, CompatibleInfo::ConvertIRType(dtype), input_id); } } std::vector OpLowererImpl::CollectInputTensor( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const ::pir::Operation* op, std::vector* func_args, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map) { @@ -773,7 +1136,7 @@ std::vector OpLowererImpl::CollectInputTensor( void OpLowererImpl::CollectOutputInfo(::pir::Operation* op, std::vector* out_types, std::vector>* out_shapes, - const GroupPtr& group) { + const OpLoweringGroupPtr& group) { auto op_results = op->results(); for (auto& out_value : op_results) { std::string output_id = ValueName(out_value); @@ -783,6 +1146,9 @@ void OpLowererImpl::CollectOutputInfo(::pir::Operation* op, out_types->push_back(CompatibleInfo::ConvertIRType(type_info.dtype())); auto out_shape = ::common::vectorize(type_info.dims()); + if (out_shape.empty()) { + out_shape.push_back(1); + } out_shapes->push_back(std::move(out_shape)); } } @@ -791,7 +1157,7 @@ void OpLowererImpl::CollectOutputInfo( ::pir::Operation* op, std::vector* out_types, std::vector>* out_shapes, - const GroupPtr& group) { + const OpLoweringGroupPtr& group) { auto op_results = op->results(); for (auto& out_value : op_results) { std::string output_id = ValueName(out_value); @@ -819,6 +1185,9 @@ void OpLowererImpl::CollectOutputInfo( std::vector sym_shape; ForEachDimExpr( [&](const auto& sym) { sym_shape.emplace_back(output_id, sym); }); + if (sym_shape.empty()) { + sym_shape.emplace_back(output_id, symbol::DimExpr{1}); + } out_shapes->emplace_back(std::move(sym_shape)); } } @@ -860,7 +1229,7 @@ bool OpLowererImpl::IsInTensorMap( } ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector group_func_arg_tensors, const std::vector group_func_args) { // CHECK_EQ(group_func_arg_tensors.size(), group_func_args.size()); @@ -868,9 +1237,6 @@ ir::LoweredFunc OpLowererImpl::GenerateInferShapeFunc( int output_tensor_idx = 0; for (int tensor_arg_idx = 0; tensor_arg_idx < group_func_arg_tensors.size(); ++tensor_arg_idx) { - if (group_func_args[tensor_arg_idx].is_input()) { - continue; - } auto tensor_dim = group_func_arg_tensors[tensor_arg_idx]->sym_shape; int tensor_dim_size = tensor_dim.size(); auto tensor_shape = group_func_arg_tensors[tensor_arg_idx]->shape; diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h index fff73071becb9..e8c2d468347af 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_impl.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_impl.h @@ -21,7 +21,8 @@ #include "paddle/cinn/hlir/framework/instruction.h" #include "paddle/cinn/hlir/framework/op_lowering_impl_base.h" #include "paddle/cinn/hlir/framework/op_strategy.h" -#include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" +#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h" #include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" #include "paddle/cinn/ir/lowered_func.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" @@ -30,9 +31,9 @@ #include "paddle/pir/include/core/operation.h" // Fusion Op lowering, there are four kinds of lowering function: -// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusable,NonFusible. +// Elementwise/Broadcast/Injective,Reduce,OutEWiseFusible,NonFusible. // Elementwise/Broadcast/Injective Ops is with same schedule. -// Reduce,OutEWiseFusable,NonFusible are using different schedule. +// Reduce,OutEWiseFusible,NonFusible are using different schedule. namespace cinn { namespace hlir { @@ -40,14 +41,27 @@ namespace framework { namespace pir { class PrettyNamer; -using GroupPtr = std::shared_ptr; +using OpLoweringGroupPtr = std::shared_ptr; using cinn::common::Target; class OpLowererImpl; typedef bool (OpLowererImpl::*ScheduleDetermineFunction)(::pir::Operation*); -class OpLowererImpl : public OpLowererImplBase { +struct GroupInfo { + std::vector data_space; + std::vector reduce_axis; + std::set reduce_var_names; + std::set shared_var_names; + std::set direct_output_var_names; + std::vector broadcast_output_names; + + std::unordered_map broadcast_info; + std::unordered_map + broadcast_to_elementwise; +}; + +class OpLowererImpl : public OpLowererImplBase { public: explicit OpLowererImpl(const Target&); @@ -58,7 +72,7 @@ class OpLowererImpl : public OpLowererImplBase { * @param apply_group_schedule Whether to schedule at group level. * @return The lowered funcs. */ - std::vector Lower(const GroupPtr& group, + std::vector Lower(const OpLoweringGroupPtr& group, bool apply_op_schedule = true, bool apply_group_schedule = true, bool apply_pass = true); @@ -70,7 +84,7 @@ class OpLowererImpl : public OpLowererImplBase { * @param apply_group_schedule Whether to schedule at group level. * @return The lowered funcs. */ - BucketLoweredFuncsWrapper BucketLower(const GroupPtr& group, + BucketLoweredFuncsWrapper BucketLower(const OpLoweringGroupPtr& group, bool apply_op_schedule = false, bool apply_group_schedule = true, bool apply_pass = true); @@ -88,7 +102,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered funcs. */ std::vector LowerGroup( - const GroupPtr& group, + const OpLoweringGroupPtr& group, bool apply_op_schedule, bool apply_group_schedule, ScheduleDetermineFunction schedule_determine_func); @@ -98,7 +112,7 @@ class OpLowererImpl : public OpLowererImplBase { * @param group The group to be lowered. * @return The lowered funcs. */ - std::vector LowerCustomCall(const GroupPtr& group); + std::vector LowerCustomCall(const OpLoweringGroupPtr& group); /** * @brief Post processing, including preparing function args and temporary @@ -113,12 +127,13 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered funcs after the post processing. */ std::vector PostProcess( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, bool done_op_schedule, std::vector func_bodies, std::vector* group_func_arg_tensors, - std::vector* group_func_args); + std::vector* group_func_args, + std::vector* infer_shape_arg_tensor); /** * @brief Lower an Op set to CINN IR. @@ -130,7 +145,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered func bodies of Op set. */ void LowerOpsForMapExpr( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector<::pir::Operation*>& ops, std::vector* group_func_arg_tensors, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map); @@ -146,7 +161,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered funcs after the post processing. */ std::vector LowerMapExpr( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, bool apply_group_schedule, @@ -166,7 +181,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered func bodies of Op set. */ std::vector LowerOps( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector<::pir::Operation*>& ops, bool apply_op_schedule, ScheduleDetermineFunction schedule_determine_func, @@ -211,7 +226,7 @@ class OpLowererImpl : public OpLowererImplBase { */ ir::Expr DoGroupSchedule( ir::IRSchedule& ir_sch, // NOLINT - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map, const std::unordered_map& tmp_tensor_info); @@ -223,7 +238,7 @@ class OpLowererImpl : public OpLowererImplBase { * @return The lowered func to infer output tensor's shape. */ ir::LoweredFunc GenerateInferShapeFunc( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const std::vector group_func_arg_tensors, const std::vector group_func_args); @@ -236,24 +251,34 @@ class OpLowererImpl : public OpLowererImplBase { private: std::vector CollectInputTensor( - const GroupPtr& group, + const OpLoweringGroupPtr& group, const ::pir::Operation* op, std::vector* func_args, std::unordered_map<::pir::Value, ir::Tensor>* tensor_map); - ir::Tensor GetTensor(const GroupPtr& group, const ::pir::Value& value); - ir::Tensor GetTensorSymbolic(const GroupPtr& group, + ir::Tensor GetTensor(const OpLoweringGroupPtr& group, + const ::pir::Value& value); + ir::Tensor GetTensorSymbolic(const OpLoweringGroupPtr& group, const ::pir::Value& value); + std::shared_ptr GetGroupInfo( + const OpLoweringGroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map); + + std::shared_ptr GetGroupInfo( + const FusionGroupInfo& fusion_group_info, + const OpLoweringGroupPtr& group, + const std::unordered_map<::pir::Value, ir::Tensor>& tensor_map); + void CollectOutputInfo(::pir::Operation* op, std::vector* out_types, std::vector>* out_shapes, - const GroupPtr& group); + const OpLoweringGroupPtr& group); void CollectOutputInfo(::pir::Operation* op, std::vector* out_types, std::vector>* out_shapes, - const GroupPtr& group); + const OpLoweringGroupPtr& group); std::string ValueName(::pir::Value value); @@ -267,9 +292,14 @@ class OpLowererImpl : public OpLowererImplBase { common::Type GetTensorDtype(const ::pir::Value& value); + void BuildBroadcastInfo(const OpLoweringGroupPtr& group, + std::shared_ptr group_info); + Target target_; PrettyNamer* name_gene_; + + std::unordered_set<::pir::Operation*> erase_reshape; }; } // namespace pir diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_util.cc b/paddle/cinn/hlir/framework/pir/op_lowering_util.cc index 038908ff1ab99..56c335f6b63ca 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_util.cc +++ b/paddle/cinn/hlir/framework/pir/op_lowering_util.cc @@ -601,8 +601,8 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT } lane *= inshape[axes[index]]; if (index == 0 && lane <= max_num_threads) { - LOG(FATAL) - << "Error! lane is less equal than max_num_threads, Please check!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Error! lane is less equal than max_num_threads, Please check!")); } if (lane >= max_num_threads / 2) { if (lane <= max_num_threads) { @@ -667,7 +667,7 @@ void LoopAssignReduceWithLast(ir::IRSchedule& ir_sch, // NOLINT ir_sch.Fuse(block_name, {axes[index + 1], axes[index + 1] + 1}); } LoopOrderAssignReduce(ir_sch, block_name, first_axes, target, true); - // fuse axis before reduce to bind blockidx. + // fuse axis before reduce to bind block idx. for (int idx = 0; idx < static_cast(inshape.size() - axes.size()) - 1; ++idx) { ir_sch.Fuse(block_name, {0, 1}); @@ -713,7 +713,7 @@ void LoopAssignReduceWithoutLast(ir::IRSchedule& ir_sch, // NOLINT return left + std::to_string(right) + " "; }); - VLOG(4) << "LoopAssignReduceWithoutLast: THe input shape=[" + VLOG(4) << "LoopAssignReduceWithoutLast: The input shape=[" << cinn::utils::Join(inshape, ", ") << "], first step reduce shape=[" << cinn::utils::Join(shape, ", ") << "]" << ", axes=[" << cinn::utils::Join(axes, ", ") << "], tail=" << tail; @@ -727,7 +727,7 @@ void LoopAssignReduceWithoutLast(ir::IRSchedule& ir_sch, // NOLINT // the loop size at axis is 1, need remove axes_shift_num[j] = -1; } else if (axes[j] > idx) { - // the axies value need left shift + // the axes value need left shift axes_shift_num[j]++; } } @@ -1008,7 +1008,8 @@ void MergeReduceToReduce( n_loops.size() - 1); } } else { - LOG(FATAL) << "not support this type fusion!"; + PADDLE_THROW( + phi::errors::InvalidArgument("not support this type fusion!")); } } } else { @@ -1112,7 +1113,8 @@ void MergeReduceToReduce( ir_sch.SimpleComputeAt(block, loops.back()); } } else { - LOG(FATAL) << "Error! Unkown Reduce Type, Please Check!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Error! Unkown Reduce Type, Please Check!")); } } } @@ -1506,7 +1508,7 @@ void LoopAssignReduce( // copy loop info form rloops. copy_loop_info(nloops, rloops); } else { - LOG(FATAL) << "Error! Unkown Reduce Type!"; + PADDLE_THROW(phi::errors::InvalidArgument("Error! Unkown Reduce Type!")); } } } diff --git a/paddle/cinn/hlir/framework/pir/op_lowering_util.h b/paddle/cinn/hlir/framework/pir/op_lowering_util.h index 201cf7b556f2c..c242ec78fd9ab 100644 --- a/paddle/cinn/hlir/framework/pir/op_lowering_util.h +++ b/paddle/cinn/hlir/framework/pir/op_lowering_util.h @@ -18,6 +18,7 @@ #include #include "paddle/cinn/hlir/framework/pir/group.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_group.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/tensor.h" @@ -26,6 +27,7 @@ namespace hlir { namespace framework { namespace pir { using GroupPtr = std::shared_ptr; +using OpLoweringGroupPtr = std::shared_ptr; class PrettyNamer; diff --git a/paddle/cinn/hlir/framework/pir/op_mapper.h b/paddle/cinn/hlir/framework/pir/op_mapper.h index 73e8d9581e4b0..87053a8c02d53 100644 --- a/paddle/cinn/hlir/framework/pir/op_mapper.h +++ b/paddle/cinn/hlir/framework/pir/op_mapper.h @@ -13,9 +13,12 @@ // limitations under the License. #pragma once + +#include #include #include #include + #include "paddle/cinn/utils/type_defs.h" #include "paddle/pir/include/core/operation.h" diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc new file mode 100644 index 0000000000000..23cad86d604f5 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.cc @@ -0,0 +1,922 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/trivial_op_impl.h" + +#include + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +TrivialOp::TrivialOp(const ir::Expr& origin_func_body) { + func_body = ir::ir_utils::IRCopy(origin_func_body); +} + +TrivialOp::TrivialOp(const TrivialOp& trivial_op) { + func_body = trivial_op.GetFuncBody(); +} + +void TrivialOp::_SetFuncBody(ir::Expr new_body) { func_body = new_body; } + +ir::Expr* TrivialOp::_GetFuncBodyPointer() { return &func_body; } + +ir::Expr TrivialOp::GetFuncBody() const { return func_body; } + +ReduceOp::ReduceOp(const ir::Expr& origin_func_body) { + func_body = ir::ir_utils::IRCopy(origin_func_body); +} + +ReduceOp::ReduceOp(const ReduceOp& reduce_op) { + func_body = reduce_op.GetFuncBody(); +} + +void ReduceOp::_SetFuncBody(ir::Expr new_body) { func_body = new_body; } + +ir::Expr ReduceOp::GetFuncBody() const { return func_body; } + +ir::Expr* ReduceOp::_GetFuncBodyPointer() { return &func_body; } + +using FusibleOp = std::variant; + +ir::Expr _GetRootExpr(const FusibleOp& op) { + return std::visit([](auto&& arg) { return arg.GetFuncBody(); }, op); +} + +void _SetFuncBody(FusibleOp& op, ir::Expr new_body) { // NOLINT + std::visit([&](auto&& arg) { arg._SetFuncBody(new_body); }, op); +} + +ir::Expr GetComputeBody(const FusibleOp& op) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + const auto& compute_realize = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit) + .GetSingle(_GetRootExpr(op)); + const auto& compute_body = + (ExprSetFinderUtils::ChildStores * ExprSetFinderUtils::Store2Value) + .GetSingle(compute_realize); + return ExprTransformerUtils::SubstitudeByScheduleBlockRealize( + compute_realize)(compute_body); + } + ir::Expr operator()(const TrivialOp& op) { + const auto& compute_realize = + (ExprSetFinderUtils::ChildScheduleBlockRealizes) + .GetSingle(_GetRootExpr(op)); + const auto& compute_body = + (ExprSetFinderUtils::ChildStores * ExprSetFinderUtils::Store2Value) + .GetSingle(compute_realize); + return ExprTransformerUtils::SubstitudeByScheduleBlockRealize( + compute_realize)(compute_body); + } + }; + VLOG(4) << "GetComputeBody"; + return std::visit(Visitor(), op); +} + +ir::Tensor GetOutputTensor(const FusibleOp& op) { + struct Visitor { + ir::Tensor operator()(const ReduceOp& op) { + const auto& compute_body = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit * + ExprSetFinderUtils::ChildStores) + .GetSingle(_GetRootExpr(op)); + return compute_body.As()->tensor.as_tensor_ref(); + } + ir::Tensor operator()(const TrivialOp& op) { + const auto& compute_body = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ChildStores) + .GetSingle(_GetRootExpr(op)); + return compute_body.As()->tensor.as_tensor_ref(); + } + }; + VLOG(4) << "GetOutputTensor"; + return std::visit(Visitor(), op); +} + +std::vector AppendBound(const std::vector vars, + const ir::Expr& root) { + return ExprSetFinderUtils::MapVector( + vars, [&](const auto& v) -> ir::Var { + VLOG(4) << "AppendBound for " << v << ", lower: " + << (ExprSetFinderUtils::ChildFors * + ExprSetFinderUtils::IsForIterVar(v) * + ExprSetFinderUtils::For2Min) + .GetSingle(root) + << ", upper: " + << (ExprSetFinderUtils::ChildFors * + ExprSetFinderUtils::IsForIterVar(v) * + ExprSetFinderUtils::For2Max) + .GetSingle(root); + return ir::Var( + (ExprSetFinderUtils::ChildFors * + ExprSetFinderUtils::IsForIterVar(v) * ExprSetFinderUtils::For2Min) + .GetSingle(root), + (ExprSetFinderUtils::ChildFors * + ExprSetFinderUtils::IsForIterVar(v) * ExprSetFinderUtils::For2Max) + .GetSingle(root), + v->name, + v->is_reduce_axis); + }); +} + +std::vector GetOutputIters(const FusibleOp& op) { + struct Visitor { + std::vector operator()(const ReduceOp& op) { + ir::Expr init_block_realize = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsInit) + .GetSingle(_GetRootExpr(op)); + const std::vector& outer_iter_expr = + init_block_realize.As()->iter_values; + return trivial_fusion_detail::ComposeUtils::ExprVec2VarVec( + outer_iter_expr); + } + std::vector operator()(const TrivialOp& op) { + const auto& compute_realize = + (ExprSetFinderUtils::ChildScheduleBlockRealizes) + .GetSingle(_GetRootExpr(op)); + const std::vector& outer_iter_expr = + compute_realize.As()->iter_values; + return trivial_fusion_detail::ComposeUtils::ExprVec2VarVec( + outer_iter_expr); + } + }; + VLOG(4) << "GetOutputIters"; + return AppendBound(std::visit(Visitor(), op), _GetRootExpr(op)); +} + +std::vector GetReduceIters(const ReduceOp& op) { + auto GetUnorderedAllIterVars = [](const ReduceOp& op) { + ir::Expr compute_schedule_block_realize = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit) + .GetSingle(_GetRootExpr(op)); + + const std::vector& all_iter_expr = + compute_schedule_block_realize.As() + ->iter_values; + return ComposeUtils::ExprVec2VarVec(all_iter_expr); + }; + + // Iter Vars not appearing in outer_iter_vars are pushed into + // reduce_iter_vars + std::vector all_iter_vars = GetUnorderedAllIterVars(op); + std::vector outer_iter_vars = GetOutputIters(op); + std::vector reduce_iter_vars; + + for (auto& iter_var : all_iter_vars) { + if (!(std::find(outer_iter_vars.begin(), outer_iter_vars.end(), iter_var) != + outer_iter_vars.end())) { + iter_var->is_reduce_axis = true; + reduce_iter_vars.push_back(iter_var); + } + } + VLOG(4) << "GetReduceIters"; + return AppendBound(reduce_iter_vars, _GetRootExpr(op)); +} + +ir::Expr GetInitExpr(const ReduceOp& op) { + const auto result = + (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsInit * + ExprSetFinderUtils::ChildStores * ExprSetFinderUtils::Store2Value) + .GetSingle(op.GetFuncBody()); + VLOG(4) << "GetInitExpr: " << result; + return result; +} + +ir::Expr* _GetFuncBodyPointer(FusibleOp op) { + return std::visit([&](auto&& arg) { return arg._GetFuncBodyPointer(); }, op); +} + +ir::Expr CopyReduceBody(const FusibleOp& downstream, const ReduceOp& upstream) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + return ir::ir_utils::IRCopy(op.GetFuncBody()); + } + ir::Expr operator()(const TrivialOp& op) { + PADDLE_THROW("TrivialOp cannot be copied."); + } + }; + return std::visit(Visitor(), downstream); +} + +ir::Expr CreateReduceExpr( + const std::vector& output_iters, + const std::vector& reduce_iters, + const ir::Expr& init_body, // relay on output_iters + const ir::Expr& reduce_body, // relay on output_iters + reduce_iters + const ir::Tensor& new_write_tensor, + const ir::Tensor& origin_write_tensor) { + VLOG(4) << "CreateReduceExpr Start."; + const std::vector indice_expr = + std::vector(output_iters.begin(), output_iters.end()); + auto new_init_tensor = ir::Tensor(new_write_tensor->name + "__reduce_init", + new_write_tensor->type(), + new_write_tensor->shape, + new_write_tensor->domain, + new_write_tensor->operation, + reduce_iters); + new_init_tensor->WithBuffer(); + + const auto& init_schedule_block = + (ExprTransformerUtils::WrapStoreTransformer(new_init_tensor, + indice_expr) * + ExprTransformerUtils::WrapScheduleRealizer( + output_iters, new_init_tensor->name))(init_body); + + const auto& reduce_schedule_block = + (ExprTransformerUtils::ChangeTensorLoadTransformer( + origin_write_tensor, new_write_tensor(indice_expr)) * + ExprTransformerUtils::WrapStoreTransformer(new_write_tensor, + indice_expr) * + ExprTransformerUtils::WrapScheduleRealizer( + ComposeUtils::ConcatVector(output_iters, reduce_iters), + new_write_tensor->name) * + ExprTransformerUtils::WrapForsTransformer(reduce_iters))(reduce_body); + + const auto& gather_body = ir::Block::Make( + std::vector({init_schedule_block, reduce_schedule_block})); + return ir::Block::Make( + {(ExprTransformerUtils::WrapForsTransformer(output_iters) * + ExprTransformerUtils::WrapScheduleRealizer({}, "root"))(gather_body)}); +} + +ir::Expr CreateTrivialExpr(const std::vector& output_iters, + const ir::Expr& function_body, + const ir::Tensor& new_write_tensor) { + const auto& RemoveReduceAxisFromVar = + [](const std::vector& vars) -> std::vector { + std::vector result; + for (auto& var : vars) { + auto new_var = ir::ir_utils::IRCopy(var).as_var_ref(); + new_var->is_reduce_axis = false; + result.push_back(new_var); + } + return result; + }; + auto trivial_iters = RemoveReduceAxisFromVar(output_iters); + const std::vector indice_expr = + std::vector(trivial_iters.begin(), trivial_iters.end()); + const auto& compute_body_schedule_block = + (ExprTransformerUtils::WrapStoreTransformer(new_write_tensor, + indice_expr) * + ExprTransformerUtils::WrapScheduleRealizer( + trivial_iters, new_write_tensor->name))(function_body); + return ir::Block::Make( + {(ExprTransformerUtils::WrapForsTransformer(trivial_iters) * + ExprTransformerUtils::WrapScheduleRealizer({}, "root"))( + ir::Block::Make({compute_body_schedule_block}))}); +} + +ir::Expr CreateExprWithNewComputeBody(const FusibleOp& fusible_op, + const ir::Expr& new_compute_body) { + struct Visitor { + ir::Expr operator()(const ReduceOp& op) { + return CreateReduceExpr(GetOutputIters(op), + GetReduceIters(op), + GetInitExpr(op), + compute_body_, + GetOutputTensor(op), + GetOutputTensor(op)); + } + ir::Expr operator()(const TrivialOp& op) { + return CreateTrivialExpr( + GetOutputIters(op), compute_body_, GetOutputTensor(op)); + } + + ir::Expr compute_body_; + explicit Visitor(ir::Expr compute_body) { compute_body_ = compute_body; } + }; + VLOG(4) << "CreateExprWithNewComputeBody"; + return std::visit(Visitor(new_compute_body), fusible_op); +} + +FusionNode::FusionNode(FusibleOp fusible_op) : fusible_op(fusible_op) {} + +std::string FusionNode::GetTensorCounter() { + static int i = 0; + return std::to_string(i++); +} + +void FusionNode::replace_topo_structure_of_fused_nodes( + FusionNode* fused_up_node, FusionNode* fused_down_node) { + upstream.insert(fused_up_node->upstream.begin(), + fused_up_node->upstream.end()); + upstream.insert(fused_down_node->upstream.begin(), + fused_down_node->upstream.end()); + upstream.erase(fused_up_node); + + downstream.insert(fused_up_node->downstream.begin(), + fused_up_node->downstream.end()); + downstream.insert(fused_down_node->downstream.begin(), + fused_down_node->downstream.end()); + downstream.erase(fused_down_node); + + expr_related_op = fused_down_node->expr_related_op; + + for (const auto& pair_data : upstream) { + FusionNode* upstream_node = pair_data.first; + ::pir::Value related_value = pair_data.second; + if (upstream_node->downstream.find(fused_up_node) != + upstream_node->downstream.end()) { + upstream_node->downstream.erase(fused_up_node); + } + if (upstream_node->downstream.find(fused_down_node) != + upstream_node->downstream.end()) { + upstream_node->downstream.erase(fused_down_node); + } + upstream_node->downstream[this] = related_value; + } + + for (const auto& pair_data : downstream) { + FusionNode* downstream_node = pair_data.first; + ::pir::Value related_value = pair_data.second; + if (downstream_node->upstream.find(fused_up_node) != + downstream_node->upstream.end()) { + downstream_node->upstream.erase(fused_up_node); + } + if (downstream_node->upstream.find(fused_down_node) != + downstream_node->upstream.end()) { + downstream_node->upstream.erase(fused_down_node); + } + downstream_node->upstream[this] = related_value; + } +} + +bool FusionNode::IsTrivial() const { + return std::holds_alternative(fusible_op); +} + +bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down) {} + +std::vector FusionGraph::TransformReduceLoopRange( + const ReduceOp& upstream, FusibleOp* downstream) { + // downstream will be mutated by this transform. + VLOG(4) << "RRTransform begin"; + VLOG(4) << "RRTransform Upstream is \n" << _GetRootExpr(upstream); + VLOG(4) << "RRTransform Downstream is \n" << _GetRootExpr(*downstream); + ir::Expr modified_downstream_compute_body = GetComputeBody(*downstream); + const auto& load_upstream_expr = ComposeUtils::GetEachTensorLoadExpr( + modified_downstream_compute_body, GetOutputTensor(upstream)); + std::vector results; + ir::Tensor downstream_output_tensor = GetOutputTensor(*downstream); + + bool is_trivial_downstream = std::holds_alternative(*downstream); + + const auto create_new_tensor = [&](const ir::Tensor& downstream_load_tensor) { + VLOG(4) << "Create New Tensor Start"; + ir::Tensor result = ir::Tensor( + downstream_load_tensor->name + "_" + FusionNode::GetTensorCounter(), + downstream_load_tensor->type(), + is_trivial_downstream + ? FilterWithFakeReduceIter(downstream_output_tensor->shape) + : downstream_output_tensor->shape, + is_trivial_downstream + ? FilterWithFakeReduceIter(downstream_output_tensor->domain) + : downstream_output_tensor->domain, + GetOutputTensor(upstream)->operation, + GetReduceIters(upstream)); + result->WithBuffer(); + VLOG(4) << "Create New Tensor Result: " << result; + return result; + }; + + for (const auto& load_tensor : load_upstream_expr) { + const auto& new_tensor = + create_new_tensor(load_tensor.As()->tensor.as_tensor_ref()); + ir::Expr new_reduce = CreateReduceExpr( + is_trivial_downstream + ? FilterWithFakeReduceIter(GetOutputIters(*downstream)) + : GetOutputIters(*downstream), + GetReduceIters(upstream), + GetInitExpr(upstream), + ComposeUtils::CopyedReplaceExpr(GetComputeBody(upstream), + GetOutputIters(upstream), + load_tensor.As()->indices), + new_tensor, + GetOutputTensor(upstream)); + results.emplace_back(ReduceOp(new_reduce)); + ExprTransformerUtils::ReplaceTarget( + &modified_downstream_compute_body, + load_tensor, + new_tensor(ComposeUtils::VarVec2ExprVec( + is_trivial_downstream + ? FilterWithFakeReduceIter(GetOutputIters(*downstream)) + : GetOutputIters(*downstream)))); + } + _SetFuncBody(*downstream, + CreateExprWithNewComputeBody(*downstream, + modified_downstream_compute_body)); + VLOG(4) << "RRTransform After Replace Downstream Load: \n" + << _GetRootExpr(*downstream); + return results; +} + +FusibleOp FusionGraph::TrivialFusion(FusionNode* upstream, + FusionNode* downstream) { + CHECK(upstream->IsTrivial()); + if (downstream->IsTrivial()) { + return TrivalxOther_Fusion(std::get(upstream->fusible_op), + std::get(downstream->fusible_op)); + } else { + return TrivalxOther_Fusion(std::get(upstream->fusible_op), + std::get(downstream->fusible_op)); + } +} + +FusibleOp FusionGraph::SinkTrivialLoopAlign(TrivialOp trivial_op, + ReduceOp reduce_op) { + VLOG(4) << "SinkTrivialLoopAlign"; + ir::Expr new_trivial_body = ir::ir_utils::IRCopy(trivial_op.GetFuncBody()); + std::vector all_out_iter_vars = GetOutputIters(trivial_op); + std::vector non_reduce_iter_vars = + FilterWithFakeReduceIter(all_out_iter_vars); + std::vector fake_reduce_iter_vars; + for (const auto& idx : fake_reduce_iter_idx_) { + fake_reduce_iter_vars.emplace_back( + all_out_iter_vars.at(static_cast(idx))); + } + + VLOG(4) << "all_out_iter_vars: " + << cinn::utils::Join(all_out_iter_vars, ", "); + VLOG(4) << "non_reduce_iter_vars: " + << cinn::utils::Join(non_reduce_iter_vars, ", "); + VLOG(4) << "fake_reduce_iter_vars: " + << cinn::utils::Join(fake_reduce_iter_vars, ", "); + + ir::Expr trivial_last_for = + (ExprSetFinderUtils::ChildFors * + ExprSetFinderUtils::IsForIterVar(all_out_iter_vars.back())) + .GetSingle(new_trivial_body); + ir::Expr new_for_body = trivial_last_for.As()->body; + + const auto ExpandIterVars = [&]() { + std::vector result = + ComposeUtils::ConcatVector(non_reduce_iter_vars, fake_reduce_iter_vars); + auto upstream_reduce_iters = GetReduceIters(reduce_op); + if (fake_reduce_iter_vars.size() != upstream_reduce_iters.size()) { + result.insert(result.end(), + upstream_reduce_iters.begin(), + upstream_reduce_iters.end()); + } + VLOG(4) << "ExpandIterVars: " << cinn::utils::Join(result, ", "); + return result; + }; + + ir::Expr new_schedule_realizer = + (ExprTransformerUtils::WrapForsTransformer(ExpandIterVars()) * + ExprTransformerUtils::WrapScheduleRealizer({}, "root"))(new_for_body); + + VLOG(4) << "new_schedule_realizer\n" << new_schedule_realizer; + return TrivialOp(new_schedule_realizer); +} + +std::vector FusionGraph::ReduceTransformRecursive( + FusibleOp root_op, FusionNode* fusion_tree) { + VLOG(4) << "ReduceTransformRecursive: " << *_GetFuncBodyPointer(root_op); + std::vector result; + for (auto& pair : fusion_tree->upstream) { + auto transformed_nodes = TransformReduceLoopRange( + std::get(pair.first->fusible_op), &root_op); + for (auto& node : transformed_nodes) { + auto child_flatten = ReduceTransformRecursive(node, pair.first); + result.insert(result.end(), child_flatten.begin(), child_flatten.end()); + } + } + VLOG(4) << "Before push_back, is trivial_op: " + << std::holds_alternative(root_op); + result.push_back( + std::holds_alternative(root_op) + ? SinkTrivialLoopAlign( + std::get(root_op), + std::get( + fusion_tree->upstream.begin()->first->fusible_op)) + : root_op); + VLOG(4) << "After push_back."; + return result; +} + +std::vector FusionGraph::ReduceTransform(FusionNode* downstream) { + if (downstream->IsTrivial() && downstream->upstream.empty()) { + return {downstream->fusible_op}; + } + auto reduces = ReduceTransformRecursive(downstream->fusible_op, downstream); + return reduces; +} + +FusibleOp CreateFusibleOp(ir::Expr compute_body, OpPatternKind op_pattern) { + if (IsTrivialKind(op_pattern)) { + return TrivialOp(compute_body); + } else { + return ReduceOp(compute_body); + } +} + +template +std::vector FilterVector(const std::vector& ops, const F& f) { + std::vector res; + for (const auto& op : ops) { + if (f(op)) { + res.push_back(op); + } + } + return res; +} + +FusionGraph::FusionGraph( + const cinn::frontend::group_cluster::PatternNodePtr& pattern_node, + const std::unordered_map<::pir::Operation*, ir::Expr>& op_expr_map) { + VLOG(4) << "CreateFusionGraph"; + + std::vector<::pir::Operation*> ops = pattern_node->GetOps(); + std::vector op_compute_bodies = std::vector(); + std::transform(ops.begin(), + ops.end(), + std::back_inserter(op_compute_bodies), + [&](::pir::Operation* op) { return op_expr_map.at(op); }); + + if (pattern_node->IsReduceTrivial()) { + fake_reduce_iter_idx_ = + std::get( + pattern_node->stmt_pattern_) + .fake_reduce_iter_idx; + } + + const auto& op_patterns = GetOpPatternKindVector(ops); + CheckFusionInputValid(op_compute_bodies, op_patterns); + + std::unordered_map<::pir::Operation*, FusionNode*> op_to_node_map; + + for (int i = 0; i < ops.size(); ++i) { + FusionNode* node = + new FusionNode(CreateFusibleOp(op_compute_bodies[i], op_patterns[i])); + op_to_node_map[ops[i]] = node; + all_fusion_nodes_.emplace(node); + node->expr_related_op = ops[i]; + } + + for (::pir::Operation* op : ops) { + FusionNode* cur_node = op_to_node_map[op]; + + // add upstream nodes + for (int i = 0; i < op->num_operands(); ++i) { + ::pir::Value related_value = op->operand_source(i); + ::pir::Operation* input_op = related_value.defining_op(); + if (op_to_node_map.find(input_op) != op_to_node_map.end()) { + FusionNode* upstream_node = op_to_node_map[input_op]; + cur_node->upstream[upstream_node] = related_value; + upstream_node->downstream[cur_node] = related_value; + } + } + + // add downstream nodes + for (int i = 0; i < op->num_results(); ++i) { + ::pir::Value related_value = op->result(i); + for (auto consumer_it = related_value.use_begin(); + consumer_it != related_value.use_end(); + ++consumer_it) { + ::pir::Operation* output_op = consumer_it->owner(); + if (op_to_node_map.find(output_op) != op_to_node_map.end()) { + FusionNode* downstream_node = op_to_node_map[output_op]; + cur_node->downstream[downstream_node] = related_value; + downstream_node->upstream[cur_node] = related_value; + } + } + } + + if (cur_node->upstream.empty()) { + entrance_nodes_.emplace(cur_node); + } + + if (cur_node->downstream.empty()) { + exit_nodes_.emplace(cur_node); + } + } + + VLOG(4) << "FusionGraph Created, fusion node size: " + << all_fusion_nodes_.size(); +} + +FusionGraph::~FusionGraph() { + for (FusionNode* node : all_fusion_nodes_) { + delete node; + } +} + +std::vector GetShapeFromVars(const std::vector& vars) { + std::vector res; + for (const auto& v : vars) { + res.emplace_back(v->upper_bound); + } + return res; +} + +void DebugPrintReduceVar(const FusibleOp& op) { + VLOG(4) << "DebugPrint Op: " << GetOutputTensor(op); + VLOG(4) << "DebugPrint Op: " << GetComputeBody(op); + const auto& block = (ExprSetFinderUtils::ChildScheduleBlockRealizes * + ExprSetFinderUtils::ScheduleBlockRealizeIsNotInit * + ExprSetFinderUtils::Realizer2ScheduleBlock) + .GetSingle(_GetRootExpr(op)); + const std::vector& iter_vars = + block.As()->iter_vars; + for (const auto& v : iter_vars) { + VLOG(4) << "Var: " << v << " is_reduce_axis=" << v->is_reduce_axis; + } +} + +void FusionGraph::SplitReduceTransform() { + VLOG(4) << "SplitReduceTransform Start."; + std::vector result; + for (const auto& fop : fusion_results_) { + if (std::holds_alternative(fop)) { + VLOG(4) << "DebugPrint Op Origin: "; + ReduceOp reduce_op = std::get(fop); + ir::Tensor reduce_out_tensor = GetOutputTensor(reduce_op); + // substitude compute_body with a new init value. + ir::Expr trivial_compute_body = + ExprTransformerUtils::ChangeTensorLoadTransformer( + GetOutputTensor(fop), + GetInitExpr(reduce_op))(GetComputeBody(reduce_op)); + + const std::vector& all_iters = ComposeUtils::ConcatVector( + GetOutputIters(reduce_op), GetReduceIters(reduce_op)); + VLOG(4) << "Trivial Compute Body is " << trivial_compute_body; + ir::Tensor new_trivial_tensor = + ir::Tensor(reduce_out_tensor->name + "_split_transform", + reduce_out_tensor->type(), + GetShapeFromVars(all_iters), + GetShapeFromVars(all_iters), + ir::ComputeOp::Make( + reduce_out_tensor->name + "_split_transform", + [body = trivial_compute_body]( + const std::vector& indices) { return body; }, + GetShapeFromVars(all_iters), + GetShapeFromVars(all_iters), + {}), + {}); + new_trivial_tensor->WithBuffer(); + VLOG(4) << "Created Tensor is: " << new_trivial_tensor; + VLOG(4) << "Load Expr is: " + << new_trivial_tensor(ComposeUtils::VarVec2ExprVec(all_iters)); + + // push trivial op + VLOG(4) << "Splited TrivialOp is " + << CreateTrivialExpr( + all_iters, trivial_compute_body, new_trivial_tensor); + + result.emplace_back(TrivialOp(CreateTrivialExpr( + all_iters, trivial_compute_body, new_trivial_tensor))); + + // push reduce op, change compute_body to + VLOG(4) + << "WrapReduceOperation start: with reduce_type: " + << GetOutputTensor(reduce_op)->body().As()->reduce_type; + VLOG(4) << "WrapReduceOperation new_trivial_tensor: " + << new_trivial_tensor(ComposeUtils::VarVec2ExprVec(all_iters)); + const ir::Expr& new_reduce_body = + ExprTransformerUtils::WrapReduceOperation( + GetOutputTensor(reduce_op)->body().As()->reduce_type, + GetOutputTensor(reduce_op), + ComposeUtils::VarVec2ExprVec(GetOutputIters(reduce_op)))( + new_trivial_tensor(ComposeUtils::VarVec2ExprVec(all_iters))); + VLOG(4) << "Splited ReduceOp body is " << new_reduce_body; + VLOG(4) << "Splited ReduceOp is " + << CreateExprWithNewComputeBody( + fop, + ExprSetFinderUtils::Store2Value.GetSingle( + new_reduce_body)); + result.emplace_back(ReduceOp(CreateExprWithNewComputeBody( + fop, ExprSetFinderUtils::Store2Value.GetSingle(new_reduce_body)))); + } else { + result.emplace_back(fop); + } + } + fusion_results_ = result; + VLOG(4) << "SplitReduceTransform End~"; +} + +std::vector FusionGraph::DoFusion() { + VLOG(4) << "Start Trivial Fusion"; + DoTrivialFusion(); + VLOG(4) << "Start R + T and R + R Fusion"; + ReduceLoopTranform(); + // TODO(@xubin): remove this when backend support arbitrary reduce. + VLOG(4) << "Split Reduce Transform into a tmp tensor to keep reduce clean."; + SplitReduceTransform(); + return GetExprResults(); +} + +FusionNode* FusionGraph::FindTrivialFusibleNode() { + for (FusionNode* node : all_fusion_nodes_) { + if (node->IsTrivial() && !node->downstream.empty()) { + return node; + } + } + return nullptr; +} + +void FusionGraph::DoTrivialFusion() { + FusionNode* upstream = nullptr; + // use funcion to get upstream and downstream is save here + // cause we might delete Nodes in this process + while ((upstream = FindTrivialFusibleNode()) != nullptr) { + std::unordered_map fusion_candidate = + upstream->downstream; + upstream->downstream.clear(); + for (const auto& pair_data : fusion_candidate) { + FusionNode* downstream = pair_data.first; + FusionNode* new_node = + new FusionNode(TrivialFusion(upstream, downstream)); + new_node->replace_topo_structure_of_fused_nodes(upstream, downstream); + AppendNode(new_node); + RemoveNode(downstream); + } + RemoveNode(upstream); + } +} + +void FusionGraph::ReduceLoopTranform() { + for (FusionNode* node : exit_nodes_) { + auto fusion_nodes = ReduceTransform(node); + fusion_results_.insert( + fusion_results_.end(), fusion_nodes.begin(), fusion_nodes.end()); + } +} + +std::vector FusionGraph::GetExprResults() { + std::vector output_exprs; + for (const auto& node : fusion_results_) { + output_exprs.emplace_back(_GetRootExpr(node)); + } + return output_exprs; +} + +void FusionGraph::RemoveNode(FusionNode* node) { + if (all_fusion_nodes_.find(node) != all_fusion_nodes_.end()) { + all_fusion_nodes_.erase(node); + } + if (entrance_nodes_.find(node) != entrance_nodes_.end()) { + entrance_nodes_.erase(node); + } + if (exit_nodes_.find(node) != exit_nodes_.end()) { + exit_nodes_.erase(node); + } + delete node; +} + +void FusionGraph::AppendNode(FusionNode* node) { + all_fusion_nodes_.emplace(node); + if (node->upstream.empty()) { + entrance_nodes_.emplace(node); + } + + if (node->downstream.empty()) { + exit_nodes_.emplace(node); + } +} + +FusionNode* FusionGraph::FindReduceUpstream(FusionNode* node) { + for (const auto& pair_data : node->upstream) { + FusionNode* upstream = pair_data.first; + if (!upstream->IsTrivial()) { + return upstream; + } + } + return nullptr; +} + +} // namespace trivial_fusion_detail + +std::vector OperationFusion( + const std::vector<::pir::Operation*>& original_ops, + const std::vector& op_compute_bodies) { + const auto& ops = trivial_fusion_detail::FilterVector( + original_ops, [](const ::pir::Operation* op) { + if (op->name() == "cinn_op.generate_shape") { + return false; + } + return true; + }); + + auto output = std::vector(); + auto op_expr_map = + trivial_fusion_detail::ComposeUtils::MakeMap(ops, op_compute_bodies); + + auto frontend_cluster_result = cinn::frontend::ClusterOps(ops); + for (const auto& frontend_node : frontend_cluster_result) { + trivial_fusion_detail::FusionGraph graph = + trivial_fusion_detail::FusionGraph(frontend_node, op_expr_map); + output = trivial_fusion_detail::ComposeUtils::ConcatVector( + output, graph.DoFusion()); + } + + VLOG(4) << "Fusion Result: output size is " << output.size(); + for (const auto& expr : output) { + VLOG(4) << expr; + } + return output; +} + +FusionGroupInfo GetFusionGroupInfo( + const std::vector& op_compute_bodies) { + using trivial_fusion_detail::ReduceOp; + using trivial_fusion_detail::ComposeUtils::ConcatVector; + using trivial_fusion_detail::ExprSetFinderUtils::ChildScheduleBlockRealizes; + using trivial_fusion_detail::ExprSetFinderUtils::ScheduleBlockRealizeIsInit; + + FusionGroupInfo group_info = FusionGroupInfo(); + + const auto IsReduceBody = [](const ir::Expr& expr_body) { + return !(ChildScheduleBlockRealizes * ScheduleBlockRealizeIsInit)(expr_body) + .empty(); + }; + + for (const auto& body : op_compute_bodies) { + if (IsReduceBody(body)) { + ReduceOp op = ReduceOp(body); + if (group_info.reduce_var_name.empty()) { + std::vector all_iters = + ConcatVector(GetOutputIters(op), GetReduceIters(op)); + std::transform(all_iters.begin(), + all_iters.end(), + std::back_inserter(group_info.loop_ranges), + [](const ir::Var var) { + VLOG(4) << "Var is : : " << var; + VLOG(4) << "Var->upper_bound: " << var->upper_bound; + if (var->upper_bound.is_constant()) { + return var->upper_bound.as_int64(); + } else { + return (int64_t)-1; + } + }); + std::vector reduce_iters = GetReduceIters(op); + for (int64_t i = all_iters.size() - reduce_iters.size(); + i < all_iters.size(); + i++) { + group_info.reduce_axis.emplace_back(i); + } + } + group_info.reduce_var_name.emplace_back(GetOutputTensor(op)->name); + } + } + + if (group_info.reduce_var_name.empty()) { + trivial_fusion_detail::TrivialOp op = + trivial_fusion_detail::TrivialOp(*(op_compute_bodies.begin())); + std::vector iters = GetOutputIters(op); + std::transform(iters.begin(), + iters.end(), + std::back_inserter(group_info.loop_ranges), + [](const ir::Var var) { + if (var->upper_bound.is_constant()) { + return var->upper_bound.as_int64(); + } else { + return (int64_t)-1; + } + }); + } + VLOG(4) << group_info.DebugPrint(); + return group_info; +} + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_impl.h b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h new file mode 100644 index 0000000000000..27b8705db107b --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_impl.h @@ -0,0 +1,227 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +#include "paddle/cinn/frontend/group_cluster/group_cluster.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +struct TrivialOp { + public: + explicit TrivialOp(const ir::Expr& origin_func_body); + + TrivialOp(const TrivialOp& trivial_op); + + void _SetFuncBody(ir::Expr new_body); + ir::Expr* _GetFuncBodyPointer(); + + ir::Expr GetFuncBody() const; + + private: + ir::Expr func_body; +}; + +struct ReduceOp { + public: + explicit ReduceOp(const ir::Expr& origin_func_body); + ReduceOp(const ReduceOp& reduce_op); + + void _SetFuncBody(ir::Expr new_body); + + ir::Expr GetFuncBody() const; + + ir::Expr* _GetFuncBodyPointer(); + + private: + ir::Expr func_body; +}; + +using FusibleOp = std::variant; + +ir::Expr _GetRootExpr(const FusibleOp& op); + +void _SetFuncBody(FusibleOp& op, ir::Expr new_body); // NOLINT +ir::Expr GetComputeBody(const FusibleOp& op); + +ir::Tensor GetOutputTensor(const FusibleOp& op); + +std::vector AppendBound(const std::vector vars, + const ir::Expr& root); + +std::vector GetOutputIters(const FusibleOp& op); + +std::vector GetReduceIters(const ReduceOp& op); + +ir::Expr GetInitExpr(const ReduceOp& op); + +ir::Expr* _GetFuncBodyPointer(FusibleOp op); + +ir::Expr CopyReduceBody(const FusibleOp& downstream, const ReduceOp& upstream); + +ir::Expr CreateReduceExpr( + const std::vector& output_iters, + const std::vector& reduce_iters, + const ir::Expr& init_body, // relay on output_iters + const ir::Expr& reduce_body, // relay on output_iters + reduce_iters + const ir::Tensor& new_write_tensor, + const ir::Tensor& origin_write_tensor); + +ir::Expr CreateTrivialExpr(const std::vector& output_iters, + const ir::Expr& function_body, + const ir::Tensor& new_write_tensor); +ir::Expr CreateExprWithNewComputeBody(const FusibleOp& fusible_op, + const ir::Expr& new_compute_body); +struct FusionNode { + FusibleOp fusible_op; + ::pir::Operation* expr_related_op; + + std::unordered_map upstream; + std::unordered_map downstream; + + explicit FusionNode(FusibleOp fusible_op); + + static std::string GetTensorCounter(); + void replace_topo_structure_of_fused_nodes(FusionNode* fused_up_node, + FusionNode* fused_down_node); + + bool IsTrivial() const; +}; + +bool CheckAllLoopRangeEq(ReduceOp reduce_upper, TrivialOp trivial_down); + +FusibleOp CreateFusibleOp(ir::Expr compute_body, OpPatternKind op_pattern); + +struct FusionGraph { + explicit FusionGraph( + const cinn::frontend::group_cluster::PatternNodePtr& pattern_node, + const std::unordered_map<::pir::Operation*, ir::Expr>& op_expr_map); + ~FusionGraph(); + + std::vector DoFusion(); + + private: + FusionNode* FindTrivialFusibleNode(); + void DoTrivialFusion(); + void ReduceLoopTranform(); + void SplitReduceTransform(); + std::vector GetExprResults(); + void RemoveNode(FusionNode* node); + void AppendNode(FusionNode* node); + FusionNode* FindReduceUpstream(FusionNode* node); + + private: + FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream); + + template + DownStreamOp TrivalxOther_Fusion(TrivialOp upstream, + DownStreamOp downstream) { + VLOG(4) << "Trivial x OtherFusion begin."; + + const auto& replaced_tensor = GetOutputTensor(upstream); + VLOG(4) << "upstream is " << upstream.GetFuncBody(); + VLOG(4) << "downstream is " << downstream.GetFuncBody(); + + ir::Expr modified_body = ir::ir_utils::IRCopy(downstream.GetFuncBody()); + SequenceMutator( + ComposeUtils::GetEachTensorLoadExpr(modified_body, replaced_tensor), + &modified_body, + [&](const ir::Expr& downstream_load_expr, ir::Expr* downstream_body) { + ComposeUtils::ReplaceDownstreamLoadExprWithUpstreamComputeBody( + upstream, downstream_load_expr, downstream_body); + }); + + VLOG(4) << "TTFusion end:\n" << modified_body; + return DownStreamOp(modified_body); + } + + std::vector ReduceTransform(FusionNode* downstream); + std::vector ReduceTransformRecursive(FusibleOp root_op, + FusionNode* fusion_tree); + std::vector TransformReduceLoopRange(const ReduceOp& upstream, + FusibleOp* downstream); + FusibleOp SinkTrivialLoopAlign(TrivialOp trivial_op, ReduceOp reduce_op); + + template + std::vector FilterWithFakeReduceIter(const std::vector& input) { + std::vector result; + for (size_t i = 0; i < input.size(); i++) { + if (std::find(fake_reduce_iter_idx_.begin(), + fake_reduce_iter_idx_.end(), + i) == fake_reduce_iter_idx_.end()) { + result.emplace_back(input.at(i)); + } + } + return result; + } + + private: + std::unordered_set all_fusion_nodes_; + std::vector fusion_results_; + std::unordered_set entrance_nodes_; + std::unordered_set exit_nodes_; + + std::vector fake_reduce_iter_idx_; + // std::unordered_map<::pir::Value, ShardableAxes> shardable_axes_; +}; + +} // namespace trivial_fusion_detail + +struct FusionGroupInfo { + std::vector loop_ranges; + std::vector reduce_axis; + std::vector reduce_var_name; + + std::string DebugPrint() { + return "GroupInfo\nloop_ranges: " + cinn::utils::Join(loop_ranges, " ") + + "\nreduce_axis: " + cinn::utils::Join(reduce_axis, " ") + + "\nreduce_var_name: " + cinn::utils::Join(reduce_var_name, " "); + } +}; + +FusionGroupInfo GetFusionGroupInfo( + const std::vector& op_compute_bodies); + +std::vector OperationFusion( + const std::vector<::pir::Operation*>& ops, + const std::vector& op_compute_bodies); + +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_util.cc b/paddle/cinn/hlir/framework/pir/trivial_op_util.cc new file mode 100644 index 0000000000000..c930aa8a8fd95 --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_util.cc @@ -0,0 +1,521 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/hlir/framework/pir/trivial_op_util.h" + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +namespace ComposeUtils { + +std::vector ExprVec2VarVec(const std::vector& in) { + std::vector out; + for (auto& expr : in) { + out.push_back(expr.as_var_ref()); + } + return out; +} + +std::vector VarVec2ExprVec(const std::vector& in) { + return std::vector(in.begin(), in.end()); +} + +std::vector GetEachTensorLoadExpr(const ir::Expr& body, + const ir::Tensor& tensor) { + VLOG(4) << "GetEachTensorLoadExpr: " << tensor; + std::set load_exprs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor( + body, [&tensor](const Expr* expr) { + return expr->As() && expr->As()->is_addr_tensor() && + expr->As()->tensor.as_tensor_ref()->name == + tensor->name; + }); + for (auto& t : load_exprs) { + VLOG(4) << "GetEachTensorLoadExpr Found: " << t << " " << t.ptr(); + } + return std::vector(load_exprs.begin(), load_exprs.end()); +} + +MappingTargetExprToDestExprMutator::MappingTargetExprToDestExprMutator( + const ir::Expr& source, const ir::Expr& dest) + : source_(source), dest_(dest) {} + +void MappingTargetExprToDestExprMutator::operator()(Expr* expr) { + IRMutator::Visit(expr, expr); +} + +void MappingTargetExprToDestExprMutator::Visit(const ir::Load* load, Expr* op) { + if (load == source_.ptr()) { + *op = dest_; + } else { + IRMutator::Visit(load, op); + } +} +void MappingTargetExprToDestExprMutator::Visit(const ir::Store* store, + Expr* op) { + if (store == source_.ptr()) { + *op = dest_; + } else { + IRMutator::Visit(store, op); + } +} +void MappingTargetExprToDestExprMutator::Visit(const ir::Reduce* reduce, + Expr* op) { + if (reduce == source_.ptr()) { + *op = dest_; + } else { + IRMutator::Visit(reduce, op); + } +} + +bool CheckIterEq(const std::vector& up_iter, + const std::vector& down_iter) { + if (up_iter.size() != down_iter.size()) return false; + + for (int i = 0; i < up_iter.size(); ++i) { + const ir::Var& up_iter_var = up_iter[i]; + const ir::Var& down_iter_var = down_iter[i]; + + if (up_iter_var != down_iter_var) return false; + if (up_iter_var->lower_bound.as_int64() != + down_iter_var->lower_bound.as_int64()) + return false; + if (up_iter_var->upper_bound.as_int64() != + down_iter_var->upper_bound.as_int64()) + return false; + } + return true; +} + +ir::Expr CopyedReplaceExpr(const Expr& source, + const std::vector& replaced, + const std::vector& candidates) { + VLOG(4) << "CopyedReplaceExpr Start"; + VLOG(4) << "Replace Body : " << source; + VLOG(4) << "Replace From : " << cinn::utils::Join(replaced, " "); + VLOG(4) << "Replace To : " << cinn::utils::Join(candidates, " "); + + CHECK_EQ(replaced.size(), candidates.size()) + << "In ReplaceExpr, the size of Vars to be replaced must be equal to " + "the " + "size of cadidate Exprs! Please check."; + auto copyed_source = ir::ir_utils::IRCopy(source); + if (replaced.empty()) return copyed_source; + std::map replacing_map; + for (int i = 0; i < replaced.size(); ++i) { + // If the Var to be replaced is equal to the candidate, we skip it. + if (candidates[i].is_var() && candidates[i].as_var_ref() == replaced[i]) + continue; + replacing_map[replaced[i]] = candidates[i]; + } + ir::MappingVarToExprMutator mapper(replacing_map); + mapper(©ed_source); + VLOG(4) << "CopyedReplaceExpr Result: " << copyed_source; + return copyed_source; +} + +void SubstitudeTargetExprWithDestExpr(const ir::Expr& source, + const ir::Expr& dest, + ir::Expr* body) { + VLOG(4) << "SubstitideExpr Start"; + VLOG(4) << "Substitide Body : " << *body; + VLOG(4) << "Substitide From : " << source; + VLOG(4) << "Substitide To : " << dest; + MappingTargetExprToDestExprMutator mapper(source, dest); + mapper(body); + VLOG(4) << "SubstitideExpr Result: " << *body; +} + +ir::Expr SubstitudeIndexVector(const Expr& source, + const std::vector& load_vars, + const std::vector& indices) { + return CopyedReplaceExpr(source, load_vars, indices); +} +} // namespace ComposeUtils + +namespace ExprSetFinderUtils { + +using ExprSet = std::vector; +using Expr2ExprSet = std::function; +ExprSetFinder::ExprSetFinder(Expr2ExprSet f, std::string s) { + f_ = f; + name = s; +} +ExprSet ExprSetFinder::operator()(const ir::Expr& x) const { return f_(x); } +ir::Expr ExprSetFinder::GetSingle(const ir::Expr& x) const { + ExprSetFinder call = (*this) * ExprSetFinder::GetIdentity(); + const auto& o = call.operator()(x); + if (o.size() != 1) { + PADDLE_THROW("Try to get single result, but we get %d.", o.size()); + } + return *o.begin(); +} + +ExprSetFinder ExprSetFinder::operator*(ExprSetFinder x) const { + auto new_f = [self = *this, x = x](const ir::Expr& e) -> ExprSet { + const auto& rs = self.f_(e); + VLOG(6) << "ExprSetFinder Info : " << self.name; + VLOG(6) << " Inputs :" << e; + for (const auto& r : rs) { + VLOG(6) << " Outputs : \n" << r; + } + std::vector res; + for (const auto& r : rs) { + const auto& x_res = x.f_(r); + res.insert(res.begin(), x_res.begin(), x_res.end()); + } + return res; + }; + return ExprSetFinder(std::function(new_f), x.name + "*" + this->name); +} + +ExprSetFinder ExprSetFinder::GetIdentity() { + return ExprSetFinder( + [](const ir::Expr& e) { return std::vector{e}; }, "identity"); +} + +ExprSetFinder Identity = ExprSetFinder::GetIdentity(); + +ExprSetFinder Store2Value = ExprSetFinder( + [](const ir::Expr& e) -> ExprSet { + if (e.As()) { + return {e.As()->value}; + } + return {}; + }, + "Store2Value"); + +ExprSetFinder Realizer2ScheduleBlock = ExprSetFinder( + [](const ir::Expr& e) -> ExprSet { + if (e.As()) { + return {e.As()->schedule_block}; + } + return {}; + }, + "Realizer2ScheduleBlock"); + +ExprSetFinder ScheduleBlock2Body = ExprSetFinder( + [](const ir::Expr& e) -> ExprSet { + if (e.As()) { + return {e.As()->body}; + } + return {}; + }, + "ScheduleBlock2Body"); + +ExprSetFinder ScheduleBlockRealizeNotRoot = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("root") == std::string::npos); + }, + "ScheduleBlockRealizeNotRoot"); + +ExprSetFinder ScheduleBlockRealizeIsNotInit = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("__reduce_init") == std::string::npos); + }, + "ScheduleBlockRealizeIsNotInit"); + +ExprSetFinder ScheduleBlockRealizeIsInit = FilterMaker( + [](const ir::Expr& e) -> bool { + return (e.As() && + e.As() + ->schedule_block.As() + ->name.find("__reduce_init") != std::string::npos); + }, + "ScheduleBlockRealizeIsInit"); + +ExprSetFinder IsFor = FilterMaker( + [](const ir::Expr& e) -> bool { return e.As(); }, "IsFor"); + +ExprSetFinder ChildScheduleBlocks = + Collector([](const ir::Expr* e) { return e->As(); }, + "ChildScheduleBlocks"); + +ExprSetFinder ChildScheduleBlockRealizes = + Collector( + [](const ir::Expr* e) { return e->As(); }, + "ChildScheduleBlockRealizes") * + ScheduleBlockRealizeNotRoot; + +ExprSetFinder IsForIterVar(const ir::Var& var) { + return FilterMaker( + [var = var](const ir::Expr& e) -> bool { + return e.As() && e.As()->loop_var == var; + }, + "IsForIterVar"); +} + +ExprSetFinder For2Min = ExprSetFinder( + [](const ir::Expr& e) -> ExprSet { return {e.As()->min}; }, + "For2Min"); + +ExprSetFinder For2Max = ExprSetFinder( + [](const ir::Expr& e) -> ExprSet { return {e.As()->extent}; }, + "For2Max"); + +ExprSetFinder ChildStores = Collector( + [](const ir::Expr* e) { return e->As(); }, "ChildStores"); + +ExprSetFinder ChildTensorLoads = Collector( + [](const ir::Expr* e) { + return e->As() && e->As()->is_addr_tensor(); + }, + "ChildLoads"); + +ExprSetFinder ChildTensorStores = Collector( + [](const ir::Expr* e) { + return e->As() && e->As()->is_addr_tensor(); + }, + "ChildTensorStores"); + +ExprSetFinder FilterLoadByTensor(const ir::Tensor& tensor) { + return FilterMaker( + [tensor = tensor](const ir::Expr& e) -> bool { + return e.As() && + e.As()->tensor.as_tensor_ref()->name == tensor->name; + }, + "FilterLoadByTensor(" + tensor->name + ")"); +} + +ExprSetFinder ChildFors = + Collector([](const ir::Expr* e) { return e->As(); }, "ChildFors"); + +ExprSetFinder FindFather(const ir::Expr& root) { + const auto& f = [&](const auto& child) -> ExprSet { + ExprSetFinder find_child = + Collector([child](const ir::Expr* e) { return *e == child; }); + const auto& father_collector = Collector( + [&](const ir::Expr* current) { return !find_child(*current).empty(); }); + return father_collector(root); + }; + return ExprSetFinder(f, "FindFather"); +} +} // namespace ExprSetFinderUtils + +namespace ExprTransformerUtils { +using ExprTransformFunc = std::function; + +ExprTransformer::ExprTransformer(ExprTransformFunc f) { f_ = f; } +ir::Expr ExprTransformer::operator()(const ir::Expr& x) const { return f_(x); } +ExprTransformer ExprTransformer::operator*(const ExprTransformer& x) const { + auto new_f = [self = *this, x = x](const ir::Expr& e) -> ir::Expr { + const auto& rs = self.f_(e); + return x.f_(rs); + }; + return ExprTransformer(std::function(new_f)); +} + +ExprTransformer Identity = ExprTransformer([](const ir::Expr& e) { return e; }); +ExprTransformer WrapForTransformer(const ir::Var& v) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + auto block = e; + if (!block.As()) { + block = ir::Block::Make({e}); + } + return ir::For::Make(v, + v->lower_bound, + v->upper_bound, + ir::ForType::Serial, + ir::DeviceAPI::Host, + block); + }; + return ExprTransformer(f); +} + +ExprTransformer WrapForsTransformer(const std::vector& vs) { + const auto& f = [&](const ir::Expr& e) -> ir::Expr { + ExprTransformer t = Identity; + for (const auto& v : vs) { + t = WrapForTransformer(v) * t; + } + return t(e); + }; + return ExprTransformer(f); +} + +ExprTransformer ChangeTensorLoadTransformer(const ir::Tensor& tensor, + const ir::Expr& dst_load) { + const auto& f = [&](const ir::Expr& e) -> ir::Expr { + auto copied_e = ir::ir_utils::IRCopy(e); + const auto& load = (ExprSetFinderUtils::ChildTensorLoads * + ExprSetFinderUtils::FilterLoadByTensor(tensor)) + .GetSingle(copied_e); + ComposeUtils::MappingTargetExprToDestExprMutator(load, dst_load)(&copied_e); + return copied_e; + }; + return ExprTransformer(f); +} + +void ReplaceTarget(ir::Expr* e, const ir::Expr& t, const ir::Expr dst) { + ComposeUtils::MappingTargetExprToDestExprMutator(t, dst)(e); +} + +ExprTransformer WrapStoreTransformer(const ir::Tensor& tensor, + const std::vector& indices) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + return ir::Store::Make(tensor, e, indices); + }; + return ExprTransformer(f); +} + +std::vector CreateInnerBlockVars( + const std::vector& block_vars) { + int i = 0; + std::vector vars; + for (const auto& v : block_vars) { + vars.emplace_back("inner_block_" + std::to_string(i++)); + vars.back()->is_reduce_axis = v->is_reduce_axis; + } + return vars; +} + +ExprTransformer ChangeVarTransformer(const std::vector& target_vars, + const std::vector& dest_vars) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + return ComposeUtils::CopyedReplaceExpr( + e, + target_vars, + std::vector(dest_vars.begin(), dest_vars.end())); + }; + return ExprTransformer(f); +} + +ExprTransformer WrapReduceOperation(const ir::Reduce::ReduceType& reduce_type, + const ir::Tensor& tensor, + const std::vector& axis_exprs) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + switch (reduce_type) { + case ir::Reduce::kSum: + return ir::Store::Make(tensor, tensor(axis_exprs) + e, axis_exprs); + case ir::Reduce::kMul: + return ir::Store::Make(tensor, tensor(axis_exprs) * e, axis_exprs); + case ir::Reduce::kMax: + return ir::Store::Make( + tensor, ir::Max::Make(tensor(axis_exprs), e), axis_exprs); + case ir::Reduce::kMin: + return ir::Store::Make( + tensor, ir::Min::Make(tensor(axis_exprs), e), axis_exprs); + case ir::Reduce::kAll: + return ir::Store::Make(tensor, tensor(axis_exprs) && e, axis_exprs); + case ir::Reduce::kAny: + return ir::Store::Make(tensor, tensor(axis_exprs) || e, axis_exprs); + default: + CINN_NOT_IMPLEMENTED + } + }; + return ExprTransformer(f); +} + +ExprTransformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + const auto& iter_values = + realize.As()->iter_values; + const auto& iter_vars = realize.As() + ->schedule_block.As() + ->iter_vars; + return ExprTransformerUtils::ChangeVarTransformer( + iter_vars, ComposeUtils::ExprVec2VarVec(iter_values))(e); + }; + return ExprTransformer(f); +} + +ExprTransformer WrapScheduleRealizer(const std::vector& block_vars, + const std::string& tensor_name) { + const auto& f = [=](const ir::Expr& e) -> ir::Expr { + if (e.As()) { + PADDLE_THROW("please input a non-schedule block expr."); + } + const auto& inner_block_var = CreateInnerBlockVars(block_vars); + const auto& replaced_e = + ChangeVarTransformer(block_vars, inner_block_var)(e); + const auto& schedule_block = ir::ScheduleBlock::Make( + inner_block_var, {}, {}, tensor_name, replaced_e); + const auto& schedule_realizer = ir::ScheduleBlockRealize::Make( + std::vector(block_vars.begin(), block_vars.end()), + schedule_block); + return schedule_realizer; + }; + return ExprTransformer(f); +} +} // namespace ExprTransformerUtils + +std::vector GetOpPatternKindVector( + const std::vector<::pir::Operation*>& ops) { + const auto& op_pattern_map = + Operator::GetAttrs("OpPattern"); + std::vector op_patterns; + const auto ConvertToPattern = [&op_pattern_map](const ::pir::Operation* op) { + const std::string cinn_op_name = CompatibleInfo::OpName(*op); + const hlir::framework::Operator* cinn_op = Operator::Get(cinn_op_name); + return op_pattern_map[cinn_op]; + }; + std::transform(ops.begin(), + ops.end(), + std::back_inserter(op_patterns), + ConvertToPattern); + return op_patterns; +} + +bool IsTrivialKind(OpPatternKind kind) { + return kind == OpPatternKind::kElementWise || + kind == OpPatternKind::kBroadcast || kind == OpPatternKind::kInjective; +} + +void CheckFusionInputValid(const std::vector& op_compute_bodies, + const std::vector& op_patterns) { + if (VLOG_IS_ON(4)) { + for (const auto& func : op_compute_bodies) { + VLOG(4) << "FuncBody is :" << func; + } + for (const auto& op_ptn : op_patterns) { + VLOG(4) << "OpPattern is :" << op_ptn; + } + } + VLOG(4) << " op_patterns.size() = " << op_compute_bodies.size(); + VLOG(4) << "op_compute_bodies.size() = " << op_patterns.size(); + PADDLE_ENFORCE_EQ( + op_patterns.size(), op_compute_bodies.size(), "ops and size not equal"); +} + +} // namespace trivial_fusion_detail +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/trivial_op_util.h b/paddle/cinn/hlir/framework/pir/trivial_op_util.h new file mode 100644 index 0000000000000..9dbddc6ada18c --- /dev/null +++ b/paddle/cinn/hlir/framework/pir/trivial_op_util.h @@ -0,0 +1,256 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" +#include "paddle/cinn/hlir/framework/compile_error.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_util.h" +#include "paddle/cinn/hlir/framework/pir/utils.h" +#include "paddle/cinn/hlir/op/external_api_registry.h" +#include "paddle/cinn/hlir/pe/map_expr_to_ir.h" +#include "paddle/cinn/ir/dim.h" +#include "paddle/cinn/ir/group_schedule/base_group_scheduler.h" +#include "paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h" +#include "paddle/cinn/ir/schedule/ir_schedule.h" +#include "paddle/cinn/lang/placeholder.h" +#include "paddle/cinn/optim/schedule_block_dce.h" +#include "paddle/cinn/optim/transform_gpu_forloop.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" + +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +namespace trivial_fusion_detail { + +namespace ComposeUtils { + +template +std::vector ConcatVector(const std::vector& first, + const std::vector& second) { + std::vector result = first; + result.insert(result.end(), second.begin(), second.end()); + return result; +} + +template +std::unordered_map MakeMap(const std::vector& keys, + const std::vector& values) { + std::unordered_map result = std::unordered_map(); + + CHECK(keys.size() == values.size()); + for (int i = 0; i < keys.size(); i++) { + result[keys[i]] = values[i]; + } + return result; +} + +std::vector ExprVec2VarVec(const std::vector& in); +std::vector VarVec2ExprVec(const std::vector& in); + +std::vector GetEachTensorLoadExpr(const ir::Expr& body, + const ir::Tensor& tensor); + +struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> { + explicit MappingTargetExprToDestExprMutator(const ir::Expr& source, + const ir::Expr& dest); + + void operator()(Expr* expr); + + private: + void Visit(const ir::Load* load, Expr* op) override; + void Visit(const ir::Store* store, Expr* op) override; + void Visit(const ir::Reduce* reduce, Expr* op) override; + + private: + ir::Expr source_; + ir::Expr dest_; +}; + +bool CheckIterEq(const std::vector& up_iter, + const std::vector& down_iter); + +ir::Expr CopyedReplaceExpr(const Expr& source, + const std::vector& replaced, + const std::vector& candidates); +void SubstitudeTargetExprWithDestExpr(const ir::Expr& source, + const ir::Expr& dest, + ir::Expr* body); + +ir::Expr SubstitudeIndexVector(const Expr& source, + const std::vector& load_vars, + const std::vector& indices); + +template +void ReplaceDownstreamLoadExprWithUpstreamComputeBody( + const FusionOp& upstream, + const ir::Expr& downstream_load_expr, + ir::Expr* downstream_body) { + ComposeUtils::SubstitudeTargetExprWithDestExpr( + downstream_load_expr, + ComposeUtils::SubstitudeIndexVector( + GetComputeBody(upstream), + GetOutputIters(upstream), + downstream_load_expr.As()->indices), + downstream_body); +} +} // namespace ComposeUtils + +namespace ExprSetFinderUtils { + +using ExprSet = std::vector; +using Expr2ExprSet = std::function; +struct ExprSetFinder { + Expr2ExprSet f_; + std::string name; + explicit ExprSetFinder(Expr2ExprSet f, std::string s = ""); + + ExprSet operator()(const ir::Expr& x) const; + ir::Expr GetSingle(const ir::Expr& x) const; + ExprSetFinder operator*(ExprSetFinder x) const; + static ExprSetFinder GetIdentity(); +}; + +template +ExprSetFinder Collector(Teller t, std::string name = "") { + return ExprSetFinder( + [=](const ir::Expr& x) -> ExprSet { + const auto& rs = cinn::ir::ir_utils::CollectIRNodesWithoutTensor(x, t); + return std::vector(rs.begin(), rs.end()); + }, + name); +} + +template +ExprSetFinder FilterMaker(FilterFunc t, std::string name) { + return ExprSetFinder( + [=](const ir::Expr& x) -> ExprSet { + if (t(x)) { + return {x}; + } + return {}; + }, + name); +} + +extern ExprSetFinder Identity; + +extern ExprSetFinder Store2Value; + +extern ExprSetFinder Realizer2ScheduleBlock; + +extern ExprSetFinder ScheduleBlock2Body; + +extern ExprSetFinder ScheduleBlockRealizeNotRoot; + +extern ExprSetFinder ScheduleBlockRealizeIsNotInit; + +extern ExprSetFinder ScheduleBlockRealizeIsInit; + +extern ExprSetFinder IsFor; + +extern ExprSetFinder ChildScheduleBlocks; + +extern ExprSetFinder ChildScheduleBlockRealizes; + +extern ExprSetFinder For2Min; + +extern ExprSetFinder For2Max; + +extern ExprSetFinder ChildStores; + +extern ExprSetFinder ChildTensorLoads; + +extern ExprSetFinder ChildTensorStores; + +extern ExprSetFinder ChildFors; + +ExprSetFinder IsForIterVar(const ir::Var& var); + +ExprSetFinder FilterLoadByTensor(const ir::Tensor& tensor); + +ExprSetFinder FindFather(const ir::Expr& root); + +template +std::vector MapVector(const std::vector& as, M func) { + std::vector res; + for (const auto& a : as) { + res.push_back(func(a)); + } + return res; +} +} // namespace ExprSetFinderUtils + +namespace ExprTransformerUtils { +using ExprTransformFunc = std::function; +struct ExprTransformer { + ExprTransformFunc f_; + explicit ExprTransformer(ExprTransformFunc f); + ir::Expr operator()(const ir::Expr& x) const; + ExprTransformer operator*(const ExprTransformer& x) const; +}; + +extern ExprTransformer Identity; + +ExprTransformer WrapForTransformer(const ir::Var& v); + +ExprTransformer WrapForsTransformer(const std::vector& vs); +ExprTransformer ChangeTensorLoadTransformer(const ir::Tensor& tensor, + const ir::Expr& dst_load); + +void ReplaceTarget(ir::Expr* e, const ir::Expr& t, const ir::Expr dst); + +ExprTransformer WrapStoreTransformer(const ir::Tensor& tensor, + const std::vector& indices); + +ExprTransformer WrapReduceOperation(const ir::Reduce::ReduceType& reduce_type, + const ir::Tensor& tensor, + const std::vector& axis_exprs); + +std::vector CreateInnerBlockVars( + const std::vector& block_vars); + +ExprTransformer ChangeVarTransformer(const std::vector& target_vars, + const std::vector& dest_vars); + +ExprTransformer SubstitudeByScheduleBlockRealize(const ir::Expr& realize); + +ExprTransformer WrapScheduleRealizer(const std::vector& block_vars, + const std::string& tensor_name); +} // namespace ExprTransformerUtils + +std::vector GetOpPatternKindVector( + const std::vector<::pir::Operation*>& ops); + +template +void SequenceMutator(const std::vector& as, C* acc, const Func& mutator) { + VLOG(4) << "SequenceTransform Init: " << acc; + for (int i = 0; i < as.size(); ++i) { + mutator(as[i], acc); + VLOG(4) << "SequenceTransform Iter: " << acc; + } +} + +bool IsTrivialKind(OpPatternKind kind); + +void CheckFusionInputValid(const std::vector& op_compute_bodies, + const std::vector& op_patterns); + +} // namespace trivial_fusion_detail +} // namespace pir +} // namespace framework +} // namespace hlir +} // namespace cinn diff --git a/paddle/cinn/hlir/framework/pir/utils.cc b/paddle/cinn/hlir/framework/pir/utils.cc index 83fe4ed5ef16c..942bf35f3f8eb 100644 --- a/paddle/cinn/hlir/framework/pir/utils.cc +++ b/paddle/cinn/hlir/framework/pir/utils.cc @@ -23,6 +23,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/pir/op_mapper.h" +#include "paddle/common/enforce.h" #include "paddle/common/flags.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -32,6 +33,7 @@ #include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" PD_DECLARE_string(allow_cinn_ops); PD_DECLARE_string(deny_cinn_ops); @@ -48,6 +50,8 @@ const std::unordered_map CompatibleInfo::OP_NAMES = { {"pd_op.full", "fill_constant"}, {"pd_op.sum", "reduce_sum"}, {"pd_op.max", "reduce_max"}, + {"pd_op.min", "reduce_min"}, + {"pd_op.prod", "reduce_prod"}, {"pd_op.add", "elementwise_add"}, {"pd_op.elementwise_pow", "pow"}, {"pd_op.multiply", "elementwise_mul"}, @@ -67,6 +71,26 @@ using GroupOpsVec = std::vector<::pir::Operation*>; // & FLAGS_deny_cinn_ops. constexpr char kDelim[] = ";"; +std::unordered_set StringSplit(const std::string& str, + const std::string& delim) { + std::regex reg(delim); + std::unordered_set elems{ + std::sregex_token_iterator(str.begin(), str.end(), reg, -1), + std::sregex_token_iterator()}; + elems.erase(""); + return elems; +} + +std::string GetDebugInfo(const std::unordered_set& names) { + std::string debug_info = "["; + for (auto& name : names) { + debug_info.append(name); + debug_info.append(", "); + } + debug_info.append("]"); + return debug_info; +} + // OpTransInfo contains informations used to detect subgraphs // supported by the CINN compiler. class OpTransInfo { @@ -77,8 +101,24 @@ class OpTransInfo { OpTransInfo() {} const DeParamCondT& deny_param_cond() const { return deny_param_cond_; } - const std::unordered_set& default_deny_ops() const { - return default_deny_ops_; + bool IsDeniedByDefault(const std::string& op_name) const { + return default_deny_ops_.count(op_name) || IsDeniedInFLAGS(op_name); + } + + bool IsDeniedInFLAGS(const std::string& op_name) const { + auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); + auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); + if (VLOG_IS_ON(4)) { + LOG_FIRST_N(INFO, 1) << "The allowed Cinn Ops: " + << GetDebugInfo(allow_ops); + LOG_FIRST_N(INFO, 1) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops); + } + if (!allow_ops.empty()) { + return allow_ops.count(op_name) == 0U; + } else if (!deny_ops.empty()) { + return deny_ops.count(op_name); + } + return false; } private: @@ -86,30 +126,37 @@ class OpTransInfo { {"batch_norm_grad", {"ReserveSpace"}}}; std::unordered_set default_deny_ops_{ - "feed", "fetch", "conv2d", "conv2d_grad", "dropout", "matmul"}; + "feed", + "fetch", + "conv2d", + "conv2d_grad", + "depthwise_conv2d", + "depthwise_conv2d_grad", + "dropout", + "pool2d", + "pool2d_grad", + "split", + "matmul", + "matmul_grad", + "embedding_grad", + "embedding", + "arange", + }; }; -std::unordered_set StringSplit(const std::string& str, - const std::string& delim) { - std::regex reg(delim); - std::unordered_set elems{ - std::sregex_token_iterator(str.begin(), str.end(), reg, -1), - std::sregex_token_iterator()}; - elems.erase(""); - return elems; -} - -std::string GetDebugInfo(const std::unordered_set& names) { - std::string debug_info = "["; - for (auto& name : names) { - debug_info.append(name); - debug_info.append(", "); +std::string OpNameAfterStripDialect(const ::pir::Operation& op) { + std::string name = op.name(); + auto pos = name.find("."); + if (pos == std::string::npos) { + return name; } - debug_info.append("]"); - return debug_info; + auto op_name = name.substr(pos + 1); + VLOG(7) << "GetOpName: " << name << " -> " << op_name; + CHECK(op_name != "") << "Not Allow op name is empty"; + return op_name; } -bool IsSupportForCinn(const ::pir::Operation& op); +bool IsSupportInCinn(const ::pir::Operation& op); // In case of op has some attributes generated by FullOp, it need // implement OpPattern in pd_to_cinn_pass. Otherwise, we mark them @@ -120,7 +167,7 @@ bool UnimplementOps(const ::pir::Operation& op) { if (op.isa()) { auto out = op.result(0); if (out.use_count() > 0) { - return !IsSupportForCinn(*(out.first_use().owner())); + return !IsSupportInCinn(*(out.first_use().owner())); } } return false; @@ -131,6 +178,21 @@ bool HaveZeroDimInput(const ::pir::Operation& op) { auto tensor_type = type.dyn_cast<::pir::DenseTensorType>(); return tensor_type && tensor_type.dims().size() == 0U; }; + + auto HasNegDim = [](const ::pir::Type& type) { + auto tensor_type = type.dyn_cast<::pir::DenseTensorType>(); + + if (tensor_type) { + for (size_t i = 0; i < tensor_type.dims().size(); ++i) { + if (tensor_type.dims()[i] < 0) { + return true; + } + } + } + + return false; + }; + // Judge for vector auto HasZeroDimInVT = [&](const std::vector<::pir::Type>& types) { for (auto& type : types) { @@ -144,7 +206,7 @@ bool HaveZeroDimInput(const ::pir::Operation& op) { if (!value || !value.type()) continue; if (auto vector_type = value.type().dyn_cast<::pir::VectorType>()) { if (HasZeroDimInVT(vector_type.data())) return true; - } else if (HasZeroDim(value.type())) { + } else if (HasZeroDim(value.type()) || HasNegDim(value.type())) { return true; } } @@ -152,12 +214,13 @@ bool HaveZeroDimInput(const ::pir::Operation& op) { } bool AllInputDenseTensor(const ::pir::Operation& op) { - auto IsDenseTensor = [](const ::pir::Type& type) { + const auto& IsDenseTensor = [](const ::pir::Type& type) -> bool { return type.isa<::pir::DenseTensorType>(); }; // Judge for vector - auto IsAllDenseTensor = [&](const std::vector<::pir::Type>& types) { + const auto& IsAllDenseTensor = + [&](const std::vector<::pir::Type>& types) -> bool { for (auto& type : types) { if (!IsDenseTensor(type)) return false; } @@ -177,58 +240,164 @@ bool AllInputDenseTensor(const ::pir::Operation& op) { return true; } -bool IsRegisteredInCINN(const ::pir::Operation& op) { - if (CompatibleInfo::OP_NAMES.find(op.name()) != - CompatibleInfo::OP_NAMES.end()) { - return true; - } - return OpRegistry::Global()->Find(CompatibleInfo::OpName(op)) != nullptr; +bool IsSmallNumelOp(const ::pir::Operation& op) { + const auto& GetNumElementsFromDim = [](const ::pir::DDim& dim) -> int64_t { + if (::common::contain_unknown_dim(dim)) { + return std::numeric_limits::max(); + } else { + return ::common::product(dim); + } + }; + + const auto& GetNumElementsFromValue = + [&](const ::pir::Value& value) -> int64_t { + int64_t numel = -1; + if (value && value.type()) { + auto type = value.type().dyn_cast<::pir::DenseTensorType>(); + if (type) { + numel = GetNumElementsFromDim(type.dims()); + } + } + return numel; + }; + const int64_t max_value_numel = [&] { + int64_t max_value_numel = -1; + if (op.num_operands() == 0) { // no input + return max_value_numel; + } + + for (uint32_t i = 0; i < op.num_operands(); ++i) { + max_value_numel = std::max(GetNumElementsFromValue(op.operand_source(i)), + max_value_numel); + } + for (uint32_t i = 0; i < op.num_results(); ++i) { + max_value_numel = + std::max(GetNumElementsFromValue(op.result(i)), max_value_numel); + } + return max_value_numel; + }(); + + // max value check + return (0 <= max_value_numel && max_value_numel < 32); } -bool IsSupportForCinn(const ::pir::Operation& op) { - if (!AllInputDenseTensor(op) || HaveZeroDimInput(op) || UnimplementOps(op)) { - VLOG(4) << "Found " << op.name() - << " HaveZeroDimInput or UnimplementOps or NotAllInputDenseTensor. " - << "So mark IsSupportForCinn: " << false; +bool IsShapeComputeOp(const ::pir::Operation& op) { + const auto& shape_analysis = ::pir::ShapeAnalysisManager::Instance().Get( + op.GetParent()->parent_program()); + if (op.num_operands() == 0) { return false; } - auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); - auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); - LOG_FIRST_N(INFO, 1) << "The allowed Cinn Ops: " << GetDebugInfo(allow_ops); - LOG_FIRST_N(INFO, 1) << "The denied Cinn Ops: " << GetDebugInfo(deny_ops); - // Strip the dialect, like pd_op.abs -> abs - const auto op_name = CompatibleInfo::OpName(op); - - OpTransInfo trans_info; - bool is_support = - IsRegisteredInCINN(op) && !trans_info.default_deny_ops().count(op_name); - VLOG(4) << op_name << " is_support: " << is_support - << " IsRegisteredInCINN: " << IsRegisteredInCINN(op); - // if the op type is registered in CINN and allow_ops is not empty, return - // true only when it is in allow_ops - if (!allow_ops.empty()) { - return is_support && allow_ops.count(op_name); + bool all_input_has_shape_data = true; + for (uint32_t i = 0; i < op.num_operands(); ++i) { + if (shape_analysis.HasShapeOrDataForValue(op.operand_source(i))) { + const auto& shape_expr = + shape_analysis.GetShapeOrDataForValue(op.operand_source(i)); + if (shape_expr.isa() && + shape_expr.data()) { // has shape data + continue; + } + } + all_input_has_shape_data = false; + break; } - // if the op type is registered in CINN and deny_ops is not empty, return - // true only when it is not in deny_ops - if (!deny_ops.empty()) { - return is_support && !deny_ops.count(op_name); + + for (uint32_t i = 0; i < op.num_results(); ++i) { + if (shape_analysis.HasShapeOrDataForValue(op.result(i))) { + const auto& shape_expr = + shape_analysis.GetShapeOrDataForValue(op.result(i)); + if (shape_expr.isa() && + shape_expr.data()) { // has shape data + continue; + } + } + all_input_has_shape_data = false; + break; } - // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, - // return true only when it is registered in CINN - return is_support; + return all_input_has_shape_data; +} + +// TODO(zyfncg): This function is a temporary solution, we need to remove it in +// the future. +bool IsTempDenySpecialOp(const ::pir::Operation& op) { + if (op.name() == "cinn_op.generate_shape") { + return false; + } + return IsShapeComputeOp(op); +} + +// Mainly used for pd_to_cinn_pass and reused in IsSupportInCinn function. +bool IsDeniedInCinn(const ::pir::Operation& op) { + if (!AllInputDenseTensor(op) || UnimplementOps(op)) { + VLOG(5) << "Found " << op.name() + << " UnimplementOps or NotAllInputDenseTensor. " + << "So mark IsDeniedForCinn: " << true; + return true; + } + + // Strip the dialect, like pd_op.abs -> abs + const auto op_name = OpNameAfterStripDialect(op); + const bool is_denied = OpTransInfo().IsDeniedByDefault(op_name); + VLOG(5) << op_name << " is denied in FLAGS or defaultly: " << is_denied; + return is_denied; +} + +bool IsRegisteredInCINN(const ::pir::Operation& op) { + return OpRegistry::Global()->Find(CompatibleInfo::OpName(op)) != nullptr; +} + +#define PD_OP_NAME(op) paddle::dialect::op::name() +// For op supports AttributeTensor but has handled in +// pd_to_cinn_pass. Such as cinn_op.reshape, except pd_op.reshape; +const std::unordered_set TOCINN_OPS = { + PD_OP_NAME(SumOp), + PD_OP_NAME(MaxOp), + PD_OP_NAME(MinOp), + PD_OP_NAME(ProdOp), + PD_OP_NAME(PowOp), + PD_OP_NAME(ScaleOp), + PD_OP_NAME(Pool2dOp), + PD_OP_NAME(IscloseOp), + PD_OP_NAME(SliceOp), + PD_OP_NAME(ConcatOp), + PD_OP_NAME(SplitOp), + PD_OP_NAME(SplitWithNumOp), + PD_OP_NAME(AddNOp), + PD_OP_NAME(UniformOp), +}; +#undef PD_OP_NAME + +bool HasHandledInPass(const ::pir::Operation& op) { + return TOCINN_OPS.count(op.name()) == 0U; } -} // namespace // In following cases, the op is marked SupportCinn: -// 1. its name is in OP_NAMES, like pd_op.sum; -// 2. it supports AttributeTensor but has Pattern to process it. -// Such as cinn_op.reshape, except pd_op.reshape; -// 3. otherwise, it should be registered in OpRegistry; -bool CompatibleInfo::IsSupportCinn(const ::pir::Operation& op) { - bool flag = IsSupportForCinn(op); - VLOG(4) << "CompatibleInfo::IsSupportCinn of " << op.name() +// 1. it is NOT denied in IsDeniedInCinn(op) +// 2. it should be registered in OpRegistry; +// 3. it should be handled in pd_to_cinn_pass; +bool IsSupportInCinn(const ::pir::Operation& op) { + const bool is_denied = IsDeniedInCinn(op); + const bool is_registered = IsRegisteredInCINN(op); + const bool is_handled = HasHandledInPass(op); + VLOG(5) << op.name() << ": IsDeniedInCinn = " << is_denied + << ", IsRegisteredInCINN = " << is_registered + << ", HasHandledInPass = " << is_handled; + return !is_denied && is_registered && is_handled; +} +} // namespace + +bool CompatibleInfo::IsDeniedForCinn(const ::pir::Operation& op) { + bool flag = IsDeniedInCinn(op); + VLOG(4) << "CompatibleInfo::IsDeniedForCinn of " << op.name() + << " is: " << flag; + return flag; +} + +bool CompatibleInfo::IsSupportForCinn(const ::pir::Operation& op) { + const bool not_builtin_op = op.dialect()->name() != "builtin"; + const bool flag = IsSupportInCinn(op) && not_builtin_op; + + VLOG(4) << "CompatibleInfo::IsSupportForCinn of " << op.name() << " is: " << flag; return flag; } @@ -238,16 +407,7 @@ std::string CompatibleInfo::OpName(const ::pir::Operation& op) { if (OP_NAMES.count(name)) { return OP_NAMES.at(name); } - auto pos = name.find("."); - if (pos == std::string::npos) { - return name; - } - auto cinn_op_name = name.substr(pos + 1); - VLOG(7) << "GetOpName: " << name << " -> " << cinn_op_name; - CHECK(cinn_op_name != "") - << "Found empty cinn_op_name, maybe you should implement OpPattern for " - << name; - return cinn_op_name; + return OpNameAfterStripDialect(op); } std::string CompatibleInfo::OpFuncName(const ::pir::Operation& op) { @@ -314,13 +474,24 @@ static utils::Attribute ConvertArrayAttribute( CASE_ATTRIBUTE(float, FloatAttribute) } else if (attr_vec[0].isa<::pir::DoubleAttribute>()) { CASE_ATTRIBUTE(double, DoubleAttribute) + } else if (attr_vec[0].isa<::pir::StrAttribute>()) { + std::vector dst_attr; + for (auto element : attr_vec) { + dst_attr.push_back( + element.dyn_cast<::pir::StrAttribute>().AsString()); + } } else { - LOG(FATAL) << "only support bool/int32/int64/float/double attribute in " - "ArrayAttribute"; + PADDLE_THROW(phi::errors::InvalidArgument( + "only support bool/int32/int64/float/double/string attribute in " + "ArrayAttribute")); } } + } else if (src_attr.isa<::pir::shape::SymbolAttribute>()) { + // do nothing for now } else { - LOG(FATAL) << "unknown Attribute: " << src_attr; + std::stringstream ss; + ss << "unknown Attribute: " << src_attr; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return dst_attr; } @@ -352,7 +523,9 @@ utils::AttributeMap CompatibleInfo::ConvertAttributes( utils::AttributeMap dst_attrs; for (auto& item : src_attrs) { VLOG(4) << "deal with " << item.first; - if (item.first == ::pir::kStopGradientAttrName) { + if (item.first == ::pir::kStopGradientAttrName || + item.first == ::pir::kOutputDimExprs || + item.first == ::pir::kSymbolBindings) { continue; } else if (item.second.isa()) { auto is_cpu = @@ -387,7 +560,9 @@ cinn::common::Type CompatibleInfo::ConvertIRType(::pir::Type type) { CASE_TYPE(IndexType, I32) CASE_TYPE(BoolType, UI1) - LOG(FATAL) << "unknown ir::Type " << type; + std::stringstream ss; + ss << "unknown ir::Type " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } #undef CASE_TYPE @@ -399,7 +574,7 @@ OpPatternKind CompatibleInfo::OpKind(const ::pir::Operation& op) { auto& op_pattern_dict = Operator::GetAttrs("OpPattern"); auto op_name = CompatibleInfo::OpName(op); if (op_name == "generate_shape") { - return hlir::framework::kNonFusible; + return hlir::framework::kElementWise; } const hlir::framework::Operator* cinn_op = Operator::Get(op_name); CHECK(op_pattern_dict.Find(cinn_op)); diff --git a/paddle/cinn/hlir/framework/pir/utils.h b/paddle/cinn/hlir/framework/pir/utils.h index 225f16f5caad2..c489e1847f26f 100644 --- a/paddle/cinn/hlir/framework/pir/utils.h +++ b/paddle/cinn/hlir/framework/pir/utils.h @@ -30,6 +30,7 @@ namespace framework { namespace pir { struct CINNKernelInfo { + std::string fn_name; void* fn_ptr; void* infer_shape_fn_ptr; @@ -54,16 +55,17 @@ struct CINNKernelInfo { struct CompatibleInfo { static constexpr char* kNamePrefix = "var"; - // TODO(Aurelius): Need add name mapping logic in REGISTER_CINN_OP - // macros or attempt to unify Op name with Paddle and CINN. - static const std::unordered_map OP_NAMES; // NOTE(Aurelius): Some ops in CINN register different // name between OpMapper and Compute/Schedule, such as // 'subtract': 1. OpMapper: 'elementwise_sub'; 2. Compute/Schedule: // 'subtract'. - static const std::unordered_set CINN_WHITE_OPS; + static const std::unordered_map OP_NAMES; + + static const std::unordered_set TOCINN_OPS; + + static bool IsDeniedForCinn(const ::pir::Operation& op); - static bool IsSupportCinn(const ::pir::Operation& op); + static bool IsSupportForCinn(const ::pir::Operation& op); static std::string OpName(const ::pir::Operation& op); @@ -122,10 +124,12 @@ struct ScheduleInfoNode { // TOOD(phlrain): update align type by new loop alignment ScheduleAlignType type{ScheduleAlignType::kNone}; + // reduction or broadcast axis locations std::vector axis_info; + // representing the iteration space std::vector factor_info; - std::string DebugStr() { + std::string DebugStr() const { std::stringstream ss; ss << "type " << static_cast(type) << "| axis info "; diff --git a/paddle/cinn/hlir/framework/pir_compiler.cc b/paddle/cinn/hlir/framework/pir_compiler.cc index 1cd7b0220b496..2db39508ce1e1 100644 --- a/paddle/cinn/hlir/framework/pir_compiler.cc +++ b/paddle/cinn/hlir/framework/pir_compiler.cc @@ -14,216 +14,25 @@ #include "paddle/cinn/hlir/framework/pir_compiler.h" -#include -#include "paddle/cinn/hlir/framework/pir/compilation_task.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/cinn/utils/multi_threading.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/pir/include/core/builtin_type.h" -#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" -PD_DECLARE_bool(cinn_bucket_compile); +namespace cinn::hlir::framework { -namespace cinn { -namespace hlir { -namespace framework { - -// TODO(Aurelius84): Clear useless Build Interface. -std::unique_ptr PirCompiler::Build() { - m_builder_.Clear(); - // NOTE(Aurelius84): Currently only support each op for one group - std::vector groups; - for (auto& op : *program_.block()) { - if (op.isa<::pir::YieldOp>()) { - continue; - } - std::vector<::pir::Operation*> ops = {&op}; - auto group = std::make_shared(ops); - group->output_ops.insert(&op); - groups.push_back(group); - } - VLOG(4) << "Groups size: " << groups.size(); - return std::move(Build(groups)); -} - -std::vector PirCompiler::BuildCUDAJITInfo( - const std::vector& groups) { - std::vector cinn_kernel_info_vecs(groups.size()); - - if (FLAGS_cinn_bucket_compile) { - for (int i = 0; i < groups.size(); ++i) { - group_compilation_contexts_.emplace_back(target_, groups[i], scope_); - } - auto worker_fn = [&](int index) { - CompilationTask task(&group_compilation_contexts_[index]); - task(); - cinn_kernel_info_vecs[index] = task.BuildPirCINNKernelInfo(); - }; - utils::parallel_run( - worker_fn, utils::SequenceDispatcher(0, groups.size()), -1); - } else { - auto op_lowerer = CreateOpLowerer(target_); - - std::vector> lowered_funcs; - for (int i = 0; i < groups.size(); ++i) { - lowered_funcs.emplace_back(op_lowerer.Lower(groups[i])); - } - - for (auto&& lowered_func : lowered_funcs) { - ProcessFunction(lowered_func); - } - compiler_ = backends::Compiler::Create(target_); - auto build_module = m_builder_.Build(); - compiler_->Build(build_module, ""); - - auto fn_ptrs = compiler_->GetFnPtr(); - - for (int idx = 0; idx < groups.size(); ++idx) { - pir::CINNKernelInfo cinn_kernel_info; - auto fn_name = groups[idx]->FuncName(); - auto fn_ptr = compiler_->Lookup(fn_name); - cinn_kernel_info.fn_ptr = fn_ptr; - cinn_kernel_info.int_args_map = groups[idx]->int_args_map; - - cinn_kernel_info_vecs[idx] = cinn_kernel_info; - } - } - return cinn_kernel_info_vecs; -} - -std::unique_ptr PirCompiler::Build( - const std::vector& groups) { - std::vector> instructions(groups.size()); - if (FLAGS_cinn_bucket_compile) { - for (int i = 0; i < groups.size(); ++i) { - group_compilation_contexts_.emplace_back(target_, groups[i], scope_); - } - auto worker_fn = [&](int index) { - CompilationTask task(&group_compilation_contexts_[index]); - task(); - instructions[index] = task.BuildInstruction(); - }; - utils::parallel_run( - worker_fn, utils::SequenceDispatcher(0, groups.size()), -1); - } else { - auto op_lowerer = CreateOpLowerer(target_); - - std::vector> lowered_funcs; - for (int i = 0; i < groups.size(); ++i) { - lowered_funcs.emplace_back(op_lowerer.Lower(groups[i])); - } - - for (auto&& lowered_func : lowered_funcs) { - ProcessFunction(lowered_func); - } - - compiler_ = backends::Compiler::Create(target_); - auto build_module = m_builder_.Build(); - compiler_->Build(build_module, ""); - - instructions = BuildInstructions(groups); +std::vector PirCompiler::Build( + const std::vector& groups) { + std::vector kernel_infos(groups.size()); + for (int i = 0; i < groups.size(); ++i) { + group_compilation_contexts_.emplace_back(target_, groups[i]); } - - // TODO(Aurelius84): Instantiate all tensors on compile-time, which is - // controlled by 'options.with_instantiate_variables' in GraphCompiler. - // Moreover, it's better to implement InsertBufferHandlers() logic - // to automatically insert Malloc and Free instructions. - for (auto& name : scope_->var_names()) { - std::string var_name({name.data(), name.size()}); - VLOG(4) << "Instantiate " << var_name << " on compile-time"; - auto* var = scope_->Var(var_name); - auto& tensor = absl::get(*var); - tensor->mutable_data(target_, tensor->type()); - } - return std::make_unique(scope_, std::move(instructions)); -} - -void PirCompiler::ProcessFunction( - const std::vector& lowered_funcs) { - for (auto&& func : lowered_funcs) { - for (auto&& arg : func->args) { - std::string arg_name = arg.name(); - if (arg_name[0] == '_') arg_name = arg_name.substr(1); - - auto* var = scope_->FindVar(arg_name); - // For argument buffer not in scope, create it. - if (!var && arg.is_buffer()) { - auto* new_var = scope_->Var(arg_name); - auto& tensor = absl::get(*new_var); - std::vector shape; - for (auto& shape_dim : arg.buffer_arg()->shape) { - CHECK(shape_dim.is_constant()); - shape.push_back(static_cast(shape_dim.get_constant())); - } - tensor->Resize(Shape{shape}); - tensor->set_type(arg.buffer_arg()->dtype); - } - } - m_builder_.AddFunction(func); - } -} - -std::vector> PirCompiler::BuildInstructions( - const std::vector& groups) { - std::vector> instructions; - for (int idx = 0; idx < groups.size(); ++idx) { - auto fn_name = groups[idx]->FuncName(); - auto instr = - std::unique_ptr(new Instruction(target_, - scope_.get(), - groups[idx]->input_names, - groups[idx]->output_names, - fn_name)); - VLOG(4) << "Lookup kernel name: " << fn_name; - auto* fn_ptr = compiler_->Lookup(fn_name); - CHECK(fn_ptr); - instr->SetLoweredFunc(reinterpret_cast(fn_ptr), fn_name); - // As some instruction like reduce, will generate more than one kernel. - // So try to find the rest kernel, if it exists. - // SetSubKernels(instr.get(), fn_name); - instr->Finalize(); - instructions.push_back(std::move(instr)); - } - return instructions; -} - -std::shared_ptr BuildScope(const Target& target, - const ::pir::Program& program) { - std::unordered_set<::pir::Value> visited; - auto scope = std::make_shared(); - - auto create_var = [&](::pir::Value value) { - if (!(value) || !(value.type())) { - return; - } - if (visited.count(value) > 0) return; - visited.emplace(value); - - std::string name = pir::CompatibleInfo::ValueName(value); - auto type_info = value.type().dyn_cast(); - auto* var = scope->Var(name); - auto& tensor = absl::get(*var); - - std::vector shape; - for (auto i = 0; i < type_info.dims().size(); ++i) { - shape.push_back(Shape::dim_t(type_info.dims()[i])); - } - tensor->Resize(Shape{shape}); - tensor->set_type(pir::CompatibleInfo::ConvertIRType(type_info.dtype())); + auto worker_fn = [&](int index) { + CompilationTask task(&group_compilation_contexts_[index]); + task(); + kernel_infos[index] = task.GetCINNKernelInfo(); }; - - for (auto& op : *program.block()) { - for (auto operand : op.operands()) { - create_var(operand.source()); - } - - for (auto result : op.results()) { - create_var(result); - } - } - return scope; + utils::parallel_run( + worker_fn, utils::SequenceDispatcher(0, groups.size()), -1); + return kernel_infos; } -} // namespace framework -} // namespace hlir -} // namespace cinn +} // namespace cinn::hlir::framework diff --git a/paddle/cinn/hlir/framework/pir_compiler.h b/paddle/cinn/hlir/framework/pir_compiler.h index 5edf5e25bf46b..d9429b76a6fa8 100644 --- a/paddle/cinn/hlir/framework/pir_compiler.h +++ b/paddle/cinn/hlir/framework/pir_compiler.h @@ -15,86 +15,23 @@ #pragma once #include -#include #include "paddle/cinn/common/macros.h" -#include "paddle/pir/include/core/program.h" - -#include "paddle/cinn/hlir/framework/graph_compiler.h" -#include "paddle/cinn/hlir/framework/op_lowering.h" #include "paddle/cinn/hlir/framework/pir/compilation_task.h" -namespace cinn { -namespace hlir { -namespace framework { +namespace cinn::hlir::framework { -// TODO(Aurelius84): Need abstract this logic to implement Proxy for -// the co-existence with GraphCompiler. class PirCompiler final { public: - PirCompiler(const ::pir::Program& prog, - const Target& target, - const std::shared_ptr& scope) - : program_(prog), - m_builder_("Pir", target), - target_(target), - scope_(scope) {} - - std::unique_ptr Build(); + PirCompiler(const Target& target) : target_(target) {} - std::vector BuildCUDAJITInfo( - const std::vector& groups); - - std::unique_ptr Build(const std::vector& groups); + std::vector Build( + const std::vector& groups); private: CINN_DISALLOW_COPY_AND_ASSIGN(PirCompiler); - std::vector GetOpFunc(const ::pir::Operation& op, int idx); - - void ProcessFunction(const std::vector& lowered_funcs); - - std::vector> BuildInstructions( - const std::vector& groups); - - const ::pir::Program& program_; - ir::Module::Builder m_builder_; - std::unique_ptr compiler_{nullptr}; Target target_; - std::shared_ptr scope_; - std::unordered_map func_names_; std::vector group_compilation_contexts_; }; -// TODO(phlrain): pir compiler don't need Scope, need to remove this -std::shared_ptr BuildScope(const Target&, const ::pir::Program&); - -class PirCompilerManager { - public: - static PirCompilerManager& Instance() { - static PirCompilerManager instance; - return instance; - } - - static std::shared_ptr Create( - const ::pir::Program& prog, - const Target& target, - const std::shared_ptr& scope) { - std::shared_ptr compiler = - std::make_shared(prog, target, scope); - PirCompilerManager::Instance().insert(compiler); - return compiler; - } - - void insert(const std::shared_ptr& compiler) { - compilers_.push_back(compiler); - } - - void clear() { compilers_.clear(); } - - private: - std::vector> compilers_; -}; - -} // namespace framework -} // namespace hlir -} // namespace cinn +} // namespace cinn::hlir::framework diff --git a/paddle/cinn/hlir/framework/program.cc b/paddle/cinn/hlir/framework/program.cc index eadbfdf4d7d2c..0e00795ae775d 100644 --- a/paddle/cinn/hlir/framework/program.cc +++ b/paddle/cinn/hlir/framework/program.cc @@ -44,22 +44,22 @@ void Program::PreRun( void Program::Export(const std::vector& persistent_vars, const std::string& filename) { - auto writeplaceholder = [=](int s, int n, FILE* f) -> int { + auto write_placeholder = [=](int s, int n, FILE* f) -> int { int pos = ftell(f); for (int i = 0; i < s * n; i++) { fwrite("\0", 1, 1, f); } return pos; }; - auto setplaceholder = [=](int p, void* b, int s, int n, FILE* f) { + auto set_placeholder = [=](int p, void* b, int s, int n, FILE* f) { int cur = ftell(f); fseek(f, p, SEEK_SET); fwrite(b, s, n, f); fseek(f, cur, SEEK_SET); }; - auto tellplaceholder = [=](int p, FILE* f) { + auto tell_placeholder = [=](int p, FILE* f) { int cur = ftell(f); - setplaceholder(p, &cur, 4, 1, f); + set_placeholder(p, &cur, 4, 1, f); }; auto padding = [=](int alignment, uint8_t value, FILE* f) { int cur = ftell(f); @@ -69,9 +69,9 @@ void Program::Export(const std::vector& persistent_vars, } }; auto varnames = scope_->var_names(); - std::unordered_map varindex; + std::unordered_map var_index; for (int i = 0; i < varnames.size(); i++) { - varindex[(std::string)varnames[i]] = i; + var_index[(std::string)varnames[i]] = i; } FILE* f = fopen(filename.c_str(), "w+"); @@ -85,25 +85,25 @@ void Program::Export(const std::vector& persistent_vars, fwrite(&unused_v, 4, 1, f); // varname list - int varnamesec = writeplaceholder(4, 1, f); - int namesnum = varnames.size(); - fwrite(&namesnum, 4, 1, f); - int nameoffset = writeplaceholder(4, namesnum, f); - for (int i = 0; i < namesnum; i++) { + int varname_sec = write_placeholder(4, 1, f); + int names_num = varnames.size(); + fwrite(&names_num, 4, 1, f); + int name_offset = write_placeholder(4, names_num, f); + for (int i = 0; i < names_num; i++) { int namelen = varnames[i].size(); fwrite(&namelen, 4, 1, f); - tellplaceholder(nameoffset + i * 4, f); + tell_placeholder(name_offset + i * 4, f); fwrite(varnames[i].data(), namelen, 1, f); fwrite("\0", 1, 1, f); } padding(16, 0, f); - tellplaceholder(varnamesec, f); + tell_placeholder(varname_sec, f); // pod_values - int buffersec = writeplaceholder(4, 1, f); - int bufoffset = writeplaceholder(4, 1, f); + int buffer_sec = write_placeholder(4, 1, f); + int buf_offset = write_placeholder(4, 1, f); padding(alignof(cinn_buffer_t), 0, f); - tellplaceholder(bufoffset, f); - std::vector> pvars; + tell_placeholder(buf_offset, f); + std::vector> p_vars; for (auto& varname : varnames) { std::string name = (std::string)varname; auto t = scope_->GetTensor(name); @@ -111,61 +111,61 @@ void Program::Export(const std::vector& persistent_vars, buffer.memory = reinterpret_cast(0); if (std::find(persistent_vars.begin(), persistent_vars.end(), name) != persistent_vars.end()) { - pvars.emplace_back(t->buffer(), - ftell(f) + offsetof(cinn_buffer_t, memory)); + p_vars.emplace_back(t->buffer(), + ftell(f) + offsetof(cinn_buffer_t, memory)); } fwrite(&buffer, sizeof(cinn_buffer_t), 1, f); } padding(16, 0, f); - tellplaceholder(buffersec, f); + tell_placeholder(buffer_sec, f); // persistent_buffers - int pbuffer = writeplaceholder(4, 1, f); - for (auto& p : pvars) { + int p_buffer = write_placeholder(4, 1, f); + for (auto& p : p_vars) { if (p.first->align) { padding(p.first->align, 0, f); } - tellplaceholder(p.second, f); + tell_placeholder(p.second, f); fwrite(p.first->memory, p.first->memory_size, 1, f); } padding(16, 0, f); - tellplaceholder(pbuffer, f); + tell_placeholder(p_buffer, f); // instructions - int instsec = writeplaceholder(4, 1, f); - int insnum = 0; + int inst_sec = write_placeholder(4, 1, f); + int ins_num = 0; for (auto& ins : instrs_) { ins->Run(nullptr, true); - insnum += ins->GetFnNames().size(); + ins_num += ins->GetFnNames().size(); } - fwrite(&insnum, 4, 1, f); - int instplaceholder = writeplaceholder(4 * 3, insnum, f); - int findex = 0; + fwrite(&ins_num, 4, 1, f); + int inst_placeholder = write_placeholder(4 * 3, ins_num, f); + int f_index = 0; for (auto& ins : instrs_) { auto& in_args = ins->GetInArgs(); auto& out_args = ins->GetOutArgs(); auto& fn_names = ins->GetFnNames(); - for (int i = 0; i < fn_names.size(); i++, findex++) { + for (int i = 0; i < fn_names.size(); i++, f_index++) { std::vector all_args(in_args[i].begin(), in_args[i].end()); all_args.insert( std::end(all_args), out_args[i].begin(), out_args[i].end()); - auto fname = fn_names[i]; - int fnamesize = fname.size(); - fwrite(&fnamesize, 4, 1, f); - tellplaceholder(instplaceholder + findex * 12, f); - fwrite(fname.c_str(), fname.size(), 1, f); + auto f_name = fn_names[i]; + int f_name_size = f_name.size(); + fwrite(&f_name_size, 4, 1, f); + tell_placeholder(inst_placeholder + f_index * 12, f); + fwrite(f_name.c_str(), f_name.size(), 1, f); fwrite("\0", 1, 1, f); int argsize = all_args.size(); - setplaceholder(instplaceholder + findex * 12 + 4, &argsize, 4, 1, f); + set_placeholder(inst_placeholder + f_index * 12 + 4, &argsize, 4, 1, f); padding(alignof(cinn_pod_value_t), 0, f); - tellplaceholder(instplaceholder + findex * 12 + 8, f); + tell_placeholder(inst_placeholder + f_index * 12 + 8, f); for (auto& arg : all_args) { - uintptr_t bufindex = varindex[arg]; - cinn_pod_value_t v(reinterpret_cast(bufindex)); + uintptr_t buf_index = var_index[arg]; + cinn_pod_value_t v(reinterpret_cast(buf_index)); fwrite(&v, sizeof(cinn_pod_value_t), 1, f); } } } padding(16, 0, f); - tellplaceholder(instsec, f); + tell_placeholder(inst_sec, f); fclose(f); } diff --git a/paddle/cinn/hlir/op/broadcast.cc b/paddle/cinn/hlir/op/broadcast.cc index bf71267b2c618..28cc2da723af5 100644 --- a/paddle/cinn/hlir/op/broadcast.cc +++ b/paddle/cinn/hlir/op/broadcast.cc @@ -307,12 +307,7 @@ std::shared_ptr StrategyForBroadcastToSymbolic( output_shapes[0].end(), out_shape.begin(), [](const ir::Dim &dim) { return dim->dim_expr; }); - std::vector broadcast_axes; - CHECK_GT(attrs.attr_store.count("broadcast_axes"), 0); - broadcast_axes = - absl::get>(attrs.attr_store.at("broadcast_axes")); VLOG(3) << "broadcast out shape: " << utils::Join(out_shape, ", "); - VLOG(3) << "broadcast_axes shape: " << utils::Join(broadcast_axes, ", "); framework::CINNCompute broadcast_to_compute([=](lang::Args args, lang::RetValue *ret) { @@ -321,14 +316,24 @@ std::shared_ptr StrategyForBroadcastToSymbolic( CINNValuePack pack_args = args[0]; CHECK(!pack_args.empty()) << "The input tensors of broadcast_to compute is empty! Please check."; - CHECK_GE(pack_args.size(), 2U); - CHECK(pack_args[1].is_string()); - std::string tensor_name = pack_args[1].operator std::string(); + std::string tensor_name = [&] { + if (pack_args.size() == 2) { + return pack_args[1].operator std::string(); + } else { + PADDLE_ENFORCE_EQ(pack_args.size(), + 3, + ::common::errors::InvalidArgument( + "The number of input tensors is wrong. " + "The expected inputs is 3, but now is %d.", + pack_args.size())); + return pack_args[2].operator std::string(); + } + }(); Expr A_expr = pack_args[0]; CHECK(A_expr.as_tensor()); ir::Tensor A = A_expr.as_tensor_ref(); - auto out = pe::BroadcastTo(A, out_shape, broadcast_axes, tensor_name); + auto out = pe::BroadcastTo(A, out_shape, tensor_name); auto stages = CreateStages({A, out}); *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; }); @@ -426,8 +431,9 @@ std::shared_ptr StrategyForBroadcastGrad( const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { - LOG(FATAL) << "Gradient operator will be decomposed into several primitive " - "operators. Please Use Decomposer Program Pass."; + PADDLE_THROW(phi::errors::Fatal( + "Gradient operator will be decomposed into several primitive " + "operators. Please Use Decomposer Program Pass.")); } std::shared_ptr StrategyForIsClose( @@ -545,16 +551,16 @@ StrategyForBinary(logical_right_shift, LogicalRightShift); } // namespace cinn CINN_REGISTER_HELPER(broadcast_ops) { -#define CINN_REGISTER_BINARY(op__, op_stragegy__) \ +#define CINN_REGISTER_BINARY(op__, op_strategy__) \ CINN_REGISTER_OP(op__) \ .describe(#op__ " function") \ .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr( \ - "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_strategy__) \ .set_attr( \ "CINNStrategySymbolic", \ - cinn::hlir::op::StrategyFor##op_stragegy__##Symbolic) \ + cinn::hlir::op::StrategyFor##op_strategy__##Symbolic) \ .set_attr("infershape", \ MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ .set_attr("inferdtype", \ @@ -567,13 +573,16 @@ CINN_REGISTER_HELPER(broadcast_ops) { "OpPattern", cinn::hlir::framework::OpPatternKind::kBroadcast) \ .set_support_level(4); -#define CINN_REGISTER_BINARY_CMP(op__, op_stragegy__) \ +#define CINN_REGISTER_BINARY_CMP(op__, op_strategy__) \ CINN_REGISTER_OP(op__) \ .describe(#op__ " function") \ .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr( \ - "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_strategy__) \ + .set_attr( \ + "CINNStrategySymbolic", \ + cinn::hlir::op::StrategyFor##op_strategy__##Symbolic) \ .set_attr("infershape", \ MakeOpFunction(cinn::hlir::op::InferShapeForBroadcast)) \ .set_attr("inferdtype", \ diff --git a/paddle/cinn/hlir/op/contrib/argmax.cc b/paddle/cinn/hlir/op/contrib/argmax.cc index 7de32179b52a0..b3c6a647c4bc3 100644 --- a/paddle/cinn/hlir/op/contrib/argmax.cc +++ b/paddle/cinn/hlir/op/contrib/argmax.cc @@ -106,7 +106,7 @@ std::shared_ptr StrategyForArgmax( if (attrs.attr_store.count("axis")) { axis = absl::get(attrs.attr_store.at("axis")); } else { - LOG(FATAL) << "reduce dimension is not set!"; + PADDLE_THROW(phi::errors::Fatal("reduce dimension is not set!")); } if (attrs.attr_store.count("keep_dim")) { keep_dims = absl::get(attrs.attr_store.at("keep_dim")); diff --git a/paddle/cinn/hlir/op/contrib/argmin.cc b/paddle/cinn/hlir/op/contrib/argmin.cc index 8f9d2ec9f45fd..dff137f0d9952 100644 --- a/paddle/cinn/hlir/op/contrib/argmin.cc +++ b/paddle/cinn/hlir/op/contrib/argmin.cc @@ -105,7 +105,7 @@ std::shared_ptr StrategyForArgmin( if (attrs.attr_store.count("axis")) { axis = absl::get(attrs.attr_store.at("axis")); } else { - LOG(FATAL) << "reduce dimension is not set!"; + PADDLE_THROW(phi::errors::Fatal("reduce dimension is not set!")); } if (attrs.attr_store.count("keep_dim")) { keep_dims = absl::get(attrs.attr_store.at("keep_dim")); diff --git a/paddle/cinn/hlir/op/contrib/bitcast_convert.cc b/paddle/cinn/hlir/op/contrib/bitcast_convert.cc index dc8516b160bd2..4ddcb52f44922 100644 --- a/paddle/cinn/hlir/op/contrib/bitcast_convert.cc +++ b/paddle/cinn/hlir/op/contrib/bitcast_convert.cc @@ -111,9 +111,10 @@ std::vector InferShapeForBitcastConvert( } else { if (output_shape.back().back() != (output_data_type.bits() / input_data_type.bits())) { - LOG(FATAL) << "The rightmost dimension of input must be equal to " - "sizeof(output_data_type)/sizeof(input_data_type) when " - "sizeof(output_data_type) > sizeof(input_data_type)"; + PADDLE_THROW(phi::errors::InvalidArgument( + "The rightmost dimension of input must be equal to " + "sizeof(output_data_type)/sizeof(input_data_type) when " + "sizeof(output_data_type) > sizeof(input_data_type)")); } output_shape.back().pop_back(); } diff --git a/paddle/cinn/hlir/op/contrib/resize.cc b/paddle/cinn/hlir/op/contrib/resize.cc index d74f4647878b0..91319ef7e5ac1 100644 --- a/paddle/cinn/hlir/op/contrib/resize.cc +++ b/paddle/cinn/hlir/op/contrib/resize.cc @@ -61,7 +61,8 @@ ir::Tensor Resize(const ir::Tensor &input, } else if (target.arch == cinn::common::Target::Arch::X86) { func_name.assign("cinn_host_resize_"); } else { - LOG(FATAL) << "Resize only supports X86 and NVGPU ! Please Check.\n"; + PADDLE_THROW(phi::errors::Fatal( + "Resize only supports X86 and NVGPU ! Please Check.\n")); } if (mode == "bilinear") { diff --git a/paddle/cinn/hlir/op/contrib/sort.cc b/paddle/cinn/hlir/op/contrib/sort.cc index 8adc618e352e6..49f50a13ab6c9 100644 --- a/paddle/cinn/hlir/op/contrib/sort.cc +++ b/paddle/cinn/hlir/op/contrib/sort.cc @@ -56,7 +56,8 @@ std::vector ArgSort(const ir::Tensor &A, } else if (target.arch == cinn::common::Target::Arch::X86) { find_func_name.assign("cinn_host_next_smallest_int32"); } else { - LOG(FATAL) << "ArgSort only supports X86 and NVGPU ! Please Check.\n"; + PADDLE_THROW(phi::errors::Fatal( + "ArgSort only supports X86 and NVGPU ! Please Check.\n")); } if (is_ascend) { index_func_name = diff --git a/paddle/cinn/hlir/op/custom_call.cc b/paddle/cinn/hlir/op/custom_call.cc index 91c3ee6db0898..fc84e4cc9eb1a 100644 --- a/paddle/cinn/hlir/op/custom_call.cc +++ b/paddle/cinn/hlir/op/custom_call.cc @@ -231,14 +231,14 @@ std::vector CustomCallArgsForCublas( if (is_infer) { CHECK_EQ(a_width, b_width) - << "The K dimension of mul shold be equal! Please check."; + << "The K dimension of mul should be equal! Please check."; trans_b = true; } else { CHECK_EQ(a_width, b_height) - << "The K dimension of mul shold be equal! Please check."; + << "The K dimension of mul should be equal! Please check."; } } else { - LOG(FATAL) << "Unkown Matmul Setting!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown Matmul Setting!")); } CHECK_EQ(a_shape.size(), 4); @@ -365,14 +365,14 @@ std::vector CustomCallArgsForBatchedCublas( if (is_infer) { CHECK_EQ(a_width, b_width) - << "The K dimension of mul shold be equal! Please check."; + << "The K dimension of mul should be equal! Please check."; trans_b = true; } else { CHECK_EQ(a_width, b_height) - << "The K dimension of mul shold be equal! Please check."; + << "The K dimension of mul should be equal! Please check."; } } else { - LOG(FATAL) << "Unkown Matmul Setting!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown Matmul Setting!")); } CHECK_EQ(a_shape.size(), 4); @@ -878,10 +878,12 @@ std::vector CustomCallArgsForMemset( void operator()(int64_t v) { *scalar_ = static_cast(v); } void operator()(bool v) { *scalar_ = v ? 0xFFFFFFFF : 0; } -#define EXPAND_MEMSET_TYPE_UNSUPPORT(TYPE) \ - void operator()(const TYPE &) { \ - LOG(FATAL) << "The type of \"value\" of memset custom_call not support: " \ - << #TYPE; \ +#define EXPAND_MEMSET_TYPE_UNSUPPORT(TYPE) \ + void operator()(const TYPE &) { \ + std::stringstream ss; \ + ss << "The type of \"value\" of memset custom_call not support: " \ + << #TYPE; \ + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); \ } EXPAND_MEMSET_TYPE_UNSUPPORT(std::string) @@ -937,7 +939,7 @@ std::vector CustomCallArgsForMemcpy( return {Expr(count)}; } -bool RegisteryCustomCallArgsFunc() { +bool RegisterCustomCallArgsFunc() { #ifdef CINN_WITH_CUDA CustomCallArgsFuncRegistry::Global().Register( "cinn_call_cublas", @@ -1025,7 +1027,7 @@ bool RegisteryCustomCallArgsFunc() { return true; } -static bool registry_custom_call_list_func = RegisteryCustomCallArgsFunc(); +static bool registry_custom_call_list_func = RegisterCustomCallArgsFunc(); } // namespace op } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/op/elementwise.cc b/paddle/cinn/hlir/op/elementwise.cc index b215e0dd85952..d32c2c0af8b2f 100644 --- a/paddle/cinn/hlir/op/elementwise.cc +++ b/paddle/cinn/hlir/op/elementwise.cc @@ -18,6 +18,7 @@ #include "absl/types/optional.h" #include "paddle/cinn/adt/op_equation_context.h" +#include "paddle/cinn/common/type.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op_strategy.h" @@ -25,8 +26,11 @@ #include "paddle/cinn/hlir/pe/ir_schedule_pe.h" #include "paddle/cinn/hlir/pe/nn.h" #include "paddle/cinn/hlir/pe/schedule.h" +#include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/utils/functional.h" +#include "paddle/common/enforce.h" +#include "paddle/phi/core/enforce.h" namespace cinn { namespace hlir { @@ -73,6 +77,7 @@ std::shared_ptr StrategyForElementwise( CHECK(!args.empty()) << "The input argument of " << op_name << " compute is empty! Please check."; CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) << "1 input tensor for " << op_name << " compute"; CHECK_EQ(pack_args.size(), 2U); @@ -332,22 +337,27 @@ Expr GetScalarExpr(const framework::NodeAttr::attr_t &attr) { void operator()(bool v) { scalar_ = Expr(v); } void operator()(const std::string &v) { scalar_ = Expr(v); } void operator()(const std::vector &) { - LOG(FATAL) << "wrong type std::vector"; + PADDLE_THROW(phi::errors::InvalidArgument("wrong type std::vector")); } void operator()(const std::vector &) { - LOG(FATAL) << "wrong type std::vector"; + PADDLE_THROW( + phi::errors::InvalidArgument("wrong type std::vector")); } void operator()(const std::vector &) { - LOG(FATAL) << "wrong type std::vector"; + PADDLE_THROW( + phi::errors::InvalidArgument("wrong type std::vector")); } void operator()(const std::vector &) { - LOG(FATAL) << "wrong type std::vector"; + PADDLE_THROW( + phi::errors::InvalidArgument("wrong type std::vector")); } void operator()(const std::vector &) { - LOG(FATAL) << "wrong type std::vector"; + PADDLE_THROW( + phi::errors::InvalidArgument("wrong type std::vector")); } void operator()(const std::vector &) { - LOG(FATAL) << "wrong type std::vector"; + PADDLE_THROW( + phi::errors::InvalidArgument("wrong type std::vector")); } }; absl::visit(Visitor{scalar}, attr); @@ -431,8 +441,9 @@ std::shared_ptr StrategyForSum( const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { - LOG(FATAL) << "The operator will be decomposed into several primitive " - "operators. Please Use Decomposer Program Pass."; + PADDLE_THROW(phi::errors::Fatal( + "The operator will be decomposed into several primitive " + "operators. Please Use Decomposer Program Pass.")); } std::vector InferShapeForSum(const std::vector &inputs_shape, @@ -441,10 +452,11 @@ std::vector InferShapeForSum(const std::vector &inputs_shape, auto shape = inputs_shape[0]; for (size_t i = 1; i < inputs_shape.size(); ++i) { if (inputs_shape[i] != shape) { - LOG(FATAL) << "The input shapes must be the same. But received: the i-th(" - << i << ") input shape is " - << utils::Join(inputs_shape[i], ",") - << " and the first input shape is " << utils::Join(shape, ","); + std::stringstream ss; + ss << "The input shapes must be the same. But received: the i-th(" << i + << ") input shape is " << utils::Join(inputs_shape[i], ",") + << " and the first input shape is " << utils::Join(shape, ","); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } std::vector out_shape{shape}; @@ -458,9 +470,11 @@ std::vector InferDtypeForSum(const std::vector &inputs_type, auto type = inputs_type[0]; for (size_t i = 1; i < inputs_type.size(); ++i) { if (inputs_type[i] != type) { - LOG(FATAL) << "The input types must be the same. But received: the i-th(" - << i << ") input type is " << inputs_type[i] - << " and the first input type is " << type; + std::stringstream ss; + ss << "The input types must be the same. But received: the i-th(" << i + << ") input type is " << inputs_type[i] + << " and the first input type is " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } std::vector res{type}; @@ -530,8 +544,7 @@ std::shared_ptr StrategyForFillConstantSymbolic( CHECK(!args.empty()) << "The input argument of fill_constant compute " "is empty! Please check."; bool force_cpu = false; - CHECK(attrs.attr_store.count("shape")); - auto shape = absl::get>(attrs.attr_store.at("shape")); + auto shape = output_shapes[0]; CHECK(attrs.attr_store.count("value")); auto value = GetScalarExpr(attrs.attr_store.at("value")); CHECK(attrs.attr_store.count("force_cpu")); @@ -652,7 +665,9 @@ std::shared_ptr StrategyForAssignValue( } EXPAND_ATTR_TYPE(EXPAND_VALUE_TO_TENSOR) else { // NOLINT - LOG(FATAL) << "Assign value not support the type " << out_type[0]; + std::stringstream ss; + ss << "Assign value not support the type " << out_type[0]; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } #undef EXPAND_VALUE_TO_TENSOR @@ -693,7 +708,8 @@ std::vector InferShapeForAssignValue( } EXPAND_ATTR_TYPE(EXPAND_ATTR_TO_GET_SHAPE) else { // NOLINT - LOG(FATAL) << "assign_value not support the type!"; + PADDLE_THROW( + phi::errors::InvalidArgument("assign_value not support the type!")); } #undef EXPAND_ATTR_TO_GET_SHAPE @@ -734,7 +750,8 @@ std::vector InferDtypeForAssignValue( } EXPAND_ATTR_TYPE(EXPAND_ATTR_TO_GET_DTYPE) else { // NOLINT - LOG(FATAL) << "assign_value not support the type!"; + PADDLE_THROW( + phi::errors::InvalidArgument("assign_value not support the type!")); } #undef EXPAND_ATTR_TO_GET_DTYPE } @@ -1014,16 +1031,19 @@ std::shared_ptr StrategyForReshapeSymbolic( Expr A = pack_args[0]; CHECK(A.as_tensor()); CHECK(!output_shapes.empty()); - auto attr_store = attrs.attr_store; - CHECK(attr_store.count("shape")) << "find no attr of shape"; auto tensor_A = A.as_tensor_ref(); - auto stages = CreateStages({tensor_A}); + auto stages = CreateStages({}); VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") << ", output_shapes: " << utils::Join(output_shapes[0], ", "); - CHECK_EQ(pack_args.size(), 2); - CHECK(pack_args[1].is_string()); - std::string tensor_name = pack_args[1].operator std::string(); + std::string tensor_name; + if (pack_args.size() == 4) { + CHECK(pack_args[2].is_string()); + tensor_name = pack_args[2].operator std::string(); + } else { + CHECK(pack_args[1].is_string()); + tensor_name = pack_args[1].operator std::string(); + } ir::Tensor out = pe::Reshape(tensor_A, output_shapes[0], tensor_name); std::vector res; @@ -1078,9 +1098,12 @@ std::vector> InferShapeForReshape( } else if (output_shape[i] == -1 && flag_index == -1) { flag_index = i; } else if (output_shape[i] == -1) { - LOG(FATAL) << "More than one -1 in output_shape of op reshape."; + PADDLE_THROW(phi::errors::InvalidArgument( + "More than one -1 in output_shape of op reshape.")); } else { - LOG(FATAL) << "Unsupported output_shape " << output_shape[i]; + std::stringstream ss; + ss << "Unsupported output_shape " << output_shape[i]; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } if (flag_index >= 0) output_shape[flag_index] = tensor_size; @@ -1128,6 +1151,170 @@ std::shared_ptr StrategyForCast( return strategy; } +std::shared_ptr StrategyForCastSymbolic( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute cast_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Cast compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Cast compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + CHECK_EQ(pack_args.size(), 2U); + std::string tensor_name = pack_args[1].operator std::string(); + ir::Tensor out = pe::Cast(tensor_A, out_type[0], tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of Cast is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(cast_compute, lang::PackedFunc(), "strategy.cast.x86", 1); + return strategy; +} + +std::shared_ptr StrategyForYieldStore( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute cast_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Cast compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Cast compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + CHECK_EQ(pack_args.size(), 2U); + std::string tensor_name = pack_args[1].operator std::string(); + ir::Tensor out = pe::Store(tensor_A, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of Cast is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(cast_compute, + GetElementwiseScheduleFunc(output_shapes, target), + "strategy.reshape.x86", + 1); + return strategy; +} + +std::shared_ptr StrategyForYieldStoreSymbolic( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute cast_compute( + [=](lang::Args args, lang::RetValue *ret) { + CHECK(!args.empty()) + << "The input arguments of Cast compute is empty! Please check.\n"; + CINNValuePack pack_args = args[0]; + CHECK_GE(pack_args.size(), 1U) + << "at least 1 input tensors for Cast compute\n"; + Expr A = pack_args[0]; + CHECK(A.as_tensor()); + CHECK(!output_shapes.empty()); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + CHECK_EQ(pack_args.size(), 2U); + std::string tensor_name = pack_args[1].operator std::string(); + ir::Tensor out = pe::Store(tensor_A, tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of Cast is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl(cast_compute, lang::PackedFunc(), "strategy.store.x86", 1); + return strategy; +} + +std::shared_ptr StrategyForGenerateShapeSymbolic( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute generate_shape_compute( + [=](lang::Args args, lang::RetValue *ret) { + PADDLE_ENFORCE(!args.empty(), + ::common::errors::InvalidArgument( + "Invalid argument. The input arguments of " + "generate_shape compute is empty! Please check.")); + CINNValuePack pack_args = args[0]; + PADDLE_ENFORCE_GE(pack_args->size(), + 1U, + ::common::errors::InvalidArgument( + "At least 1 input tensors for generate_shape " + "compute, but now get %d.", + pack_args->size())); + auto stages = CreateStages({}); + + std::string tensor_name = pack_args.back().operator std::string(); + ir::Tensor out(ir::_Tensor_::Make(/*name=*/tensor_name, + /*dtype=*/common::type_of(), + /*shape=*/ + { + Expr(1), + }, + /*domain=*/ + { + Expr(1), + })); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + PADDLE_ENFORCE(!out_type.empty(), + ::common::errors::InvalidArgument( + "Invalid argument. The output type of " + "generate_shape is empty! Please check.")); + + res.push_back(CINNValue(stages)); + *ret = CINNValuePack{res}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl( + generate_shape_compute, lang::PackedFunc(), "strategy.store.x86", 1); + return strategy; +} + std::vector InferDtypeForCast(const std::vector &inputs_type, const framework::AttrMapType &attrs) { CHECK(attrs.count("dtype")); @@ -1206,21 +1393,81 @@ std::vector InferDtypeForLogicalNot(const std::vector &inputs_type, return {cinn::common::Bool()}; } +std::shared_ptr StrategyForTril( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + framework::CINNCompute tril_compute([=](lang::Args args, + lang::RetValue *ret) { + PADDLE_ENFORCE_EQ(args.size(), + size_t(1), + phi::errors::InvalidArgument( + "The input arguments of tril compute is empty")); + CINNValuePack pack_args = args[0]; + PADDLE_ENFORCE_GE( + pack_args.size(), + size_t(1), + phi::errors::InvalidArgument("only 1 input tensor for tril compute")); + Expr A = pack_args[0]; + PADDLE_ENFORCE_NOT_NULL( + A.as_tensor(), + phi::errors::InvalidArgument( + "first input argument in tril should be tensor")); + int diagonal = absl::get(attrs.attr_store.at("diagonal")); + auto tensor_A = A.as_tensor_ref(); + auto stages = CreateStages({tensor_A}); + + PADDLE_ENFORCE_NE(output_shapes.size(), + size_t(0), + phi::errors::InvalidArgument( + "output shape of tril should not be empty.")); + VLOG(3) << "A shape: " << utils::Join(tensor_A->shape, ", ") + << ", output_shapes: " << utils::Join(output_shapes[0], ", "); + + PADDLE_ENFORCE_EQ(pack_args.size(), + size_t(2), + phi::errors::InvalidArgument( + "args of tril compute should be equal to 2")); + PADDLE_ENFORCE_EQ(pack_args[1].is_string(), + true, + phi::errors::InvalidArgument( + "The second argument of tril should be string")); + std::string tensor_name = pack_args[1].operator std::string(); + + ir::Tensor out = + pe::Tril(tensor_A, diagonal, output_shapes[0], tensor_name); + std::vector res; + stages->InsertLazily(out); + res.push_back(CINNValue(out)); + CHECK(!out_type.empty()) + << "Output type of Reshape is empty! Please check.\n"; + res.push_back(CINNValue(stages)); + + *ret = CINNValuePack{res}; + }); + auto strategy = std::make_shared(); + strategy->AddImpl(tril_compute, lang::PackedFunc(), "strategy.tril.x86", 1); + + return strategy; +} + } // namespace op } // namespace hlir } // namespace cinn CINN_REGISTER_HELPER(elementwise_ops) { -#define CINN_REGISTER_UNARY(op__, op_stragegy__) \ +#define CINN_REGISTER_UNARY(op__, op_strategy__) \ CINN_REGISTER_OP(op__) \ .describe(#op__ " function") \ .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr( \ - "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_strategy__) \ .set_attr( \ "CINNStrategySymbolic", \ - cinn::hlir::op::StrategyFor##op_stragegy__##Symbolic) \ + cinn::hlir::op::StrategyFor##op_strategy__##Symbolic) \ .set_attr("infershape", \ MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ .set_attr("inferdtype", \ @@ -1270,13 +1517,13 @@ CINN_REGISTER_HELPER(elementwise_ops) { #undef CINN_REGISTER_UNARY -#define CINN_REGISTER_COMPARE(op__, op_stragegy__) \ +#define CINN_REGISTER_COMPARE(op__, op_strategy__) \ CINN_REGISTER_OP(op__) \ .describe(#op__ " function") \ .set_num_inputs(1) \ .set_num_outputs(1) \ .set_attr( \ - "CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \ + "CINNStrategy", cinn::hlir::op::StrategyFor##op_strategy__) \ .set_attr("infershape", \ MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \ .set_attr("inferdtype", \ @@ -1441,6 +1688,25 @@ CINN_REGISTER_HELPER(elementwise_ops) { .set_num_outputs(1) .set_attr( "CINNStrategy", cinn::hlir::op::StrategyForCast) + .set_attr( + "CINNStrategySymbolic", cinn::hlir::op::StrategyForCastSymbolic) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCast)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) + .set_support_level(4); + + CINN_REGISTER_OP(yield_store) + .describe("This operator is used to cast input tensor's type to target.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr( + "CINNStrategy", cinn::hlir::op::StrategyForYieldStore) + .set_attr( + "CINNStrategySymbolic", cinn::hlir::op::StrategyForYieldStoreSymbolic) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCast)) @@ -1450,6 +1716,22 @@ CINN_REGISTER_HELPER(elementwise_ops) { "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); + CINN_REGISTER_OP(generate_shape) + .describe("This operator is used to cast input tensor's type to target.") + .set_num_inputs(1) + .set_num_outputs(1) + .set_attr( + "CINNStrategySymbolic", + cinn::hlir::op::StrategyForGenerateShapeSymbolic) + .set_attr("infershape", + MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) + .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForCast)) + .set_attr("inferlayout", + MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) + .set_support_level(4); + CINN_REGISTER_OP(arange) .describe("Returns evenly spaced values within a given interval.") .set_num_inputs(0) @@ -1481,6 +1763,8 @@ CINN_REGISTER_HELPER(elementwise_ops) { .set_num_outputs(1) .set_attr( "CINNStrategy", cinn::hlir::op::StrategyForLogicalNot) + .set_attr( + "CINNStrategySymbolic", cinn::hlir::op::StrategyForLogicalNotSymbolic) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) .set_attr("inferdtype", @@ -1491,5 +1775,16 @@ CINN_REGISTER_HELPER(elementwise_ops) { "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise) .set_support_level(4); + CINN_REGISTER_OP(tril) + .describe( + "Filters out the upper portion of an input tensor on one side of a " + "diagonal") + .set_num_inputs(2) + .set_num_outputs(1) + .set_attr( + "CINNStrategySymbolic", cinn::hlir::op::StrategyForTril) + .set_attr( + "OpPattern", cinn::hlir::framework::OpPatternKind::kElementWise); + return true; } diff --git a/paddle/cinn/hlir/op/nn.cc b/paddle/cinn/hlir/op/nn.cc index 60cbc1c89e222..2b1ce342e0810 100644 --- a/paddle/cinn/hlir/op/nn.cc +++ b/paddle/cinn/hlir/op/nn.cc @@ -305,7 +305,8 @@ std::shared_ptr StrategyForConv2d( dilation[1], tensor_name); } else { - LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support NCHW and NHWC data layout\n")); } auto stages = CreateStages({A.as_tensor_ref(), B.as_tensor_ref()}); @@ -368,7 +369,9 @@ std::shared_ptr StrategyForConv2d( } else if (target.arch == Target::Arch::X86) { CINN_NOT_IMPLEMENTED } - LOG(FATAL) << "This target [" << target << "] is not supported yet."; + std::stringstream ss; + ss << "This target [" << target << "] is not supported yet."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); }); auto strategy = std::make_shared(); @@ -713,8 +716,8 @@ std::shared_ptr StrategyForConv2dNCHWc( strategy->AddImpl( conv2d_compute, conv2d_schedule, "strategy.conv2d_NCHWc.x86", 1); } else { - LOG(FATAL) - << "conv2d_NCHWc op with dtype != float32 is not implemented yet!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "conv2d_NCHWc op with dtype != float32 is not implemented yet!")); } return strategy; } @@ -894,7 +897,8 @@ std::shared_ptr StrategyForDepthwiseConv2d( stride[1], tensor_name); } else { - LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support NCHW and NHWC data layout\n")); } auto stages = CreateStages({A.as_tensor_ref(), B.as_tensor_ref()}); @@ -1008,7 +1012,8 @@ std::vector InferShapeForDepthwiseConv2d( out_shape_w, inputs_shape[1][1] * inputs_shape[0][3]}}; } else { - LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support NCHW and NHWC data layout\n")); } return res; } @@ -1093,7 +1098,8 @@ std::shared_ptr StrategyForBatchNorm( "strategy.batchnorm.x86", 1); } else { - LOG(FATAL) << "BatchNorm op with dtype != float32 is not implemented yet!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "BatchNorm op with dtype != float32 is not implemented yet!")); } return strategy; } @@ -1303,7 +1309,9 @@ std::vector> InferShapeForPool1d( } else if (data_format == "NWC") { width_axis = 1; } else { - LOG(FATAL) << "unsupported data_format: " << data_format << std::endl; + std::stringstream ss; + ss << "unsupported data_format: " << data_format << std::endl; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (ceil_mode) { @@ -1406,8 +1414,8 @@ std::shared_ptr StrategyForPool2d( width_index = 3; data_format = "NCHW"; } else { - LOG(FATAL) - << "Only support 'NCHW' or 'NHWC' or 'AnyLayout' data_format.\n"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support 'NCHW' or 'NHWC' or 'AnyLayout' data_format.\n")); } kernel_size = {A_tensor->shape[height_index].as_int32(), A_tensor->shape[width_index].as_int32()}; @@ -2206,7 +2214,8 @@ std::vector InferShapeForBatchNormTrain( if (attrs.find("data_layout") != attrs.end()) { data_layout = absl::get(attrs.at("data_layout")); } else { - LOG(FATAL) << "data_layout is not found, please check!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "data_layout is not found, please check!")); } CHECK_EQ(inputs_shape[0].size(), 4) << "x dimension size is not required!"; @@ -2237,7 +2246,9 @@ std::vector InferShapeForBatchNormTrain( CHECK_EQ(inputs_shape[0][3], inputs_shape[4][0]) << "x and moving_variance dimension size is not equal!"; } else { - LOG(FATAL) << "data_layout " << data_layout << " is not support!"; + std::stringstream ss; + ss << "data_layout " << data_layout << " is not support!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return {inputs_shape[0], @@ -2271,8 +2282,9 @@ std::shared_ptr StrategyForGradOp( const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { - LOG(FATAL) << "Gradient operator will be decomposed into several primitive " - "operators. Please Use Decomposer Program Pass."; + PADDLE_THROW(phi::errors::Fatal( + "Gradient operator will be decomposed into several primitive " + "operators. Please Use Decomposer Program Pass.")); } // batch norm grad @@ -2285,7 +2297,8 @@ std::vector InferShapeForBatchNormGrad( if (attrs.find("data_layout") != attrs.end()) { data_layout = absl::get(attrs.at("data_layout")); } else { - LOG(FATAL) << "data_layout is not found, please check!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "data_layout is not found, please check!")); } CHECK_EQ(inputs_shape[0].size(), 4) << "dy dimension size is not required!"; @@ -2313,7 +2326,9 @@ std::vector InferShapeForBatchNormGrad( CHECK_EQ(inputs_shape[0][3], inputs_shape[4][0]) << "dy and moving_variance dimension size is not equal!"; } else { - LOG(FATAL) << "data_layout " << data_layout << " is not support!"; + std::stringstream ss; + ss << "data_layout " << data_layout << " is not support!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return {inputs_shape[0], inputs_shape[2], inputs_shape[2]}; diff --git a/paddle/cinn/hlir/op/op_util.cc b/paddle/cinn/hlir/op/op_util.cc index 6cad9f4cb75f1..b0976f22c38cb 100644 --- a/paddle/cinn/hlir/op/op_util.cc +++ b/paddle/cinn/hlir/op/op_util.cc @@ -100,8 +100,9 @@ std::string GetExternFuncName(const cinn::common::Target& target, } else if (target.arch == cinn::common::Target::Arch::X86) { func_proto_name.append("host_"); } else { - LOG(FATAL) << func_name - << " only supports X86 and NVGPU! Please Check.\n"; + std::stringstream ss; + ss << func_name << " only supports X86 and NVGPU! Please Check.\n"; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } } func_proto_name.append(func_name); @@ -138,11 +139,22 @@ std::string GetExternFuncName(const cinn::common::Target& target, } else if (type.is_uint(64)) { func_proto_name.append("uint64"); } else { - LOG(FATAL) << "Can not find type: " << type - << " for extern function. Please Check.\n"; + std::stringstream ss; + ss << "Can not find type: " << type + << " for extern function. Please Check.\n"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return func_proto_name; } +std::vector ToCinnExprs(const std::vector& args) { + std::vector exprs; + std::transform(args.begin(), + args.end(), + std::back_inserter(exprs), + [](const ir::Dim& arg) { return arg->dim_expr; }); + return exprs; +} + } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/op/op_util.h b/paddle/cinn/hlir/op/op_util.h index a0521e26f1b72..ee5ec1cad4531 100644 --- a/paddle/cinn/hlir/op/op_util.h +++ b/paddle/cinn/hlir/op/op_util.h @@ -20,6 +20,7 @@ #include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/node.h" +#include "paddle/cinn/ir/dim.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/utils/type_defs.h" @@ -60,6 +61,8 @@ std::vector ToCinnExprs(const std::vector &args) { return exprs; } +std::vector ToCinnExprs(const std::vector &args); + template std::vector ToPodVector(const std::vector &args) { if (args.empty()) { @@ -125,7 +128,9 @@ std::vector ToPodVector(const std::vector &args) { shape_v.push_back(static_cast(e.as_double())); } } else { - LOG(FATAL) << "Not support " << type; + std::stringstream ss; + ss << "Not support " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return shape_v; } diff --git a/paddle/cinn/hlir/op/reduction.cc b/paddle/cinn/hlir/op/reduction.cc index a8fda43e0ceb5..d5a378dc809e6 100644 --- a/paddle/cinn/hlir/op/reduction.cc +++ b/paddle/cinn/hlir/op/reduction.cc @@ -88,7 +88,7 @@ std::shared_ptr StrategyForReduce( CHECK_NE(reduce_axes[idx - 1], reduce_axes[idx]); } } else { - LOG(FATAL) << "reduce dimension is not set!"; + PADDLE_THROW(phi::errors::InvalidArgument("reduce dimension is not set!")); } bool keep_dim = false; @@ -270,7 +270,7 @@ std::shared_ptr StrategyForReduce( CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { - LOG(FATAL) << "Unkown Reduce Type!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown Reduce Type!")); } } else { if (arg_pack.size() == 2) { @@ -304,7 +304,7 @@ std::shared_ptr StrategyForReduce( CINNValue(ir_sch.GetModule().GetExprs().at(0))}; *ret = CINNValuePack{res}; } else { - LOG(FATAL) << "Unkown Reduce Type!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown Reduce Type!")); } } } else { @@ -352,7 +352,7 @@ std::shared_ptr StrategyForReduceSymbolic( CHECK_NE(reduce_axes[idx - 1], reduce_axes[idx]); } } else { - LOG(FATAL) << "reduce dimension is not set!"; + PADDLE_THROW(phi::errors::InvalidArgument("reduce dimension is not set!")); } bool keep_dim = false; diff --git a/paddle/cinn/hlir/op/transform.cc b/paddle/cinn/hlir/op/transform.cc index 113c2b2f1cd82..21754487e7846 100644 --- a/paddle/cinn/hlir/op/transform.cc +++ b/paddle/cinn/hlir/op/transform.cc @@ -27,6 +27,9 @@ #include "paddle/cinn/hlir/pe/schedule.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/utils/string.h" +#include "paddle/common/enforce.h" +#include "paddle/common/errors.h" +#include "paddle/phi/core/enforce.h" namespace cinn { namespace hlir { @@ -286,9 +289,9 @@ std::vector> InferShapeForSplit( if (attrs.find("num_or_sections") != attrs.end()) { sections = absl::get>(attrs.at("num_or_sections")); } else { - LOG(FATAL) - << "The Split op doesn't find [num_or_sections] attribute! It it " - "a mandatory attribute ! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The Split op doesn't find [num_or_sections] attribute! It it " + "a mandatory attribute ! Please check.")); } if (inputs_shape.empty()) { @@ -337,11 +340,13 @@ std::vector> InferShapeForSplit( neg_index = i; } else { if (sections[i] == 0) { - LOG(FATAL) << "The attribute 'num_or_sections' should not has 0 ! " - "Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The attribute 'num_or_sections' should not has 0 ! " + "Please check.")); } else { - LOG(FATAL) << "The attribute 'num_or_sections' can only have at most " - "one '-1' ! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The attribute 'num_or_sections' can only have at most " + "one '-1' ! Please check.")); } } } @@ -373,9 +378,9 @@ std::vector InferDtypeForSplit(const std::vector &inputs_type, if (attrs.find("num_or_sections") != attrs.end()) { sections = absl::get>(attrs.at("num_or_sections")); } else { - LOG(FATAL) - << "The Split op doesn't find [num_or_sections] attribute! It it " - "a mandatory attribute ! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The Split op doesn't find [num_or_sections] attribute! It it " + "a mandatory attribute ! Please check.")); } int output_size = sections.size(); @@ -399,9 +404,9 @@ std::vector> InferLayoutForSplit( sections = absl::get>(attrs.attr_store.at("num_or_sections")); } else { - LOG(FATAL) - << "The Split op doesn't find [num_or_sections] attribute! It it " - "a mandatory attribute ! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The Split op doesn't find [num_or_sections] attribute! It it " + "a mandatory attribute ! Please check.")); } int output_size = sections.size(); @@ -923,7 +928,8 @@ std::shared_ptr StrategyForReverse( for (auto &e : axis) { if (e >= static_cast(output_shapes[0].size()) || e < -1 * static_cast(output_shapes[0].size())) { - LOG(FATAL) << "axis is not in [0, n_dim), Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "axis is not in [0, n_dim), Please check.")); } if (e < 0) { e += output_shapes[0].size(); @@ -970,7 +976,8 @@ std::vector InferShapeForReverse( for (auto &e : axis) { if (e >= static_cast(inputs_shape[0].size()) || e < -1 * static_cast(inputs_shape[0].size())) { - LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "axis is not in [-n_dim, n_dim), Please check.")); } if (e < 0) { e += inputs_shape[0].size(); @@ -990,7 +997,8 @@ std::vector> InferLayoutForReverse( for (auto &e : axis) { if (e >= static_cast(input_shapes[0].size()) || e < -1 * static_cast(input_shapes[0].size())) { - LOG(FATAL) << "axis is not in [-n_dim, n_dim), Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "axis is not in [-n_dim, n_dim), Please check.")); } } } @@ -1043,7 +1051,8 @@ std::shared_ptr StrategyForTranspose( << "output shape is not equal! Please check!\n"; } } else { - LOG(FATAL) << "axis is not be set! Please check."; + PADDLE_THROW( + phi::errors::InvalidArgument("axis is not be set! Please check.")); } framework::CINNCompute transpose_compute([=](lang::Args args, @@ -1072,6 +1081,84 @@ std::shared_ptr StrategyForTranspose( return strategy; } +std::shared_ptr StrategyForTransposeSymbolic( + const framework::NodeAttr &attrs, + const std::vector &inputs, + const std::vector &out_type, + const std::vector> &output_shapes, + const Target &target) { + // check output shape + PADDLE_ENFORCE_EQ(output_shapes.empty(), + false, + ::common::errors::InvalidArgument( + "Output shape is empty! Please check.\n")); + PADDLE_ENFORCE_EQ(output_shapes[0].empty(), + false, + ::common::errors::InvalidArgument( + "Output shape is empty! Please check.\n")); + + std::vector axis; + auto input_shape = inputs[0]->shape; + if (attrs.attr_store.find("axis") != attrs.attr_store.end()) { + axis = absl::get>(attrs.attr_store.at("axis")); + PADDLE_ENFORCE_EQ(axis.size(), + output_shapes[0].size(), + ::common::errors::InvalidArgument( + "axis size is not equal output_shapes size! Please " + "check setting.\n")); + // check axis and shape + for (int idx = 0; idx < axis.size(); ++idx) { + PADDLE_ENFORCE(axis[idx] >= 0 && axis[idx] < axis.size(), + ::common::errors::InvalidArgument( + "axis is not in the tensor shape.")); + for (int idy = idx + 1; idy < axis.size(); ++idy) { + PADDLE_ENFORCE_NE(axis[idx], + axis[idy], + ::common::errors::InvalidArgument( + "The same axis parameter exists!")); + } + } + } else { + PADDLE_THROW( + ::common::errors::InvalidArgument("axis is not be set! Please check.")); + } + + framework::CINNCompute transpose_compute([=](lang::Args args, + lang::RetValue *ret) { + PADDLE_ENFORCE( + !args.empty(), + ::common::errors::InvalidArgument("The input argument of transpose " + "compute is empty! Please check.\n")); + CINNValuePack input_args = args[0]; + PADDLE_ENFORCE(!input_args.empty(), + ::common::errors::InvalidArgument( + "at least one input tensor for transpose compute.\n")); + Expr A = input_args[0]; + PADDLE_ENFORCE( + A.as_tensor(), + ::common::errors::InvalidArgument("The input argument is not Tensor.")); + PADDLE_ENFORCE_EQ(input_args.size(), + 2, + ::common::errors::InvalidArgument( + "The input args size must be equal to 2.")); + PADDLE_ENFORCE( + input_args[1].is_string(), + ::common::errors::InvalidArgument( + "The second argument must be of type string and is the name " + "of the output tensor.")); + std::string tensor_name = input_args[1].operator std::string(); + + auto out = pe::Transpose(A.as_tensor_ref(), axis, tensor_name); + auto stages = CreateStages({out}); + *ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}}; + }); + + auto strategy = std::make_shared(); + strategy->AddImpl( + transpose_compute, lang::PackedFunc(), "strategy.transpose.x86", 1); + return strategy; +} + std::vector InferShapeForTranspose( const std::vector &inputs_shape, const framework::AttrMapType &attrs) { @@ -1092,7 +1179,8 @@ std::vector InferShapeForTranspose( } result.push_back(output_shape); } else { - LOG(FATAL) << "axis is not be set! Please check."; + PADDLE_THROW( + phi::errors::InvalidArgument("axis is not be set! Please check.")); } return result; } @@ -1117,7 +1205,8 @@ std::vector> InferLayoutForTranspose( } } } else { - LOG(FATAL) << "axis is not be set! Please check."; + PADDLE_THROW( + phi::errors::InvalidArgument("axis is not be set! Please check.")); } std::vector new_input_layouts = input_layouts; @@ -2010,6 +2099,8 @@ CINN_REGISTER_HELPER(transform_ops) { .set_num_outputs(1) .set_attr( "CINNStrategy", cinn::hlir::op::StrategyForTranspose) + .set_attr( + "CINNStrategySymbolic", cinn::hlir::op::StrategyForTransposeSymbolic) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForTranspose)) .set_attr("inferdtype", diff --git a/paddle/cinn/hlir/pass/alterlayout.cc b/paddle/cinn/hlir/pass/alterlayout.cc index 4e7df28e7994a..8ca3475c2d7e3 100644 --- a/paddle/cinn/hlir/pass/alterlayout.cc +++ b/paddle/cinn/hlir/pass/alterlayout.cc @@ -139,7 +139,7 @@ std::vector UpdateInferInfos( } void AlterLayoutPass(Graph* graph) { - // alterlayout only in X86 for it's specific layout requirements + // alter layout only in X86 for it's specific layout requirements if (graph->target_.arch == Target::Arch::X86) { auto store_nodes = std::get<0>(graph->topological_order()); auto& shape_dict = graph->GetMutableAttrs< @@ -261,9 +261,10 @@ void AlterLayoutPass(Graph* graph) { } else if (input_shape.size() == 5) { ic = input_shape[1] * input_shape[4]; } else { - LOG(FATAL) - << "conv2d's input shape should be 4D/5D. Wrong input shape: " - << utils::Join(input_shape, ", "); + std::stringstream ss; + ss << "conv2d's input shape should be 4D/5D. Wrong input shape: " + << utils::Join(input_shape, ", "); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (weight_shape.size() == 4) { @@ -273,9 +274,10 @@ void AlterLayoutPass(Graph* graph) { oc = weight_shape[0] * weight_shape[5]; fc = weight_shape[1] * weight_shape[4]; } else { - LOG(FATAL) - << "conv2d's weight shape should be 4D/6D. Wrong weight shape: " - << utils::Join(weight_shape, ", "); + std::stringstream ss; + ss << "conv2d's weight shape should be 4D/6D. Wrong weight shape: " + << utils::Join(weight_shape, ", "); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } VLOG(3) << "oc: " << oc; VLOG(3) << "ic: " << ic; @@ -323,7 +325,7 @@ void AlterLayoutPass(Graph* graph) { src_input_layout, dst_input_layout, cinn::common::UniqName(node->op()->name + - "_input_layout_tranform")); + "_input_layout_transform")); UpdateInferInfos(input_trans_node, {input_shape}, {input_type}, @@ -371,7 +373,7 @@ void AlterLayoutPass(Graph* graph) { src_kernel_layout, dst_kernel_layout, cinn::common::UniqName(node->op()->name + - "_weight_layout_tranform")); + "_weight_layout_transform")); UpdateInferInfos(weight_trans_node, {weight_shape}, {weight_type}, @@ -512,7 +514,8 @@ void AlterLayoutPass(Graph* graph) { layout_dict[source->id()] = src_layout; auto input_data = source->safe_as(); CHECK(input_data); - VLOG(3) << source->id() << " do layout_tranform from C to NCHW"; + VLOG(3) << source->id() + << " do layout_transform from C to NCHW"; std::string op_type = "broadcast_to"; auto trans_node = new Node( Operator::Get(op_type), @@ -543,7 +546,7 @@ void AlterLayoutPass(Graph* graph) { NodeData* new_output_data; Node* new_trans_node; VLOG(3) << new_input_data->id() - << " do layout_tranform from NCHW to NCHWxc"; + << " do layout_transform from NCHW to NCHWxc"; std::tie(new_trans_node, new_output_data) = InsertLayoutTransformNodeAfter( graph, @@ -553,7 +556,7 @@ void AlterLayoutPass(Graph* graph) { new_src_layout, new_input_layouts[i], cinn::common::UniqName(new_input_data->id() + - "_layout_tranform")); + "_layout_transform")); UpdateInferInfos(new_trans_node, {shape_dict[new_input_data->id()]}, {input_types[i]}, @@ -577,7 +580,7 @@ void AlterLayoutPass(Graph* graph) { NodeData* output_data; Node* trans_node; VLOG(3) << source->id() - << " do layout_tranform from NCHW to NCHWxc"; + << " do layout_transform from NCHW to NCHWxc"; std::tie(trans_node, output_data) = InsertLayoutTransformNodeAfter( graph, @@ -587,7 +590,7 @@ void AlterLayoutPass(Graph* graph) { src_layout, new_input_layouts[i], cinn::common::UniqName(source->id() + - "_layout_tranform")); + "_layout_transform")); UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, @@ -602,7 +605,7 @@ void AlterLayoutPass(Graph* graph) { } else if (input_shape_size == 5 && new_input_layouts[i].size() == 4) { // NCHWxc -> NCHW - // insert layout tranfrom + // insert layout transform auto source = inlinks[i]->source(); auto src_layout = input_layouts[i]; layout_dict[source->id()] = src_layout; @@ -611,7 +614,7 @@ void AlterLayoutPass(Graph* graph) { NodeData* output_data; Node* trans_node; VLOG(3) << source->id() - << " do layout_tranform from NCHWxc to NCHW"; + << " do layout_transform from NCHWxc to NCHW"; std::tie(trans_node, output_data) = InsertLayoutTransformNodeAfter( graph, @@ -621,7 +624,7 @@ void AlterLayoutPass(Graph* graph) { src_layout, new_input_layouts[i], cinn::common::UniqName(source->id() + - "_layout_tranform")); + "_layout_transform")); UpdateInferInfos(trans_node, {input_shapes[i]}, {input_types[i]}, @@ -709,7 +712,7 @@ void AlterLayoutPass(Graph* graph) { src_layout, dst_layout, cinn::common::UniqName(node->op()->name + - "_final_layout_tranform")); + "_final_layout_transform")); shape_dict[temp_out->id()] = shape; type_dict[temp_out->id()] = type; layout_dict[temp_out->id()] = src_layout; diff --git a/paddle/cinn/hlir/pass/constant_folding_pass_test.cc b/paddle/cinn/hlir/pass/constant_folding_pass_test.cc index 0cf95ea0a12e5..a30ea35953629 100644 --- a/paddle/cinn/hlir/pass/constant_folding_pass_test.cc +++ b/paddle/cinn/hlir/pass/constant_folding_pass_test.cc @@ -369,7 +369,7 @@ TEST(Constant_Folding, fold_expand_dims_to_fill_constant_2) { TEST(Constant_Folding, fold_expand_dims_to_fill_constant_3) { NetBuilder net_builder("fold_expand_dims_to_fill_constant_3"); - // create model, ExpandDims axes have nagetive value + // create model, ExpandDims axes have negative value int h = 32, w = 32; auto A = net_builder.FillConstant({h, w}, 1.0f, "A"); auto B = net_builder.ExpandDims(A, {1, -1}); diff --git a/paddle/cinn/hlir/pass/dense_merge_pass.cc b/paddle/cinn/hlir/pass/dense_merge_pass.cc index 82341cb8469bf..a726aa1a36c1a 100644 --- a/paddle/cinn/hlir/pass/dense_merge_pass.cc +++ b/paddle/cinn/hlir/pass/dense_merge_pass.cc @@ -26,7 +26,7 @@ using framework::Node; using framework::NodeAttr; // Dense Merge Pass: merge those gemm which has same var as input into a batched -// cubals call op. A * B, A * C, A * D,... after A * [B, C, D,...] Using cublas +// cublas call op. A * B, A * C, A * D,... after A * [B, C, D,...] Using cublas // batched gemm can avoid do concat and slice. class DenseMergePassHelper : public FusionHelperBase { diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass.cc b/paddle/cinn/hlir/pass/fusion_merge_pass.cc index eb251fca8608e..fd023662f9050 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass.cc +++ b/paddle/cinn/hlir/pass/fusion_merge_pass.cc @@ -55,7 +55,7 @@ class FusionMergePassHelper : public FusionHelperBase { } GroupList operator()() { - // run fusion merge untill no update. + // run fusion merge until no update. DoFusionMerge(); for (auto& group : fusion_groups_) { VLOG(3) << "Fusion Group -> " << group->group_id; @@ -170,7 +170,7 @@ class FusionMergePassHelper : public FusionHelperBase { } } if (is_ring) { - LOG(FATAL) << "Exists Ring, Please Check!"; + PADDLE_THROW(phi::errors::Fatal("Exists Ring, Please Check!")); } } } @@ -199,13 +199,13 @@ class FusionMergePassHelper : public FusionHelperBase { // check dependency if (IsDependencySimplify(producer, candidate, candidates)) { VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id - << ", As it depency others!"; + << ", As it dependency others!"; continue; } if (IsDependency(producer, candidate, candidates)) { VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id - << ", As it depency others!"; + << ", As it dependency others!"; continue; } @@ -414,7 +414,7 @@ class FusionMergePassHelper : public FusionHelperBase { std::unordered_set fuse_consumers_unsafe; std::unordered_set fuse_consumers; for (const auto& consumer : consumers) { - VLOG(4) << "Check consuemr " << consumer->group_id + VLOG(4) << "Check consumer " << consumer->group_id << " can fuse to producer " << producer->group_id; // if can't fuse if (!relation.vertical_relation.count(consumer->op_pattern_kind)) { @@ -698,7 +698,7 @@ class FusionMergePassHelper : public FusionHelperBase { sub_group->nodes.insert(sub_group->nodes.begin(), producer->CollectNodes()[0]); sub_group->nodes_set.insert(producer->CollectNodes()[0]); - // remove depency. + // remove dependency. consumer->input_nodes.erase(producer->CollectNodes()[0]); consumer->mut_producer_groups()->erase(producer); producer->mut_consumer_groups()->erase(consumer); @@ -1081,7 +1081,7 @@ class FusionMergePassHelper : public FusionHelperBase { void FusionMergePassInternal(Graph* graph) { if (graph->fusion_groups.size() <= 1) { - VLOG(3) << "Don't do Fusoin Merge Pass...!"; + VLOG(3) << "Don't do Fusion Merge Pass...!"; return; } diff --git a/paddle/cinn/hlir/pass/fusion_merge_pass_util.h b/paddle/cinn/hlir/pass/fusion_merge_pass_util.h index 219d08d7d08e6..5541ec09bc178 100644 --- a/paddle/cinn/hlir/pass/fusion_merge_pass_util.h +++ b/paddle/cinn/hlir/pass/fusion_merge_pass_util.h @@ -330,7 +330,7 @@ inline bool horizontal_relation( }; auto selected_nodes = select_node_set(second_set, op_pattern_kind); - auto check_depency = [&](const Node* node) { + auto check_dependency = [&](const Node* node) { std::queue candidates; std::unordered_set visited_set; candidates.push(node); @@ -360,7 +360,7 @@ inline bool horizontal_relation( }; for (auto node : selected_nodes) { - if (check_depency(node)) { + if (check_dependency(node)) { return false; } } diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc b/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc index 65d0d9eb7c243..b9d553019a459 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass.cc @@ -212,7 +212,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { } } if (is_ring) { - LOG(FATAL) << "Exists Ring, Please Check!"; + PADDLE_THROW(phi::errors::Fatal("Exists Ring, Please Check!")); } } } @@ -244,7 +244,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralHorizontalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralHorizontalFuse handling producer : " << producer->group_id; - const auto& GetFusableConsumerGroupLists = + const auto& GetFusibleConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& MarkFusible = [&](const OpGroupList& candidates) { @@ -255,8 +255,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { EnableFusedHorizontalGroups(&fuse_ctx); return tagged_lists; }; - const auto& GetFusableConsumerGroupList = [&]() -> std::vector { - const auto& group_lists = GetFusableConsumerGroupLists(); + const auto& GetFusibleConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusibleConsumerGroupLists(); if (group_lists.empty()) { return std::vector{}; } @@ -271,7 +271,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return ret; }; - const auto& group_lists = GetFusableConsumerGroupList(); + const auto& group_lists = GetFusibleConsumerGroupList(); if (group_lists.empty()) { return false; } @@ -303,7 +303,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool CallGeneralInputFusePass( const std::unordered_set& consumers) { VLOG(3) << "CallGeneralInputFusePass...!"; - const auto& GetFusableConsumerGroupLists = + const auto& GetFusibleConsumerGroupLists = [&]() -> std::vector { std::vector tagged_lists; const auto& MarkFusible = [&](const OpGroupList& candidates) { @@ -318,8 +318,8 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { EnableFusedInputGroups(&fuse_ctx); return tagged_lists; }; - const auto& GetFusableConsumerGroupList = [&]() -> std::vector { - const auto& group_lists = GetFusableConsumerGroupLists(); + const auto& GetFusibleConsumerGroupList = [&]() -> std::vector { + const auto& group_lists = GetFusibleConsumerGroupLists(); if (group_lists.empty()) { return std::vector{}; } @@ -334,7 +334,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return ret; }; - const auto& group_lists = GetFusableConsumerGroupList(); + const auto& group_lists = GetFusibleConsumerGroupList(); if (group_lists.empty()) { return false; } @@ -522,7 +522,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { bool GeneralVerticalFuse(const GroupPtr& producer) { VLOG(3) << "GeneralVerticalFuse...!"; using GroupSets = std::vector>; - const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + const auto& GetFusibleConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; const auto& MarkFusible = [&](const OpGroupPtr& first, const OpGroupPtr& second) { @@ -534,9 +534,9 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return tagged_sets; }; - auto GetFusableConsumerGroupSet = + auto GetFusibleConsumerGroupSet = [&]() -> std::unordered_set { - const auto& group_sets = GetFusableConsumerOpGroupSets(); + const auto& group_sets = GetFusibleConsumerOpGroupSets(); if (group_sets.empty()) { return {}; } @@ -548,7 +548,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { }; bool update = false; - auto consumer_groups = GetFusableConsumerGroupSet(); + auto consumer_groups = GetFusibleConsumerGroupSet(); if (consumer_groups.size()) { SelectConsumerToFuse(producer, &consumer_groups); } @@ -771,7 +771,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { VLOG(3) << "GeneralRecomputeFuse handling producer : " << producer->group_id; using GroupSets = std::set>; - const auto& GetFusableConsumerOpGroupSets = [&]() -> GroupSets { + const auto& GetFusibleConsumerOpGroupSets = [&]() -> GroupSets { GroupSets tagged_sets; const auto& MarkFusible = [&](const OpGroupPtr& first, const OpGroupPtr& second) { @@ -783,9 +783,9 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { return tagged_sets; }; - auto GetFusableConsumerGroupSet = + auto GetFusibleConsumerGroupSet = [&]() -> std::unordered_set { - const auto& group_sets = GetFusableConsumerOpGroupSets(); + const auto& group_sets = GetFusibleConsumerOpGroupSets(); if (group_sets.empty()) { return {}; } @@ -797,7 +797,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { }; bool update = false; - auto consumer_groups = GetFusableConsumerGroupSet(); + auto consumer_groups = GetFusibleConsumerGroupSet(); if (consumer_groups.size() > 0) { CHECK(consumer_groups.size() == producer->mut_consumer_groups()->size()) << "Recompute requires fuse all consumers!"; @@ -833,7 +833,7 @@ class GeneralFusionMergePassHelper : public FusionHelperBase { sub_group->nodes.insert(sub_group->nodes.begin(), producer->CollectNodes()[0]); sub_group->nodes_set.insert(producer->CollectNodes()[0]); - // remove depency. + // remove dependency. consumer->input_nodes.erase(producer->CollectNodes()[0]); consumer->mut_producer_groups()->erase(producer); producer->mut_consumer_groups()->erase(consumer); diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_horizontal_fuse_pass.cc b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_horizontal_fuse_pass.cc index e953caf20ab7a..642ad8acf6aec 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_horizontal_fuse_pass.cc +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_horizontal_fuse_pass.cc @@ -62,7 +62,7 @@ class DefaultHorizontalFusePass final : public HorizontalFusePass { bool fusionable = false; for (auto& groups : fusionable_consumers) { auto& last = groups.back(); - if (!HorizontalFuseUtil::DetectFusabilityByKind( + if (!HorizontalFuseUtil::DetectFusibilityByKind( ctx, candidate, last)) { continue; } diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_input_fuse_pass.cc b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_input_fuse_pass.cc index 7dc68d65599f9..1f251af14e212 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_input_fuse_pass.cc +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_input_fuse_pass.cc @@ -63,7 +63,7 @@ class DefaultInputFusePass final : public InputFusePass { bool fusionable = false; for (auto& groups : fusionable_consumers) { auto& last = groups.back(); - if (!HorizontalFuseUtil::DetectFusabilityByKind( + if (!HorizontalFuseUtil::DetectFusibilityByKind( ctx, candidate, last)) { continue; } diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_recompute_fuse_pass.cc b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_recompute_fuse_pass.cc index 137a470d5993d..c1eab18569a8c 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_recompute_fuse_pass.cc +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_recompute_fuse_pass.cc @@ -44,7 +44,7 @@ class DefaultRecomputeFusePass final : public RecomputeFusePass { std::vector candidates; for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); - if (!VerticalFuseUtil::DetectFusabilityByKind(ctx, producer, consumer)) { + if (!VerticalFuseUtil::DetectFusibilityByKind(ctx, producer, consumer)) { continue; } unsafe_candidates.push_back(consumer); diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_vertical_fuse_pass.cc b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_vertical_fuse_pass.cc index fcffcb6be03f8..eb74a622db21d 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_vertical_fuse_pass.cc +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/default_vertical_fuse_pass.cc @@ -46,7 +46,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { std::vector candidates; for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); - if (!VerticalFuseUtil::DetectFusabilityByKind(ctx, producer, consumer)) { + if (!VerticalFuseUtil::DetectFusibilityByKind(ctx, producer, consumer)) { break; } candidates.push_back(consumer); @@ -58,7 +58,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass { for (int i = 0; i < consumers.size(); ++i) { const auto& consumer = consumers.at(i); - if (!VerticalFuseUtil::DetectFusabilityByKind(ctx, producer, consumer)) { + if (!VerticalFuseUtil::DetectFusibilityByKind(ctx, producer, consumer)) { continue; } if (ctx->fuse_helper().DetectCycleIfFuse(producer, consumer)) { diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_util.h b/paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_util.h index 81b170637e54d..56612879b6770 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_util.h +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/horizontal_fuse_util.h @@ -29,7 +29,7 @@ template struct HorizontalFuseUtil { using KindKeyT = std::pair; - static bool DetectFusabilityByKind(FusePassCtxT* ctx, + static bool DetectFusibilityByKind(FusePassCtxT* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { const KindKeyT kind_pair(src.kind(), dst.kind()); diff --git a/paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_util.h b/paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_util.h index 4845af9ea94eb..9c754d59bac42 100644 --- a/paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_util.h +++ b/paddle/cinn/hlir/pass/general_fusion_merge_pass/vertical_fuse_util.h @@ -29,7 +29,7 @@ using framework::OpPatternKind; struct VerticalFuseUtil { using KindKeyT = std::pair; - static bool DetectFusabilityByKind(LightwareFusePassCtx* ctx, + static bool DetectFusibilityByKind(LightwareFusePassCtx* ctx, const OpGroupPtr& src, const OpGroupPtr& dst) { const KindKeyT kind_pair(src.kind(), dst.kind()); diff --git a/paddle/cinn/hlir/pass/op_fusion_pass_util.h b/paddle/cinn/hlir/pass/op_fusion_pass_util.h index c8af3db911689..12eece98e1327 100644 --- a/paddle/cinn/hlir/pass/op_fusion_pass_util.h +++ b/paddle/cinn/hlir/pass/op_fusion_pass_util.h @@ -124,7 +124,7 @@ CONDITION_FUNC(reduce_fuse_reduce) { } CONDITION_FUNC(is_horizontal_relation) { - auto check_depency = [&](const Node* node) { + auto check_dependency = [&](const Node* node) { std::queue candidates; std::unordered_set visited_set; candidates.push(node); @@ -157,7 +157,7 @@ CONDITION_FUNC(is_horizontal_relation) { if (helper->GetOpKind(node) != consumer->op_pattern_kind) { continue; } - if (check_depency(node)) { + if (check_dependency(node)) { return false; } } @@ -207,17 +207,17 @@ CONDITION_FUNC(horizontal_or_vertical_reduce_relation) { return false; } - int succesive_reduce_dimension = reduce_shape.at(reduce_axes.back()); + int successive_reduce_dimension = reduce_shape.at(reduce_axes.back()); for (int idx = reduce_axes.size() - 2; idx >= 0; --idx) { if (reduce_axes[idx] == reduce_axes[idx + 1] - 1) { - succesive_reduce_dimension *= reduce_shape[reduce_axes[idx]]; + successive_reduce_dimension *= reduce_shape[reduce_axes[idx]]; continue; } break; } return helper->target_ == cinn::common::DefaultNVGPUTarget() - ? (succesive_reduce_dimension <= helper->target_.max_num_threads() + ? (successive_reduce_dimension <= helper->target_.max_num_threads() ? true : false) : true; diff --git a/paddle/cinn/hlir/pass/opfusion.cc b/paddle/cinn/hlir/pass/opfusion.cc index 537b9abb45881..c8690c0625fbb 100644 --- a/paddle/cinn/hlir/pass/opfusion.cc +++ b/paddle/cinn/hlir/pass/opfusion.cc @@ -83,7 +83,7 @@ class DomTree { const std::vector& nodes) { int size = nodes.size(); dom_nodes_.resize(nodes.size()); - // construct postdom tree, reverse topological_order + // construct post dom tree, reverse topological_order for (int i = size - 1; i >= 0; i--) { auto* dom_node = CreateDomNode(nodes[i]); CHECK(dom_node); @@ -160,7 +160,7 @@ class DomTree { parent = dom_node; CHECK(parent); } else { - // if the out_var links to more than one opnode, then we need to find + // if the out_var links to more than one op_node, then we need to find // the LCA parent = LCA(parent, dom_node, pattern); } @@ -170,7 +170,7 @@ class DomTree { VLOG(2) << sink->id() << "'s op pattern is " << op_pattern; if (op_node->attrs.attr_store.count("pre_run") && absl::get(op_node->attrs.attr_store["pre_run"])) { - // not fuse pre_run opnode + // not fuse pre_run op_node op_pattern = framework::kNonFusible; VLOG(3) << op_node->op()->name << " do pre_run and not fuse"; } @@ -264,7 +264,7 @@ class GraphPartition { auto pattern = op_pattern_dict[op_node->op()]; if (op_node->attrs.attr_store.count("pre_run") && absl::get(op_node->attrs.attr_store["pre_run"])) { - // not fuse pre_run opnode + // not fuse pre_run op_node pattern = framework::kNonFusible; VLOG(3) << op_node->op()->name << " do pre_run and not fuse"; } @@ -412,7 +412,8 @@ class GraphPartition { parent->master_node = child->master_node; if (child->pattern > framework::kBroadcast && parent->pattern > framework::kBroadcast) { - LOG(FATAL) << "can't fuse 2 groups both with complex pattern"; + PADDLE_THROW(phi::errors::InvalidArgument( + "can't fuse 2 groups both with complex pattern")); } else { parent->pattern = child->pattern > parent->pattern ? child->pattern : parent->pattern; @@ -549,7 +550,7 @@ class GraphPartition { void OpFusionPass(Graph* graph) { auto store_nodes = std::get<0>(graph->topological_order()); int node_size = store_nodes.size(); - // construct postdom tree, reverse topological_order + // construct post dom tree, reverse topological_order DomTree tree; auto& dom_nodes = tree.CreatePostDomTree(store_nodes); // graph partition diff --git a/paddle/cinn/hlir/pass/reduce_split_pass.cc b/paddle/cinn/hlir/pass/reduce_split_pass.cc index 1f8c500cc9be0..899c233866ca5 100644 --- a/paddle/cinn/hlir/pass/reduce_split_pass.cc +++ b/paddle/cinn/hlir/pass/reduce_split_pass.cc @@ -71,7 +71,7 @@ uint32_t NextPowerOf2(uint32_t n) { class ReduceSplitPass { public: - // Find the reduce op with nwhc format and large shape, split it into two ops + // Find the reduce op with NWHC format and large shape, split it into two ops static int Apply(framework::Graph* graph) { int MAX_NUM_THREADS = cinn::common::DefaultNVGPUTarget().max_num_threads(); constexpr int MAX_ITER_PER_THREAD = 32; // empirical value diff --git a/paddle/cinn/hlir/pass/single_group_optimize_pass.cc b/paddle/cinn/hlir/pass/single_group_optimize_pass.cc index 816943b38cee0..db67b990cd76e 100644 --- a/paddle/cinn/hlir/pass/single_group_optimize_pass.cc +++ b/paddle/cinn/hlir/pass/single_group_optimize_pass.cc @@ -201,7 +201,7 @@ void SingleGroupOptimizePass::InitNodeToGroups() { CINN_REGISTER_HELPER(SingleGroupOptimizePass) { CINN_REGISTER_PASS(SingleGroupOptimizePass) - .describe("Optimize singel group to improve performance.") + .describe("Optimize single group to improve performance.") .set_change_structure(true) .set_body(cinn::hlir::pass::SingleGroupOptimizePassImpl); diff --git a/paddle/cinn/hlir/pe/CMakeLists.txt b/paddle/cinn/hlir/pe/CMakeLists.txt index 6ac7787749fd4..3ecab5a4d1c76 100755 --- a/paddle/cinn/hlir/pe/CMakeLists.txt +++ b/paddle/cinn/hlir/pe/CMakeLists.txt @@ -16,9 +16,7 @@ gather_srcs( transform.cc vision.cc) -if(NOT CINN_ONLY) - gather_srcs(cinnapi_src SRCS map_expr_to_ir.cc) -endif() +gather_srcs(cinnapi_src SRCS map_expr_to_ir.cc) cinn_cc_test(test_cinn_pe_elementwise SRCS pe_elementwise_test.cc DEPS cinncore) cinn_cc_test(test_cinn_pe_broadcast SRCS pe_broadcast_test.cc DEPS cinncore) diff --git a/paddle/cinn/hlir/pe/broadcast.cc b/paddle/cinn/hlir/pe/broadcast.cc index 439ff30e2691c..fb47ed737fdf3 100644 --- a/paddle/cinn/hlir/pe/broadcast.cc +++ b/paddle/cinn/hlir/pe/broadcast.cc @@ -23,6 +23,7 @@ #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" +#include "paddle/common/errors.h" PD_DECLARE_bool(cinn_bucket_compile); namespace cinn { @@ -145,9 +146,11 @@ void GetBroadcastShape(const std::vector& shape1, broadcast_flag1->emplace_back(true); broadcast_flag2->emplace_back(false); } else { - LOG(FATAL) << "Incompatible broadcast dims " << shape1_new[size1 - i] - << " and " << shape2_new[size2 - i] << " in: " << shape1_new - << " and " << shape2_new << std::endl; + std::stringstream ss; + ss << "Incompatible broadcast dims " << shape1_new[size1 - i] << " and " + << shape2_new[size2 - i] << " in: " << shape1_new << " and " + << shape2_new << std::endl; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } } @@ -357,14 +360,16 @@ Tensor BroadcastTo(const Tensor& A, [=](const std::vector& indice) { std::vector broadcast_indice; for (int idx = 0; idx < axes.size(); ++idx) { - int a_shape_i = A_shape[idx].as_int32(); + int a_shape_i = A_shape[idx].as_int64(); if (a_shape_i == 1) { broadcast_indice.push_back(ir::Expr(0)); } else if (a_shape_i == out_shape[axes[idx]]) { broadcast_indice.push_back(indice[axes[idx]]); } else { - LOG(FATAL) << "fail to broad cast input shape " << a_shape_i - << " to output shape " << out_shape[axes[idx]]; + std::stringstream ss; + ss << "fail to broad cast input shape " << a_shape_i + << " to output shape " << out_shape[axes[idx]]; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } return A(broadcast_indice); @@ -374,36 +379,31 @@ Tensor BroadcastTo(const Tensor& A, Tensor BroadcastTo(const Tensor& A, const std::vector& out_shape, - const std::vector& broadcast_axes, const std::string& out_name) { auto A_shape = A->shape; - CHECK_EQ(A_shape.size(), broadcast_axes.size()) - << "broadcast_axes's size should be same with the input shape's size"; - CHECK_GE(out_shape.size(), broadcast_axes.size()) - << "broadcast_axes's size should be no more than out_shape's size"; - auto axes = broadcast_axes; - for (auto& axis : axes) { - // if axis < 0, plus out_shape.size - if (axis < 0) { - axis = out_shape.size() + axis; - } - CHECK_LT(axis, out_shape.size()); - } - std::sort(axes.begin(), axes.end()); + PADDLE_ENFORCE_GE( + out_shape.size(), + A_shape.size(), + ::common::errors::InvalidArgument( + "broadcast_to's out_shape's size should be GreaterEqual " + "with the input shape's size")); return Compute( ToCinnExprs(out_shape), [=](const std::vector& indice) { std::vector broadcast_indice; - for (int idx = 0; idx < axes.size(); ++idx) { - ir::Expr a_shape_i = A_shape[idx]; + int out_A_offset = out_shape.size() - A_shape.size(); + for (int idx = out_A_offset; idx < out_shape.size(); ++idx) { + ir::Expr a_shape_i = A_shape[idx - out_A_offset]; if (MathEqual(a_shape_i, ir::Expr(1))) { broadcast_indice.push_back(ir::Expr(0)); - } else if (MathEqual(a_shape_i, out_shape[axes[idx]])) { - broadcast_indice.push_back(indice[axes[idx]]); + } else if (MathEqual(a_shape_i, out_shape[idx])) { + broadcast_indice.push_back(indice[idx]); } else { - LOG(FATAL) << "fail to broad cast input shape " << a_shape_i - << " to output shape " << out_shape[axes[idx]]; + std::stringstream ss; + ss << "fail to broad cast input shape " << a_shape_i + << " to output shape " << out_shape[idx]; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } return A(broadcast_indice); diff --git a/paddle/cinn/hlir/pe/broadcast.h b/paddle/cinn/hlir/pe/broadcast.h index efdafee9c9dce..f2cb2649ad499 100644 --- a/paddle/cinn/hlir/pe/broadcast.h +++ b/paddle/cinn/hlir/pe/broadcast.h @@ -118,7 +118,6 @@ ir::Tensor BroadcastTo( ir::Tensor BroadcastTo( const ir::Tensor& A, const std::vector& out_shape, - const std::vector& broadcast_axes, const std::string& out_name = cinn::common::UniqName("T_broadcast_to_out")); // This operator checks if all x and y satisfy the condition: |x - y| <= atol + diff --git a/paddle/cinn/hlir/pe/elementwise.cc b/paddle/cinn/hlir/pe/elementwise.cc index 60933cd66c4b0..559014658de0e 100644 --- a/paddle/cinn/hlir/pe/elementwise.cc +++ b/paddle/cinn/hlir/pe/elementwise.cc @@ -197,29 +197,45 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::vector& A_expr_shape = A->shape; int input_total_size = 1; int output_total_size = 1; - for (auto& i : A_expr_shape) { - CHECK(i.is_constant()) << "Input tensor's shape should be constant value."; - input_total_size *= static_cast(i.get_constant()); + std::vector A_stride_info; + int stride_base = 1; + A_stride_info.push_back(Expr(stride_base)); + + for (int i = A_expr_shape.size() - 1; i > 0; i--) { + stride_base *= static_cast(A_expr_shape[i].get_constant()); + A_stride_info.insert(A_stride_info.begin(), Expr(stride_base)); } + + std::vector new_stride_info; + stride_base = 1; + new_stride_info.push_back(Expr(stride_base)); + + for (int i = new_shape.size() - 1; i > 0; --i) { + stride_base *= new_shape[i]; + + new_stride_info.insert(new_stride_info.begin(), Expr(stride_base)); + } + for (auto& i : new_shape) { output_total_size *= i; new_expr_shape.push_back(Expr(i)); } - CHECK_EQ(input_total_size, output_total_size) - << "In op reshape, the input tensor and output tensor's total size " - "should be equal, please check!"; + auto res = Compute( new_expr_shape, [=](const std::vector& indice) { - Expr offset = Expr(0); - for (int i = 0; i < indice.size(); i++) { - offset = offset * new_expr_shape[i] + indice[i]; + Expr offset = indice[0] * new_stride_info[0]; + for (int i = 1; i < indice.size(); i++) { + offset = offset + indice[i] * new_stride_info[i]; } std::vector indice_a; for (int i = A_expr_shape.size() - 1; i >= 0; i--) { - auto temp = common::AutoSimplify(offset % A_expr_shape[i]); + auto inner_offset = offset; + if (i != (A_expr_shape.size() - 1)) { + inner_offset = inner_offset / A_stride_info[i]; + } + auto temp = inner_offset % A_expr_shape[i]; indice_a.insert(indice_a.begin(), temp); - offset = (offset - temp) / A_expr_shape[i]; } return A(indice_a); }, @@ -232,32 +248,45 @@ ir::Tensor Reshape(const ir::Tensor& A, const std::string& name) { std::vector new_expr_shape; const std::vector& A_expr_shape = A->shape; - ir::Expr input_total_size(1); - for (auto& i : A_expr_shape) { - // CHECK(i.is_constant()) << "Input tensor's shape should be constant - // value."; - input_total_size = ir::Mul::Make(input_total_size, i); + Expr input_total_size(1); + Expr output_total_size(1); + + std::vector A_stride_info; + Expr stride_base(1); + A_stride_info.push_back(stride_base); + for (int i = A_expr_shape.size() - 1; i > 0; i--) { + stride_base = stride_base * A_expr_shape[i]; + A_stride_info.insert(A_stride_info.begin(), Expr(stride_base)); + } + + std::vector new_stride_info; + stride_base = Expr(1); + new_stride_info.push_back(Expr(stride_base)); + for (int i = new_shape.size() - 1; i > 0; --i) { + stride_base = stride_base * new_shape[i]->dim_expr; + new_stride_info.insert(new_stride_info.begin(), Expr(stride_base)); } - ir::Expr output_total_size(1); + for (auto& i : new_shape) { - output_total_size = ir::Mul::Make(output_total_size, i->dim_expr); + output_total_size = output_total_size * i->dim_expr; new_expr_shape.push_back(i->dim_expr); } - // CHECK_EQ(input_total_size, output_total_size) - // << "In op reshape, the input tensor and output tensor's total size " - // "should be equal, please check!"; + auto res = Compute( new_expr_shape, [=](const std::vector& indice) { - Expr offset = Expr(0); - for (int i = 0; i < indice.size(); i++) { - offset = offset * new_expr_shape[i] + indice[i]; + Expr offset = indice[0] * new_stride_info[0]; + for (int i = 1; i < indice.size(); i++) { + offset = offset + indice[i] * new_stride_info[i]; } std::vector indice_a; for (int i = A_expr_shape.size() - 1; i >= 0; i--) { - auto temp = offset % A_expr_shape[i]; + auto inner_offset = offset; + if (i != (A_expr_shape.size() - 1)) { + inner_offset = inner_offset / A_stride_info[i]; + } + auto temp = inner_offset % A_expr_shape[i]; indice_a.insert(indice_a.begin(), temp); - offset = (offset - temp) / A_expr_shape[i]; } return A(indice_a); }, @@ -277,6 +306,14 @@ ir::Tensor Cast(const ir::Tensor& A, return res; } +ir::Tensor Store(const ir::Tensor& A, const std::string& name) { + auto res = Compute( + A->shape, + [=](const std::vector& indices) { return A(indices); }, + name); + return res; +} + ir::Tensor Arange(const float start, const float stop, const float step, @@ -295,6 +332,28 @@ ir::Tensor Arange(const float start, return res; } +ir::Tensor Tril(const ir::Tensor& A, + const int diagonal, + const std::vector& out_shape, + const std::string& name) { + ir::Tensor res = Compute( + ToCinnExprs(out_shape), + [=](const std::vector& indice) { + PADDLE_ENFORCE_GE(indice.size(), + size_t(2), + phi::errors::InvalidArgument( + "The Tril op input tensor must have a rank " + "greater than or equal to 2.")); + std::vector new_indice(indice.end() - 2, indice.end()); + Expr col_indice = indice.back(); + return ir::Select::Make(new_indice[0] >= new_indice[1] - diagonal, + A(indice), + ir::Zero(A->type())); + }, + name); + return res; +} + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/pe/elementwise.h b/paddle/cinn/hlir/pe/elementwise.h index a9bbb71193255..fe8db5cf775d0 100644 --- a/paddle/cinn/hlir/pe/elementwise.h +++ b/paddle/cinn/hlir/pe/elementwise.h @@ -139,6 +139,9 @@ ir::Tensor Cast(const ir::Tensor& A, const Type& dtype, const std::string& name = UniqName("T_Elementwise_Cast_out")); +ir::Tensor Store(const ir::Tensor& A, + const std::string& name = UniqName("T_Elementwise_Store_out")); + ir::Tensor Arange( const float start, const float stop, @@ -146,6 +149,11 @@ ir::Tensor Arange( const Type& dtype, const std::string& name = UniqName("T_Elementwise_Arange_out")); +ir::Tensor Tril(const ir::Tensor& A, + const int diagonal, + const std::vector& out_shape, + const std::string& name = UniqName("T_Elementwise_Tril_out")); + } // namespace pe } // namespace hlir } // namespace cinn diff --git a/paddle/cinn/hlir/pe/ir_schedule_pe.cc b/paddle/cinn/hlir/pe/ir_schedule_pe.cc index 36052d25f8a44..d224a5fd1e1ca 100644 --- a/paddle/cinn/hlir/pe/ir_schedule_pe.cc +++ b/paddle/cinn/hlir/pe/ir_schedule_pe.cc @@ -200,7 +200,7 @@ std::vector IRCudaScheduleMatMul( ir_sch.MergeExprs(); // Generally, there are 2 ScheduleBlocks in the lowered function, // the first is for reduce_init and the second is the real compute block, - // here we use loops of the first block to Bind GPU index in top spatial axies + // here we use loops of the first block to Bind GPU index in top spatial axes auto init_block = ir_sch.GetAllBlocks().front(); VLOG(3) << "Matmul lowered expr:\n" << ir_sch.GetModule().GetExprs().front(); @@ -784,7 +784,8 @@ void IRCudaScheduleBlockShuffleReduce(ir::IRSchedule &ir_sch, // NOLINT } return loop_var_count; } - LOG(FATAL) << "Can't find var in tensor indexes!"; + PADDLE_THROW( + phi::errors::InvalidArgument("Can't find var in tensor indexes!")); }; auto loop_var_count = get_loop_index(ir_sch.GetLoops(reduce_out->name).back(), ir_sch.GetBlock(reduce_out->name)); diff --git a/paddle/cinn/hlir/pe/map_expr_to_ir.cc b/paddle/cinn/hlir/pe/map_expr_to_ir.cc index 2f1e854672fd4..e7a2de5150026 100644 --- a/paddle/cinn/hlir/pe/map_expr_to_ir.cc +++ b/paddle/cinn/hlir/pe/map_expr_to_ir.cc @@ -158,8 +158,9 @@ class MapExprToIrTranslator { DoEach(expr); break; default: - LOG(FATAL) << "Visit node_type = " << expr.node_type() - << ", not supported!"; + std::stringstream ss; + ss << "Visit node_type = " << expr.node_type() << ", not supported!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); break; } } @@ -220,7 +221,7 @@ class MapExprToIrTranslator { } else { return NoInlineTranslator::Call(internal_stmt); } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } std::optional TranslateOpExprImpl( @@ -233,7 +234,8 @@ class MapExprToIrTranslator { std::vector TranslateTensorIndexImpl( const OpCall& op_call, const IterExprs4TensorT& IterExprs4Tensor) const { - LOG(FATAL) << "Dead code, no TensorIndexExpr for OpCall"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Dead code, no TensorIndexExpr for OpCall")); } std::vector TranslateTensorIndexImpl( @@ -381,7 +383,7 @@ class MapExprToIrTranslator { return (this->*make_store_rvalue_expr)( store_rvalue, op_expr_children, IterExprs4Tensor); } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } std::optional TranslateOpCallImpl( @@ -685,13 +687,13 @@ class MapExprToIrTranslator { std::tuple GetForTypeAndInfoImpl(const Vectorize& loop_type, const LoopDescriptor& ld) const { - LOG(FATAL) << "Vectorize not supported yet"; + PADDLE_THROW(phi::errors::InvalidArgument("Vectorize not supported yet")); } std::tuple GetForTypeAndInfoImpl(const Unroll& loop_type, const LoopDescriptor& ld) const { - LOG(FATAL) << "Unroll not supported yet"; + PADDLE_THROW(phi::errors::InvalidArgument("Unroll not supported yet")); } std::tuple GetForTypeAndInfo( @@ -704,7 +706,7 @@ class MapExprToIrTranslator { ir::Expr Accumulate(const std::vector& ir_exprs) const { if (ir_exprs.size() == 0) { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } else if (ir_exprs.size() == 1) { return ir_exprs.at(0); } else { @@ -714,12 +716,12 @@ class MapExprToIrTranslator { } return ret; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } ir::Expr Multiply(const std::vector& ir_exprs) const { if (ir_exprs.size() == 0) { - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } else if (ir_exprs.size() == 1) { return ir_exprs.at(0); } else { @@ -729,7 +731,7 @@ class MapExprToIrTranslator { } return ret; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } ir::Expr GetStride(const List& dims, int start) const { @@ -820,16 +822,16 @@ class MapExprToIrTranslator { } ir::Expr TranslateDimExprImpl(const ::symbol::Max& dim_expr) const { - LOG(FATAL) << "Not Supported yet"; + PADDLE_THROW(phi::errors::Unimplemented("Not supported yet")); } ir::Expr TranslateDimExprImpl(const ::symbol::Min& dim_expr) const { - LOG(FATAL) << "Not Supported yet"; + PADDLE_THROW(phi::errors::Unimplemented("Not supported yet")); } ir::Expr TranslateDimExprImpl( const ::symbol::Broadcast& dim_expr) const { - LOG(FATAL) << "Not Supported yet"; + PADDLE_THROW(phi::errors::Unimplemented("Not supported yet")); } ir::Expr TranslateDimExpr(const Value& value) const { @@ -859,7 +861,9 @@ class MapExprToIrTranslator { } else if (Match(value)) { return TranslateBI(value); } else { - LOG(FATAL) << "Not supported yet! " << ToTxtString(value); + std::stringstream ss; + ss << "Not supported yet! " << ToTxtString(value); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } diff --git a/paddle/cinn/hlir/pe/nn.cc b/paddle/cinn/hlir/pe/nn.cc index 9c10e1ad137c2..9e48b26ae9392 100644 --- a/paddle/cinn/hlir/pe/nn.cc +++ b/paddle/cinn/hlir/pe/nn.cc @@ -54,7 +54,9 @@ std::string Type2StrForNN(cinn::common::Type type) { } else if (type.is_float16()) { return "fp16"; } - LOG(FATAL) << "NN Not Support " << type; + std::stringstream ss; + ss << "NN Not Support " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return ""; } @@ -1397,7 +1399,9 @@ std::vector Pool1d(const Tensor &tensor, } else if (data_format == "NWC") { width_axis = 1; } else { - LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; + std::stringstream ss; + ss << "Unsupported data format: " << data_format << std::endl; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } CHECK_EQ(tensor->shape.size(), 3U) << "pool1d requires tensor's shape_size to be 3\n"; @@ -1459,7 +1463,7 @@ std::vector GlobalPool2d(const Tensor &tensor, UniqName(output_name)); return {ret, temp}; } else { - LOG(FATAL) << "unsupported pooling type."; + PADDLE_THROW(phi::errors::InvalidArgument("unsupported pooling type.")); } return {}; } @@ -1486,7 +1490,9 @@ std::vector Pool2d(const Tensor &tensor, height_axis = 2; width_axis = 3; } else { - LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; + std::stringstream ss; + ss << "Unsupported data format: " << data_format << std::endl; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } CHECK(tensor->shape.size() == 4U || tensor->shape.size() == 5U) << "pool2d requires tensor's shape_size to be 4 or 5\n"; @@ -1524,7 +1530,9 @@ std::vector Pool3d(const Tensor &tensor, height_axis = 2; width_axis = 3; } else { - LOG(FATAL) << "Unsupported data format: " << data_format << std::endl; + std::stringstream ss; + ss << "Unsupported data format: " << data_format << std::endl; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } CHECK_EQ(tensor->shape.size(), 5U) << "pool1d requires tensor's shape_size to be 5\n"; @@ -1558,8 +1566,9 @@ Tensor DropoutInfer(const ir::Tensor &tensor, // fusion schedule. return Identity(tensor, output_name).front(); } else { - LOG(FATAL) << "dropout_implementation attr must be 'downgrade_in_infer' or " - "'upscale_in_train'\n"; + PADDLE_THROW(phi::errors::InvalidArgument( + "dropout_implementation attr must be 'downgrade_in_infer' or " + "'upscale_in_train'\n")); } } diff --git a/paddle/cinn/hlir/pe/reduction.cc b/paddle/cinn/hlir/pe/reduction.cc index 7e33a1475e48b..b831d1b588472 100644 --- a/paddle/cinn/hlir/pe/reduction.cc +++ b/paddle/cinn/hlir/pe/reduction.cc @@ -90,7 +90,9 @@ std::string Type2StrForReduce(cinn::common::Type type) { } else if (type.is_bool()) { return ""; } - LOG(FATAL) << "Reduce Not Support " << type; + std::stringstream ss; + ss << "Reduce Not Support " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return ""; } @@ -129,6 +131,13 @@ void GetOutputShape(const std::vector& real_axes, if (output_shape->empty()) { output_shape->push_back(cinn::common::make_one()); } + + CHECK(!tensor->shape.empty()); + if (tensor->shape[0]->type() == Int(64)) { + for (auto& shape_item : *output_shape) { + shape_item->convert_int32_to_int64(); + } + } } /*! @@ -166,6 +175,14 @@ Tensor DoReduce(const Tensor& tensor, int indice_cnt = 0; int reduce_cnt = 0; + // Set keepdim flags of indices. + if (tensor->shape.size() == indices.size()) { + for (const auto& i : real_axes) { + VLOG(4) << "Set is_keepdim = true for var(" << i << ")"; + indices[i].as_var_ref()->is_keepdim = true; + } + } + for (size_t i = 0; i < tensor->shape.size(); ++i) { bool squeeze_i = std::find(squeeze_axes.begin(), squeeze_axes.end(), i) != squeeze_axes.end(); @@ -1081,9 +1098,15 @@ std::string CrossThreadReduceExternalFuncName(const ir::Expr& op, const ir::Expr& tensor) { CHECK_NOTNULL(tensor.as_tensor()); if (op.As()) { + if (tensor.as_tensor()->type().is_bool()) { + return "cinn_block_reduce_any_internal_shm"; + } return "cinn_block_reduce_sum" + Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm"; } else if (op.As()) { + if (tensor.as_tensor()->type().is_bool()) { + return "cinn_block_reduce_all_internal_shm"; + } return "cinn_block_reduce_prod" + Type2StrForReduce(tensor.as_tensor()->type()) + "_internal_shm"; } else if (op.As()) { @@ -1097,7 +1120,9 @@ std::string CrossThreadReduceExternalFuncName(const ir::Expr& op, } else if (op.As()) { return "cinn_block_reduce_any_internal_shm"; } else { - LOG(FATAL) << "Reduce type: " << op << " Not supported yet!"; + std::stringstream ss; + ss << "Reduce type: " << op << " Not supported yet!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return ""; } diff --git a/paddle/cinn/hlir/pe/schedule.cc b/paddle/cinn/hlir/pe/schedule.cc index 3c3067ce436ab..3e4af70e1b1cc 100644 --- a/paddle/cinn/hlir/pe/schedule.cc +++ b/paddle/cinn/hlir/pe/schedule.cc @@ -47,8 +47,8 @@ ScheduleParam::ScheduleParam(cinn::common::Target::Arch arch) { break; } default: { - LOG(FATAL) - << "Schedule params must be initialized with target x86 or nvgpu."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Schedule params must be initialized with target x86 or nvgpu.")); } } } @@ -290,7 +290,7 @@ void MatmulScheduleCPU(poly::StageMap stages, for (int i = 0; i < all_axes_inner.size(); ++i) { all_axes.push_back(all_axes_inner[i]); } - // int axies + // int axes CHECK_EQ(all_axes.size(), out_axis_dims); if (is_k_splited) { if (is_m_splited || is_n_splited) { @@ -2454,8 +2454,9 @@ void CudaScheduleConv2(poly::StageMap stages, } else if (stages[PR]->n_out_dims() == 19) { stages[PR]->Fuse({13, 14, 15, 16, 17, 18}); } else { - LOG(FATAL) << "PR number of output dims is wrong: " - << stages[PR]->n_out_dims(); + std::stringstream ss; + ss << "PR number of output dims is wrong: " << stages[PR]->n_out_dims(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (stages[KR]->n_out_dims() == 18) { @@ -2463,8 +2464,9 @@ void CudaScheduleConv2(poly::StageMap stages, } else if (stages[KR]->n_out_dims() == 19) { stages[KR]->Fuse({13, 14, 15, 16, 17, 18}); } else { - LOG(FATAL) << "KR number of output dims is wrong: " - << stages[KR]->n_out_dims(); + std::stringstream ss; + ss << "KR number of output dims is wrong: " << stages[KR]->n_out_dims(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } int thread_z = f_param[2]; int thread_x = x_param[2]; @@ -2768,8 +2770,11 @@ void CudaScheduleInjective(poly::Stage *stage, if (new_num_thread % 32 != 0) { new_num_thread = MaxFactorLessThan(prod_size, num_thread); } - if (new_num_thread == 1) - LOG(FATAL) << "prod_size out of range: " << prod_size; + if (new_num_thread == 1) { + std::stringstream ss; + ss << "prod_size out of range: " << prod_size; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); + } CHECK_GT(prod_size, new_num_thread); stage->Split(0, new_num_thread); diff --git a/paddle/cinn/hlir/pe/transform.cc b/paddle/cinn/hlir/pe/transform.cc index 2e78caca83206..3cd4120f89a1b 100644 --- a/paddle/cinn/hlir/pe/transform.cc +++ b/paddle/cinn/hlir/pe/transform.cc @@ -1070,18 +1070,25 @@ ir::Tensor SliceSymbolic(const ir::Tensor& A, input_shape.emplace_back(shape); } - std::vector new_starts(starts); + std::vector new_starts; + std::transform(starts.begin(), + starts.end(), + std::back_inserter(new_starts), + [](const int start) { return ir::Expr(start); }); + for (int i = 0; i < axes.size(); i++) { - CHECK(input_shape[axes[i]].is_constant()) - << "Not supported Slice in dynamic dimensions, because the " - "relationship between slice range and symbol size cannot be " - "determined at compile time"; - if (new_starts[i] < -input_shape[axes[i]].as_int64()) { - new_starts[i] = 0; - } else if (new_starts[i] < 0) { - new_starts[i] = input_shape[axes[i]].as_int64() + new_starts[i]; - } else if (new_starts[i] > input_shape[axes[i]].as_int64()) { - new_starts[i] = input_shape[axes[i]].as_int64() - 1; + if (input_shape[axes[i]].is_constant()) { + if (new_starts[i].as_int64() < -input_shape[axes[i]].as_int64()) { + new_starts[i] = ir::Expr(0); + } else if (new_starts[i].as_int64() < 0) { + new_starts[i] = input_shape[axes[i]].as_int64() + new_starts[i]; + } else if (new_starts[i].as_int64() > input_shape[axes[i]].as_int64()) { + new_starts[i] = input_shape[axes[i]].as_int64() - ir::Expr(1); + } + } else { + if (new_starts[i].as_int64() < 0) { + new_starts[i] = ir::Add::Make(input_shape[axes[i]], new_starts[i]); + } } } @@ -1269,7 +1276,8 @@ ir::Tensor ScatterAssign(const ir::Tensor& input, } else if (target.arch == cinn::common::Target::Arch::X86) { extern_fun_name.assign("cinn_host_find_int"); } else { - LOG(FATAL) << "ScatterAssign only support X86 and NVGPU ! Please Check.\n"; + PADDLE_THROW(phi::errors::Fatal( + "ScatterAssign only support X86 and NVGPU ! Please Check.\n")); } auto pos_axis = axis; diff --git a/paddle/cinn/ir/group_schedule/CMakeLists.txt b/paddle/cinn/ir/group_schedule/CMakeLists.txt index d53ce85347b61..c23653da8d6e9 100644 --- a/paddle/cinn/ir/group_schedule/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/CMakeLists.txt @@ -4,4 +4,5 @@ gather_srcs(cinnapi_src SRCS base_group_scheduler.cc) gather_srcs(cinnapi_src SRCS st_shape_group_scheduler.cc) gather_srcs(cinnapi_src SRCS dy_shape_group_scheduler.cc) +add_subdirectory(config) add_subdirectory(tactic) diff --git a/paddle/cinn/ir/group_schedule/base_group_scheduler.cc b/paddle/cinn/ir/group_schedule/base_group_scheduler.cc index a740ad268cb09..8a96fe840f85a 100644 --- a/paddle/cinn/ir/group_schedule/base_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/base_group_scheduler.cc @@ -23,13 +23,14 @@ std::unique_ptr GroupScheduler::Make( ir::IRSchedule* ir_sch, const std::unordered_set& output_tensor_names, const cinn::common::Target& target, - bool is_dy_shape) { + bool is_dy_shape, + const std::shared_ptr& group_info) { if (is_dy_shape) { return std::make_unique( - ir_sch, output_tensor_names, target); + ir_sch, output_tensor_names, target, group_info); } else { return std::make_unique( - ir_sch, output_tensor_names, target); + ir_sch, output_tensor_names, target, group_info); } } diff --git a/paddle/cinn/ir/group_schedule/base_group_scheduler.h b/paddle/cinn/ir/group_schedule/base_group_scheduler.h index 33cce051f1845..ef77397066351 100644 --- a/paddle/cinn/ir/group_schedule/base_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/base_group_scheduler.h @@ -14,9 +14,21 @@ #pragma once #include "paddle/cinn/common/target.h" +#include "paddle/cinn/ir/group_schedule/config/group_tile_config.h" +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule_block_graph.h" +namespace cinn { +namespace hlir { +namespace framework { +namespace pir { +struct GroupInfo; +} +} // namespace framework +} // namespace hlir +} // namespace cinn + namespace cinn { namespace ir { @@ -27,12 +39,15 @@ using SymbolicPredicate = Expr; */ class GroupScheduler { public: - GroupScheduler(ir::IRSchedule* ir_sch, - const std::unordered_set& output_tensor_names, - const cinn::common::Target& target) + GroupScheduler( + ir::IRSchedule* ir_sch, + const std::unordered_set& output_tensor_names, + const cinn::common::Target& target, + const std::shared_ptr& group_info) : ir_sch_(ir_sch), output_tensor_names_(output_tensor_names), - target_(target) { + target_(target), + group_info_(group_info) { schedule_block_graph_ = std::make_unique(*ir_sch_); } @@ -40,7 +55,9 @@ class GroupScheduler { ir::IRSchedule* ir_sch, const std::unordered_set& output_tensor_names, const cinn::common::Target& target, - bool is_dy_shape = false); + bool is_dy_shape = false, + const std::shared_ptr& group_info = + nullptr); virtual ~GroupScheduler() = default; @@ -57,6 +74,8 @@ class GroupScheduler { // Graph in units of ScheduleBlockNode, each node corresponds to a // ScheduleBlock in IR. std::unique_ptr schedule_block_graph_; + + std::shared_ptr group_info_; }; } // namespace ir diff --git a/paddle/cinn/ir/group_schedule/config/CMakeLists.txt b/paddle/cinn/ir/group_schedule/config/CMakeLists.txt new file mode 100644 index 0000000000000..394e17eae21a7 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/config/CMakeLists.txt @@ -0,0 +1,3 @@ +core_gather_headers() + +gather_srcs(cinnapi_src SRCS group_tile_config.cc) diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.cc b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc new file mode 100644 index 0000000000000..0d443086bdce9 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.cc @@ -0,0 +1,339 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/config/group_tile_config.h" +#include "paddle/cinn/hlir/framework/pir/op_lowering_impl.h" + +namespace cinn { +namespace ir { + +const int kMaxNumel = INT32_MAX; + +int64_t Next2Power(int64_t n) { + if (n == 1) { + return 1; + } + return int64_t(std::pow(2.0, std::ceil(std::log2(n)))); +} + +std::shared_ptr InitBasicInfo( + const std::shared_ptr& group_info) { + std::shared_ptr base_info = + std::make_shared(); + base_info->reduce_tensor_names = group_info->reduce_var_names; + base_info->shared_var_names = group_info->shared_var_names; + base_info->direct_output_var_names = group_info->direct_output_var_names; + base_info->broadcast_info = group_info->broadcast_info; + base_info->broadcast_to_elementwise = group_info->broadcast_to_elementwise; + base_info->data_rank = group_info->data_space.size(); + + std::set reduce_dim_loc; + for (auto dim : group_info->reduce_axis) { + if (dim < 0) { + dim += base_info->data_rank; + } + base_info->reduce_axis.push_back(dim); + reduce_dim_loc.insert(dim); + } + + base_info->spatial_numel = 1; + base_info->reduce_numel = 1; + for (int64_t i = 0; i < base_info->data_rank; ++i) { + if (reduce_dim_loc.count(i)) { + if (group_info->data_space[i] == -1) base_info->has_dynamic_reduce = true; + base_info->reduce_numel *= group_info->data_space[i]; + } else { + if (group_info->data_space[i] == -1) + base_info->has_dynamic_spatial = true; + base_info->spatial_numel *= group_info->data_space[i]; + } + } + base_info->is_reduce_all = + (base_info->reduce_axis.size() == base_info->data_rank); + + return base_info; +} + +std::unordered_map +BuildPureStaticShapeConfig( + const std::shared_ptr& base_info, + const common::Target& target) { + if (base_info->spatial_numel == 1) { // reduce all + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ kMaxNumel}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ 8, + /* tree_reduce_num = */ 256, + /* spatial_inner_num = */ 1, + /* reduce_method = */ BlockReduceMethod()}; + return {{bucket_info, tile_config}}; + } else if (base_info->reduce_numel == 1) { // no reduce + int64_t spatial_block = Next2Power(base_info->spatial_numel); + if (spatial_block > 1024) { + spatial_block = 1024; + } + int64_t warp_num = spatial_block / 128; + if (warp_num == 0) { + warp_num = 1; + } + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 1}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ warp_num, + /* tree_reduce_num = */ 1, + /* spatial_inner_num = */ 4, + /* reduce_method = */ NoneReduceMethod()}; + return {{bucket_info, tile_config}}; + } else if (base_info->reduce_numel <= 256) { + // warp reduce + int64_t reduce_block = Next2Power(base_info->reduce_numel); + int64_t spatial_inner_num = 256 / reduce_block; + int64_t tree_reduce_num = 32; + int64_t warp_num = 8; + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 256}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ warp_num, + /* tree_reduce_num = */ tree_reduce_num, + /* spatial_inner_num = */ spatial_inner_num, + /* reduce_method = */ WarpReduceMethod()}; + return {{bucket_info, tile_config}}; + } else if (base_info->reduce_numel <= 2048) { + int64_t spatial_block = 1; + int64_t reduce_block = + int64_t(std::ceil(base_info->reduce_numel * 1.0 / 256.0)) * 256; + int64_t warp_num = reduce_block / 256; + int64_t spatial_inner_num = 1; + int64_t reduce_inner_num = 8; + int64_t tree_reduce_num = reduce_block / reduce_inner_num; + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 257, + /* rb_upper_bound = */ 2048}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ warp_num, + /* tree_reduce_num = */ tree_reduce_num, + /* spatial_inner_num = */ spatial_inner_num, + /* reduce_method = */ BlockReduceMethod()}; + return {{bucket_info, tile_config}}; + } else { + int64_t spatial_block = 1; + int64_t reduce_block = 2048; + int64_t warp_num = 8; + int64_t reduce_inner_num = + int64_t(std::ceil(base_info->reduce_numel * 1.0 / 256.0)); + int64_t spatial_inner_num = 1; + int64_t tree_reduce_num = reduce_block / reduce_inner_num; + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 2049, + /* rb_upper_bound = */ kMaxNumel}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ warp_num, + /* tree_reduce_num = */ tree_reduce_num, + /* spatial_inner_num = */ spatial_inner_num, + /* reduce_method = */ NoneReduceMethod()}; + return {{bucket_info, tile_config}}; + } +} + +std::unordered_map +BuildStaticSpatialConfig( + const std::shared_ptr& base_info, + const common::Target& target) { + if (base_info->spatial_numel == 1) { // reduce all + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ kMaxNumel}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ 8, + /* tree_reduce_num = */ 256, + /* spatial_inner_num = */ 1, + /* reduce_method = */ BlockReduceMethod()}; + return {{bucket_info, tile_config}}; + } else { + BucketInfo bucket_info_1_256{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 256}; + ScheduleConfig::TileConfig tile_config_1_256{ + /* warp_num = */ 8, + /* tree_reduce_num = */ 32, + /* spatial_inner_num = */ 1, + /* reduce_method = */ WarpReduceMethod()}; + + BucketInfo bucket_info_257_2048{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 257, + /* rb_upper_bound = */ 2048}; + ScheduleConfig::TileConfig tile_config_257_2048{ + /* warp_num = */ 8, + /* tree_reduce_num = */ 128, + /* spatial_inner_num = */ 1, + /* reduce_method = */ BlockReduceMethod()}; + + BucketInfo bucket_info_2049_INF{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 2049, + /* rb_upper_bound = */ kMaxNumel}; + ScheduleConfig::TileConfig tile_config_2049_INF{ + /* warp_num = */ 8, + /* tree_reduce_num = */ 256, + /* spatial_inner_num = */ 1, + /* reduce_method = */ BlockReduceMethod()}; + + return {{bucket_info_1_256, tile_config_1_256}, + {bucket_info_257_2048, tile_config_257_2048}, + {bucket_info_2049_INF, tile_config_2049_INF}}; + } +} + +std::unordered_map +BuildStaticReduceConfig( + const std::shared_ptr& base_info, + const common::Target& target) { + if (base_info->reduce_numel == 1) { + BucketInfo bucket_info__1_1023{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ 1023, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 1}; + ScheduleConfig::TileConfig tile_config__1_1023{ + /* warp_num = */ -1, + /* tree_reduce_num = */ 1, + /* spatial_inner_num = */ 1, + /* reduce_method = */ NoneReduceMethod()}; + BucketInfo bucket_info__1024_1M{/* sp_lower_bound = */ 1024, + /* sp_upper_bound = */ 1024 * 1024 - 1, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 1}; + ScheduleConfig::TileConfig tile_config__1024_1M{ + /* warp_num = */ 32, + /* tree_reduce_num = */ 1, + /* spatial_inner_num = */ 4, + /* reduce_method = */ NoneReduceMethod()}; + BucketInfo bucket_info__1M_INF{/* sp_lower_bound = */ 1024 * 1024, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 1, + /* rb_upper_bound = */ 1}; + ScheduleConfig::TileConfig tile_config__1M_INF{ + /* warp_num = */ 32, + /* tree_reduce_num = */ 1, + /* spatial_inner_num = */ 4, + /* reduce_method = */ NoneReduceMethod()}; + return {{bucket_info__1_1023, tile_config__1_1023}, + {bucket_info__1024_1M, tile_config__1024_1M}, + {bucket_info__1M_INF, tile_config__1M_INF}}; + } else if (base_info->reduce_numel <= 256) { + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 2, + /* rb_upper_bound = */ 256}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ 8, + /* tree_reduce_num = */ 32, + /* spatial_inner_num = */ (256 / Next2Power(base_info->reduce_numel)), + /* reduce_method = */ WarpReduceMethod()}; + return {{bucket_info, tile_config}}; + } else if (base_info->reduce_numel <= 2048) { + int64_t reduce_block = + int64_t(std::ceil(base_info->reduce_numel * 1.0 / 256.0)) * 256; + int64_t warp_num = reduce_block / 256; + int64_t spatial_inner_num = 1; + int64_t reduce_inner_num = 8; + int64_t tree_reduce_num = reduce_block / reduce_inner_num; + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 257, + /* rb_upper_bound = */ 2048}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ warp_num, + /* tree_reduce_num = */ tree_reduce_num, + /* spatial_inner_num = */ spatial_inner_num, + /* reduce_method = */ BlockReduceMethod()}; + return {{bucket_info, tile_config}}; + } else { + int64_t reduce_block = 2048; + int64_t warp_num = 8; + int64_t reduce_inner_num = + int64_t(std::ceil(base_info->reduce_numel * 1.0 / 256.0)); + int64_t spatial_inner_num = 1; + int64_t tree_reduce_num = reduce_block / reduce_inner_num; + BucketInfo bucket_info{/* sp_lower_bound = */ 1, + /* sp_upper_bound = */ kMaxNumel, + /* rb_lower_bound = */ 2049, + /* rb_upper_bound = */ kMaxNumel}; + ScheduleConfig::TileConfig tile_config{ + /* warp_num = */ warp_num, + /* tree_reduce_num = */ tree_reduce_num, + /* spatial_inner_num = */ spatial_inner_num, + /* reduce_method = */ BlockReduceMethod()}; + return {{bucket_info, tile_config}}; + } +} + +std::unordered_map +BuildDynamicShapeConfig( + const std::shared_ptr& base_info, + const common::Target& target) { + CINN_NOT_IMPLEMENTED; +} + +std::unordered_map +CombineBaseInfoAndConfig( + const std::unordered_map& config_map, + const std::shared_ptr& base_info) { + std::unordered_map combined; + for (const auto& bucket_config : config_map) { + ScheduleConfig sch_config{base_info, std::move(bucket_config.second)}; + combined.insert({std::move(bucket_config.first), std::move(sch_config)}); + } + return combined; +} + +std::unordered_map +BuildScheduleConfig( + const std::shared_ptr& group_info, + const common::Target& target) { + std::shared_ptr base_info = + InitBasicInfo(group_info); + if (!base_info->has_dynamic_reduce && !base_info->has_dynamic_spatial) { + VLOG(6) << "Building static sptial and static reduce config."; + return CombineBaseInfoAndConfig( + BuildPureStaticShapeConfig(base_info, target), base_info); + } else if (base_info->has_dynamic_reduce && !base_info->has_dynamic_spatial) { + VLOG(6) << "Building static sptial and dynamic reduce config."; + return CombineBaseInfoAndConfig(BuildStaticSpatialConfig(base_info, target), + base_info); + } else if (!base_info->has_dynamic_reduce && base_info->has_dynamic_spatial) { + VLOG(6) << "Building dynamic sptial and static reduce config."; + return CombineBaseInfoAndConfig(BuildStaticReduceConfig(base_info, target), + base_info); + } else { // (base_info->has_dynamic_reduce && base_info->has_dynamic_spatial) + VLOG(6) << "Building dynamic sptial and dynamic reduce config."; + return CombineBaseInfoAndConfig(BuildDynamicShapeConfig(base_info, target), + base_info); + } +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/config/group_tile_config.h b/paddle/cinn/ir/group_schedule/config/group_tile_config.h new file mode 100644 index 0000000000000..176084b458a06 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/config/group_tile_config.h @@ -0,0 +1,90 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/common/target.h" +#include "paddle/cinn/ir/schedule/schedule_base.h" + +namespace cinn { + +namespace hlir::framework::pir { +struct GroupInfo; +} // namespace hlir::framework::pir + +namespace ir { + +struct ScheduleConfig { + struct BaseInfo { + std::vector reduce_axis; + int64_t data_rank; + int64_t reduce_numel; + int64_t spatial_numel; + bool has_dynamic_spatial{false}; + bool has_dynamic_reduce{false}; + bool is_reduce_all{false}; + + std::set reduce_tensor_names; + std::set temp_var_names; + std::set shared_var_names; + std::set direct_output_var_names; + + std::unordered_map broadcast_info; + std::unordered_map broadcast_to_elementwise; + }; + + struct TileConfig { + int64_t warp_num{1}; + int64_t tree_reduce_num{1}; + int64_t spatial_inner_num{1}; + ReduceMethod reduce_method{NoneReduceMethod()}; + }; + + std::shared_ptr base_info; + TileConfig tile_config; +}; + +struct BucketInfo { + int64_t sp_lower_bound = 1; + int64_t sp_upper_bound = INT64_MAX; + int64_t rb_lower_bound = 1; + int64_t rb_upper_bound = INT64_MAX; + + bool operator==(const BucketInfo& other) const { + return this->sp_lower_bound == other.sp_lower_bound && + this->sp_upper_bound == other.sp_upper_bound && + this->rb_lower_bound == other.rb_lower_bound && + this->rb_upper_bound == other.rb_upper_bound; + } +}; + +struct BucketInfoHash { + std::size_t operator()(const BucketInfo& bucket_info) const noexcept { + std::size_t hash_spl = std::hash{}(bucket_info.sp_lower_bound); + std::size_t hash_spu = std::hash{}(bucket_info.sp_upper_bound); + std::size_t hash_rbl = std::hash{}(bucket_info.rb_lower_bound); + std::size_t hash_rbu = std::hash{}(bucket_info.rb_upper_bound); + return adt::hash_combine(adt::hash_combine(hash_spl, hash_spu), + adt::hash_combine(hash_rbl, hash_rbu)); + } +}; + +std::unordered_map +BuildScheduleConfig( + const std::shared_ptr& group_info, + const common::Target& target); + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc index d5a64b6d8f7f1..e604055cf3b93 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.cc @@ -18,11 +18,15 @@ #include "paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h" +#include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h" #include "paddle/cinn/ir/group_schedule/tactic/tile_tactic.h" #include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h" #include "paddle/cinn/ir/op/ir_operators.h" +PD_DECLARE_bool(cinn_bucket_compile); + namespace cinn { namespace ir { @@ -32,12 +36,10 @@ void DynamicShapeGroupScheduler::Init() { VLOG(4) << "original group func body: \n" << ir_sch_->GetModule().GetExprs()[0]; InitBuckets(); - tactics_.emplace_back(new AlignIterSpaceTactic()); - tactics_.emplace_back(new ComputeInlineTactic()); - tactics_.emplace_back(new TileTactic()); - tactics_.emplace_back(new OptimizeReductionTactic()); - tactics_.emplace_back(new BindCudaTactic()); - tactics_.emplace_back(new ArrangeStorageTactic()); + tactics_.emplace_back(CreateLoopReorderAlignmentTactic()); + VLOG(4) << "CreateLoopReorderAlignmentTactic End"; + tactics_.emplace_back(CreateTileFirstGeneralTactic()); + VLOG(4) << "CreateTileFirstGeneralTactic End"; } void DynamicShapeGroupScheduler::InitBuckets() { @@ -47,13 +49,16 @@ void DynamicShapeGroupScheduler::InitBuckets() { [](ir::Expr extent, int lower_bound, int upper_bound) -> bool { if (!extent.is_constant()) return false; int extent_value = static_cast(extent.get_constant()); - if (extent_value < lower_bound || extent_value >= upper_bound) { + VLOG(5) << "extent_value: " << extent_value + << ",lower_bound: " << lower_bound + << ",upper_bound: " << upper_bound; + if (extent_value < lower_bound || extent_value > upper_bound) { return true; } return false; }; - auto InitBucket = [&](BucketInfo&& bucket_info) { + auto InitBucket = [&](BucketInfo&& bucket_info, ScheduleConfig&& config) { std::unique_ptr ir_sch = std::make_unique(*ir_sch_); std::unique_ptr schedule_block_graph = @@ -61,21 +66,30 @@ void DynamicShapeGroupScheduler::InitBuckets() { ir::ScheduleBlockNode* global_master = FindGlobalMasterNode(schedule_block_graph); IterativeSpaceInfo iter_space_info = ConstructIterSpaceInfo(global_master); + VLOG(4) << "iter_space_info.total_sp_extent: " + << iter_space_info.total_sp_extent; + VLOG(4) << "iter_space_info.total_rb_extent: " + << iter_space_info.total_rb_extent; + VLOG(4) << "bucket_info.sp_lower_bound: " << bucket_info.sp_lower_bound; + VLOG(4) << "bucket_info.sp_upper_bound: " << bucket_info.sp_upper_bound; + VLOG(4) << "bucket_info.rb_lower_bound: " << bucket_info.rb_lower_bound; + VLOG(4) << "bucket_info.rb_upper_bound: " << bucket_info.rb_upper_bound; if (OutOfRange(iter_space_info.total_sp_extent, bucket_info.sp_lower_bound, bucket_info.sp_upper_bound) || OutOfRange(iter_space_info.total_rb_extent, bucket_info.rb_lower_bound, bucket_info.rb_upper_bound)) { + VLOG(4) << "Out of range"; return; } SymbolicPredicate sp_lower_bound_predicate = ir::GE::Make( iter_space_info.total_sp_extent, ir::Expr(bucket_info.sp_lower_bound)); - SymbolicPredicate sp_upper_bound_predicate = ir::LT::Make( + SymbolicPredicate sp_upper_bound_predicate = ir::LE::Make( iter_space_info.total_sp_extent, ir::Expr(bucket_info.sp_upper_bound)); SymbolicPredicate rb_lower_bound_predicate = ir::GE::Make( iter_space_info.total_rb_extent, ir::Expr(bucket_info.rb_lower_bound)); - SymbolicPredicate rb_upper_bound_predicate = ir::LT::Make( + SymbolicPredicate rb_upper_bound_predicate = ir::LE::Make( iter_space_info.total_rb_extent, ir::Expr(bucket_info.rb_upper_bound)); SymbolicPredicate sp_predicate = ir::And::Make(sp_lower_bound_predicate, sp_upper_bound_predicate); @@ -85,7 +99,8 @@ void DynamicShapeGroupScheduler::InitBuckets() { ScheduleContext schedule_context{output_names, target_, std::move(iter_space_info), - std::move(bucket_info)}; + std::move(bucket_info), + std::move(config)}; BucketContext bucket_context{std::move(predicate), std::move(ir_sch), std::move(schedule_block_graph), @@ -93,30 +108,15 @@ void DynamicShapeGroupScheduler::InitBuckets() { bucket_contexts_.emplace_back(std::move(bucket_context)); }; - // naive buckets - // 1. {sp_extent[1 - 1024], rb_extent[1 - 256]} - InitBucket({/* sp_lower_bound = */ 1, - /* sp_upper_bound = */ 1024, - /* rb_lower_bound = */ 1, - /* rb_upper_bound = */ 256}); - // 2. {sp_extent[1024 - +oo], rb_extent[1 - 256]} - InitBucket({/* sp_lower_bound = */ 1024, - /* sp_upper_bound = */ INT_MAX, - /* rb_lower_bound = */ 1, - /* rb_upper_bound = */ 256}); - // 3. {sp_extent[1 - 1024], rb_extent[256 - +oo]} - InitBucket({/* sp_lower_bound = */ 1, - /* sp_upper_bound = */ 1024, - /* rb_lower_bound = */ 256, - /* rb_upper_bound = */ INT_MAX}); - // 4. {sp_extent[1024 - +oo], rb_extent[256 - +oo]} - InitBucket({/* sp_lower_bound = */ 1024, - /* sp_upper_bound = */ INT_MAX, - /* rb_lower_bound = */ 256, - /* rb_upper_bound = */ INT_MAX}); + std::unordered_map configs = + BuildScheduleConfig(group_info_, target_); + for (std::pair&& config : configs) { + InitBucket(std::move(config.first), std::move(config.second)); + } } void DynamicShapeGroupScheduler::Schedule() { + VLOG(4) << "bucket_context_.size() = " << bucket_contexts_.size(); for (BucketContext& bucket_context : bucket_contexts_) { VLOG(4) << "===========================Apply tactics on Bucket [" << bucket_context.predicate << "]=========================="; diff --git a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h index e226059011b63..0e5205a419973 100644 --- a/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/dy_shape_group_scheduler.h @@ -28,8 +28,9 @@ class DynamicShapeGroupScheduler : public GroupScheduler { DynamicShapeGroupScheduler( ir::IRSchedule* ir_sch, const std::unordered_set& output_tensor_names, - const cinn::common::Target& target) - : GroupScheduler(ir_sch, output_tensor_names, target) { + const cinn::common::Target& target, + const std::shared_ptr& group_info) + : GroupScheduler(ir_sch, output_tensor_names, target, group_info) { Init(); } diff --git a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc index 7c999205f646f..1dc21ce8a3180 100644 --- a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc +++ b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc @@ -24,34 +24,11 @@ #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" #include "paddle/cinn/optim/replace_var_with_expr.h" +#include "paddle/cinn/utils/external_func_names.h" namespace cinn { namespace ir { -static const std::unordered_set - kProhibitScheduleExternalFuncNames = { -#define CINN_NVGPU_FUNC2STRING(str) #str -#define CINN_NVGPU_FUNC_TYPE(FUNC, TYPE) \ - CINN_NVGPU_FUNC2STRING(cinn_nvgpu_##FUNC##TYPE) - -#define GEN_FUNC_NAME(_, impl) \ - _(impl, gt_num) \ - _(impl, lt_num) \ - _(impl, index_add) \ - _(impl, next_smallest) - -#define GEN_FUNC_NAME_WITH_TYPE(_, ...) \ - _(__VA_ARGS__, _bool), _(__VA_ARGS__, _fp16), _(__VA_ARGS__, _fp32), \ - _(__VA_ARGS__, _fp64), _(__VA_ARGS__, _uint8), _(__VA_ARGS__, _int8), \ - _(__VA_ARGS__, _int16), _(__VA_ARGS__, _int32), _(__VA_ARGS__, _int64), - - GEN_FUNC_NAME(GEN_FUNC_NAME_WITH_TYPE, CINN_NVGPU_FUNC_TYPE) -#undef GEN_FUNC_NAME -#undef GEN_FUNC_NAME_WITH_TYPE -#undef CINN_NVGPU_FUNC_TYPE -#undef CINN_NVGPU_FUNC2STRING -}; - static bool IsProhibitScheduleExternCallBlock(ir::Expr block) { ir::ScheduleBlockRealize* sch_block_realize = block.As(); @@ -64,7 +41,8 @@ static bool IsProhibitScheduleExternCallBlock(ir::Expr block) { sch_block->body, [&](const Expr* x) { return x->As(); }); for (ir::Expr call : find_call) { ir::Call* call_node = call.As(); - if (kProhibitScheduleExternalFuncNames.count(call_node->name) != 0) { + if (cinn::utils::GetProhibitScheduleExternalFuncNames().count( + call_node->name) != 0) { return true; } } @@ -1039,8 +1017,9 @@ void StaticShapeGroupScheduler::AllocateStorage() { consumer_block_name)) { // TODO(BiynXu): Return error information to the front-end instead of // terminating the program. - LOG(FATAL) << "Fusion requires synchronization across blocks, but " - "currently we do not support it."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Fusion requires synchronization across blocks, but " + "currently we do not support it.")); break; } else if (IsCrossThread(store_indice_value, load_indice_value, diff --git a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h index 337817995eb0f..4a2724fe11c67 100644 --- a/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h +++ b/paddle/cinn/ir/group_schedule/st_shape_group_scheduler.h @@ -46,8 +46,9 @@ class StaticShapeGroupScheduler : public GroupScheduler { StaticShapeGroupScheduler( ir::IRSchedule* ir_sch, const std::unordered_set& output_tensor_names, - const cinn::common::Target& target) - : GroupScheduler(ir_sch, output_tensor_names, target) {} + const cinn::common::Target& target, + const std::shared_ptr& group_info) + : GroupScheduler(ir_sch, output_tensor_names, target, group_info) {} void Schedule() override; diff --git a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt index e8205f7244bb1..b6a2f06760646 100644 --- a/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt +++ b/paddle/cinn/ir/group_schedule/tactic/CMakeLists.txt @@ -6,3 +6,5 @@ gather_srcs(cinnapi_src SRCS compute_inline_tactic.cc) gather_srcs(cinnapi_src SRCS optimize_reduction_tactic.cc) gather_srcs(cinnapi_src SRCS bind_cuda_tactic.cc) gather_srcs(cinnapi_src SRCS arrange_storage_tactic.cc) +gather_srcs(cinnapi_src SRCS loop_reorder_alignment_tactic.cc) +gather_srcs(cinnapi_src SRCS tile_first_general_tactic.cc) diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc index 14fde3b148a52..dcc72e4a217d8 100644 --- a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.cc @@ -23,6 +23,18 @@ namespace cinn { namespace ir { +class AlignIterSpaceTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "AlignIterSpaceTactic"; } + + private: + ScheduleContext* context_; +}; + void AlignIterSpaceTactic::Init(ScheduleContext* context) { context_ = context; } @@ -84,5 +96,9 @@ void AlignIterSpaceTactic::Apply(ir::IRSchedule* sch, } } +std::unique_ptr CreateAlignIterSpaceTactic() { + return std::make_unique(); +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h index ef30f80ce470b..2ac65d114c7f5 100644 --- a/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/align_iter_space_tactic.h @@ -20,17 +20,7 @@ namespace cinn { namespace ir { -class AlignIterSpaceTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { return "AlignIterSpaceTactic"; } - - private: - ScheduleContext* context_; -}; +std::unique_ptr CreateAlignIterSpaceTactic(); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc index 5c5398533513d..661ab9e624d94 100644 --- a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.cc @@ -24,6 +24,18 @@ namespace cinn { namespace ir { +class ArrangeStorageTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "ArrangeStorageTactic"; } + + private: + std::unordered_set output_names_; +}; + // [block_name, [var, for_node]] using VarToForMap = std::unordered_map>; @@ -385,11 +397,12 @@ void ArrangeStorageTactic::Apply(ir::IRSchedule* sch, } else if (cross_type.value() == CudaAxisType::kCudaThread) { memory_type = ir::MemoryType::GPUShared; } else if (cross_type.value() == CudaAxisType::kCudaBlock) { - LOG(FATAL) << "Fusion requires synchronization across blocks, but " - "currently we do not support it."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Fusion requires synchronization across blocks, but " + "currently we do not support it.")); break; } else { - LOG(FATAL) << "dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } @@ -420,5 +433,9 @@ void ArrangeStorageTactic::Apply(ir::IRSchedule* sch, } } +std::unique_ptr CreateArrangeStorageTactic() { + return std::make_unique(); +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h index 994108d1662b9..25fe8047efcd0 100644 --- a/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/arrange_storage_tactic.h @@ -21,17 +21,7 @@ namespace cinn { namespace ir { -class ArrangeStorageTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { return "ArrangeStorageTactic"; } - - private: - std::unordered_set output_names_; -}; +std::unique_ptr CreateArrangeStorageTactic(); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc index 0fe53e779aeae..50556da0db033 100644 --- a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.cc @@ -19,6 +19,18 @@ namespace cinn { namespace ir { +class BindCudaTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "BindCudaTactic"; } + + private: + ScheduleContext* context_; +}; + void BindCudaTactic::Init(ScheduleContext* context) { context_ = context; } const std::unordered_map @@ -56,5 +68,9 @@ void BindCudaTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { } } +std::unique_ptr CreateBindCudaTactic() { + return std::make_unique(); +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h index b66c7d1fb802c..ae2ed3985bef1 100644 --- a/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/bind_cuda_tactic.h @@ -20,17 +20,7 @@ namespace cinn { namespace ir { -class BindCudaTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { return "BindCudaTactic"; } - - private: - ScheduleContext* context_; -}; +std::unique_ptr CreateBindCudaTactic(); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc index 8da8f44d32695..5076d1ded1e69 100644 --- a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.cc @@ -25,6 +25,19 @@ namespace cinn { namespace ir { +class ComputeInlineTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "ComputeInlineTactic"; } + + private: + std::unordered_set output_names_; + cinn::common::Target target_; +}; + void ComputeInlineTactic::Init(ScheduleContext* context) { output_names_ = context->output_names; target_ = context->target; @@ -48,5 +61,9 @@ void ComputeInlineTactic::Apply(ir::IRSchedule* sch, << sch->GetModule().GetExprs().front(); } +std::unique_ptr CreateComputeInlineTactic() { + return std::make_unique(); +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h index b03e28d579bc8..821126bfc7ecc 100644 --- a/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/compute_inline_tactic.h @@ -22,18 +22,7 @@ namespace cinn { namespace ir { -class ComputeInlineTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { return "ComputeInlineTactic"; } - - private: - std::unordered_set output_names_; - cinn::common::Target target_; -}; +std::unique_ptr CreateComputeInlineTactic(); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc new file mode 100644 index 0000000000000..416537c41e5c6 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h" +#include +#include +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace ir { + +class LoopReorderAlignmentTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { + return "LoopReorderAlignmentTactic"; + } + + private: + bool NeedReorderLoops(); + + std::vector GetNewOrder(); + + void UpdateBaseRank(ir::IRSchedule* sch, const std::string& block_id); + + void DoBroadcastLoop(ir::IRSchedule* sch, const std::string& block_id); + + void DoReorder(ir::IRSchedule* sch, const std::string& block_id); + + private: + ScheduleContext* context_; + size_t base_rank_; + bool need_reorder_loops_; + std::vector new_order_; +}; + +void LoopReorderAlignmentTactic::Init(ScheduleContext* context) { + context_ = context; + base_rank_ = 0; + need_reorder_loops_ = NeedReorderLoops(); + new_order_ = GetNewOrder(); +} + +void LoopReorderAlignmentTactic::Apply(ir::IRSchedule* sch, + const std::string& block_id) { + DoBroadcastLoop(sch, block_id); + + if (!ir::IsReduceInitTensorName(block_id)) { + UpdateBaseRank(sch, block_id); + } + + if (need_reorder_loops_ && !ir::IsReduceInitTensorName(block_id)) { + DoReorder(sch, block_id); + } +} + +void LoopReorderAlignmentTactic::UpdateBaseRank(ir::IRSchedule* sch, + const std::string& block_id) { + auto loops = sch->GetLoops(block_id); + if (base_rank_ == 0) { + base_rank_ = loops.size(); + } else { + if (base_rank_ != loops.size()) { + throw std::runtime_error("loops rank not same "); + } + } +} + +bool LoopReorderAlignmentTactic::NeedReorderLoops() { + const auto HasReduceAxis = [&]() { + return context_->config.base_info->reduce_axis.size() > 0; + }; + if (!HasReduceAxis()) { + return false; + } + + const auto HasNonLastDimReduce = [&]() { + std::vector vec_reduce_axis = + context_->config.base_info->reduce_axis; + std::sort(vec_reduce_axis.begin(), vec_reduce_axis.end()); + return vec_reduce_axis.front() != + context_->config.base_info->data_rank - vec_reduce_axis.size(); + }; + + return HasNonLastDimReduce(); +} + +std::vector LoopReorderAlignmentTactic::GetNewOrder() { + std::set reduce_set(context_->config.base_info->reduce_axis.begin(), + context_->config.base_info->reduce_axis.end()); + + std::vector new_order; + for (int32_t i = 0; i < context_->config.base_info->data_rank; ++i) { + if (!reduce_set.count(i)) { + new_order.push_back(i); + } + } + for (auto axis : context_->config.base_info->reduce_axis) { + new_order.push_back(axis); + } + + return new_order; +} + +void LoopReorderAlignmentTactic::DoBroadcastLoop(ir::IRSchedule* sch, + const std::string& block_id) { + const auto HasBroadcastInfo = [&](const std::string& block_id) { + return context_->config.base_info->broadcast_info.count(block_id) > 0; + }; + const auto HasBroadcastToElementwiseInfo = [&](const std::string& block_id) { + return context_->config.base_info->broadcast_to_elementwise.count( + block_id) > 0; + }; + const auto IsFullBroadcast = [&](const std::string& block_id) { + return context_->config.base_info->broadcast_info[block_id].full_broadcast; + }; + const auto IsSplitFirst = [&](const std::string& block_id) { + return context_->config.base_info->broadcast_info[block_id].split_first; + }; + + if (HasBroadcastInfo(block_id)) { + if (IsFullBroadcast(block_id)) { + std::vector vec_out_split( + context_->config.base_info->broadcast_info[block_id] + .output_shape.size(), + 1); + + auto loops = sch->GetLoops(block_id); + sch->Split(loops[0], vec_out_split); + loops = sch->GetLoops(block_id); + } else if (IsSplitFirst(block_id)) { + for (auto& info : + context_->config.base_info->broadcast_info[block_id].split_info) { + auto axis = info.first; + auto split_res = info.second; + + auto loops = sch->GetLoops(block_id); + sch->Split(loops[axis], split_res); + loops = sch->GetLoops(block_id); + } + } else { + // Do nothing + } + + sch->Broadcast(block_id, + context_->config.base_info->broadcast_info[block_id]); + } + + if (HasBroadcastToElementwiseInfo(block_id)) { + sch->BroadcastToElementwise( + block_id, + context_->config.base_info->broadcast_to_elementwise[block_id] + .broadcast_axes); + } +} + +void LoopReorderAlignmentTactic::DoReorder(ir::IRSchedule* sch, + const std::string& block_id) { + const auto IsReduceBlock = [&](const std::string& block_id) { + return context_->config.base_info->reduce_tensor_names.count(block_id) > 0; + }; + if (IsReduceBlock(block_id)) { + return; + } + + sch->Reorder(block_id, new_order_); +} + +std::unique_ptr CreateLoopReorderAlignmentTactic() { + return std::make_unique(); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h new file mode 100644 index 0000000000000..ee4864a5ecf92 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/loop_reorder_alignment_tactic.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +std::unique_ptr CreateLoopReorderAlignmentTactic(); + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc index c9f435704be9f..445ac32c94ab1 100644 --- a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.cc @@ -19,6 +19,18 @@ namespace cinn { namespace ir { +class OptimizeReductionTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "OptimizeReductionTactic"; } + + private: + ScheduleContext* context_; +}; + void OptimizeReductionTactic::Init(ScheduleContext* context) { context_ = context; } @@ -151,5 +163,9 @@ void OptimizeReductionTactic::Apply(ir::IRSchedule* sch, << sch->GetModule().GetExprs()[0]; } +std::unique_ptr CreateOptimizeReductionTactic() { + return std::make_unique(); +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h index 108f674ee2253..aa2405530f917 100644 --- a/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/optimize_reduction_tactic.h @@ -20,17 +20,7 @@ namespace cinn { namespace ir { -class OptimizeReductionTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { return "OptimizeReductionTactic"; } - - private: - ScheduleContext* context_; -}; +std::unique_ptr CreateOptimizeReductionTactic(); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h index 68f4ae31c7a7c..b76d1684bc399 100644 --- a/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h @@ -16,6 +16,8 @@ #include #include "paddle/cinn/common/integer_set.h" +#include "paddle/cinn/ir/group_schedule/config/group_tile_config.h" +#include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/schedule/ir_schedule.h" #include "paddle/cinn/ir/schedule_block_graph.h" @@ -64,18 +66,13 @@ struct IterativeSpaceInfo { } }; -struct BucketInfo { - int sp_lower_bound = 0; - int sp_upper_bound = UINT_MAX; - int rb_lower_bound = 0; - int rb_upper_bound = UINT_MAX; -}; - struct ScheduleContext { + // TODO(BiynXu): Unify fields with similar meanings std::unordered_set output_names; Target target; IterativeSpaceInfo iter_space_info; BucketInfo bucket_info; + ScheduleConfig config; }; class ScheduleTactic { diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc new file mode 100644 index 0000000000000..8a3c2dfa71356 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.cc @@ -0,0 +1,355 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h" +#include "paddle/cinn/adt/adt.h" +#include "paddle/cinn/common/integer_set.h" +#include "paddle/cinn/common/target.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" + +PD_DECLARE_bool(support_reduce_stride_read); + +namespace cinn { +namespace ir { + +bool IsInnerThreadSpatialLoopGT(const ScheduleConfig& config, int num) { + return config.tile_config.spatial_inner_num > num; +} + +bool IsReduceBlock(const ScheduleConfig& config, const std::string& block_id) { + return config.base_info->reduce_tensor_names.count(block_id) > 0; +} + +bool HasReduceAxis(const ScheduleConfig& config) { + return config.base_info->reduce_axis.size() > 0; +} + +bool IsWarpReduce(const ScheduleConfig& config) { + const auto& MatchWarpReduce = cinn::adt::match{ + [&](const ir::NoneReduceMethod&) { return false; }, + [&](const ir::WarpReduceMethod&) { return true; }, + [&](const ir::BlockReduceMethod&) { return false; }, + }; + return std::visit(MatchWarpReduce, config.tile_config.reduce_method); +} + +class TileFirstGeneralTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "TileFirstGeneralTactic"; } + + private: + void MergeFlattenAxis(ir::IRSchedule* sch, const std::string& block_id); + void MergeReduceAxis(ir::IRSchedule* sch, const std::string& block_id); + void SplitSptialInner(ir::IRSchedule* sch, const std::string& block_id); + void SplitReduceInner(ir::IRSchedule* sch, const std::string& block_id); + void ReorderFlattenInnerWithReduceAxis(ir::IRSchedule* sch, + const std::string& block_id); + void SplitWarpNumber(ir::IRSchedule* sch, const std::string& block_id); + void Unroll(ir::IRSchedule* sch, const std::string& block_id); + void VariableTypeAssignment(ir::IRSchedule* sch, const std::string& block_id); + void SetReduceType(ir::IRSchedule* sch, const std::string& block_id); + void BindCudaInfo(ir::IRSchedule* sch, const std::string& block_id); + + private: + ScheduleContext* context_; + std::vector vec_flatten_axis_; + std::vector vec_reduce_axis_; + int reduce_current_axis_{0}; +}; + +void TileFirstGeneralTactic::Init(ScheduleContext* context) { + context_ = context; + reduce_current_axis_ = + IsInnerThreadSpatialLoopGT(context_->config, 1) ? 2 : 1; + if (context_->config.base_info->is_reduce_all) { + reduce_current_axis_ = 1; + } + // reduce axis have be re-order to last + vec_flatten_axis_.clear(); + vec_reduce_axis_.clear(); + int32_t reduce_start_idx = context_->config.base_info->data_rank - + context_->config.base_info->reduce_axis.size(); + for (int32_t i = 0; i < context_->config.base_info->data_rank; ++i) { + if (i >= reduce_start_idx) { + vec_reduce_axis_.push_back(i); + } else { + vec_flatten_axis_.push_back(i); + } + } +} + +void TileFirstGeneralTactic::Apply(ir::IRSchedule* sch, + const std::string& block_id) { + if (ir::IsReduceInitTensorName(block_id)) return; + MergeReduceAxis(sch, block_id); + VLOG(6) << "After MergeReduceAxis on block: [" << block_id + << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + MergeFlattenAxis(sch, block_id); + VLOG(6) << "After MergeFlattenAxis on block: [" << block_id + << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + SplitSptialInner(sch, block_id); + VLOG(6) << "After SplitSptialInner on block: [" << block_id + << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + SplitReduceInner(sch, block_id); + VLOG(6) << "After SplitReduceInner on block: [" << block_id + << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + ReorderFlattenInnerWithReduceAxis(sch, block_id); + VLOG(6) << "After ReorderFlattenInnerWithReduceAxis on block: [" << block_id + << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + SplitWarpNumber(sch, block_id); + VLOG(6) << "After SplitWarpNumber on block: [" << block_id + << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + BindCudaInfo(sch, block_id); + VLOG(6) << "After BindCudaInfo on block: [" << block_id << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + VariableTypeAssignment(sch, block_id); + Unroll(sch, block_id); + VLOG(6) << "After Unroll on block: [" << block_id << "], loop nest:\n" + << sch->GetLoops(block_id)[0]; + SetReduceType(sch, block_id); +} + +void TileFirstGeneralTactic::MergeFlattenAxis(ir::IRSchedule* sch, + const std::string& block_id) { + if (vec_flatten_axis_.size() >= 2) { + sch->Fuse(block_id, vec_flatten_axis_); + } +} + +void TileFirstGeneralTactic::MergeReduceAxis(ir::IRSchedule* sch, + const std::string& block_id) { + if (vec_reduce_axis_.size() >= 2 && !ir::IsReduceInitTensorName(block_id)) { + sch->Fuse(block_id, vec_reduce_axis_); + } +} + +void TileFirstGeneralTactic::SplitSptialInner(ir::IRSchedule* sch, + const std::string& block_id) { + if (IsInnerThreadSpatialLoopGT(context_->config, 1)) { + auto loops = sch->GetLoops(block_id); + auto split_loops = + sch->Split(loops[0], + std::vector( + {-1, + static_cast( + context_->config.tile_config.spatial_inner_num)})); + } +} + +void TileFirstGeneralTactic::SplitReduceInner(ir::IRSchedule* sch, + const std::string& block_id) { + if (!HasReduceAxis(context_->config)) return; + + auto loops = sch->GetLoops(block_id); + auto reduce_loop = loops[reduce_current_axis_].As(); + + if (FLAGS_support_reduce_stride_read) { + if (context_->config.base_info->reduce_numel <= 256) { + std::vector split_factors{ + -1, static_cast(context_->config.tile_config.tree_reduce_num)}; + sch->Split(loops[reduce_current_axis_], split_factors); + loops = sch->GetLoops(block_id); + sch->Reorder( + {loops[reduce_current_axis_ + 1], loops[reduce_current_axis_]}); + } else { + // split warp num first + std::vector split_factors{ + static_cast(context_->config.tile_config.warp_num), -1, 32}; + sch->Split(loops[reduce_current_axis_], split_factors); + loops = sch->GetLoops(block_id); + sch->Reorder( + {loops[reduce_current_axis_ + 2], loops[reduce_current_axis_ + 1]}); + loops = sch->GetLoops(block_id); + sch->Fuse({loops[reduce_current_axis_], loops[reduce_current_axis_ + 1]}); + } + } else { + std::vector split_factors{ + static_cast(context_->config.tile_config.tree_reduce_num), -1}; + sch->Split(loops[reduce_current_axis_], split_factors); + } + loops = sch->GetLoops(block_id); + if (IsReduceBlock(context_->config, block_id)) { + sch->FactorizeReduction(loops[reduce_current_axis_], + 0, + /* with_write_back_block_init = */ false); + } +} + +void TileFirstGeneralTactic::ReorderFlattenInnerWithReduceAxis( + ir::IRSchedule* sch, const std::string& block_id) { + // re-order flatten inner num with last dim + auto loops = sch->GetLoops(block_id); + if (IsInnerThreadSpatialLoopGT(context_->config, 1) && + HasReduceAxis(context_->config)) { + sch->Reorder({loops[2], loops[1]}); + if (IsReduceBlock(context_->config, block_id) && + sch->HasBlock(block_id + "_rf")) { + loops = sch->GetLoops(block_id + "_rf"); + sch->Reorder({loops[2], loops[1]}); + } + } +} + +void TileFirstGeneralTactic::SplitWarpNumber(ir::IRSchedule* sch, + const std::string& block_id) { + const auto IsWarpNumGT = [&](int64_t num) { + return context_->config.tile_config.warp_num > num; + }; + if (!IsWarpNumGT(1)) return; + + const auto LimitWarpNum = [&](const ir::Expr& loop, ScheduleConfig* config) { + ir::Expr extent = loop.As()->extent; + common::cas_intervals_t var_intervals = + common::CollectVarIntervalsOfExprs({extent}); + common::SymbolicExprAnalyzer analyzer(var_intervals); + const auto& proved_gt = + analyzer.ProveGT(ir::Expr(config->tile_config.warp_num), extent); + if (proved_gt.value_or(false)) { + ir::Expr upper_bound = analyzer.UpperBound(extent); + if (upper_bound.is_constant()) { + config->tile_config.warp_num = upper_bound.get_constant(); + } + } + }; + + auto loops = sch->GetLoops(block_id); + if (!HasReduceAxis(context_->config)) { + if (context_->config.tile_config.warp_num == + -1) { // only in bucket spatial_numel <= 1024 + sch->Split(loops[0], std::vector({1, -1})); + } else { + sch->Split( + loops[0], + std::vector( + {-1, + static_cast(context_->config.tile_config.warp_num * 32)})); + } + } else if (IsWarpReduce(context_->config)) { + // get num warp from flatten num + LimitWarpNum(loops[0], &(context_->config)); + int thread_y = context_->config.tile_config.warp_num * 32 / + context_->config.tile_config.tree_reduce_num; + sch->Split(loops[0], std::vector({-1, thread_y})); + + if (IsReduceBlock(context_->config, block_id) && + sch->HasBlock(block_id + "_rf")) { + auto loops = sch->GetLoops(block_id + "_rf"); + sch->Split(loops[0], std::vector({-1, thread_y})); + } + } else { + return; + } +} + +void TileFirstGeneralTactic::Unroll(ir::IRSchedule* sch, + const std::string& block_id) { + std::vector unroll_loops_idx = [&] { + if (IsWarpReduce(context_->config)) { + return std::vector{3, 4}; + } else { + return std::vector{2, 3}; + } + }(); + + const auto DoUnroll = [&](const std::vector& loops) { + for (size_t loop_idx : unroll_loops_idx) { + if (loops.size() > loop_idx && + loops[loop_idx].As()->extent.is_constant()) { + sch->Unroll(loops[loop_idx]); + } + } + }; + + DoUnroll(sch->GetLoops(block_id)); + if (IsReduceBlock(context_->config, block_id) && + sch->HasBlock(block_id + "_rf")) { + DoUnroll(sch->GetLoops(block_id + "_rf")); + } +} + +void TileFirstGeneralTactic::VariableTypeAssignment( + ir::IRSchedule* sch, const std::string& block_id) { + const auto IsOutputTensor = [&](const std::string& tensor_name) { + return context_->config.base_info->direct_output_var_names.count( + tensor_name) > 0; + }; + + auto block = sch->GetBlock(block_id); + if (!IsOutputTensor(block_id)) { + sch->SetBuffer(block, "local", false); + } + + if (IsReduceBlock(context_->config, block_id) && + sch->HasBlock(block_id + "_rf")) { + auto block = sch->GetBlock(block_id + "_rf"); + sch->SetBuffer(block, "local", false); + } +} + +void TileFirstGeneralTactic::SetReduceType(ir::IRSchedule* sch, + const std::string& block_id) { + if (IsReduceBlock(context_->config, block_id)) { + auto block = sch->GetBlock(block_id) + .As() + ->schedule_block.As(); + block->reduce_method = context_->config.tile_config.reduce_method; + } +} + +void TileFirstGeneralTactic::BindCudaInfo(ir::IRSchedule* sch, + const std::string& block_id) { + auto loops = sch->GetLoops(block_id); + if (loops.size() == 1 || context_->config.base_info->is_reduce_all) { + sch->Split(loops[0], std::vector({1, -1})); + } + + const auto DoBind = [&](const std::vector& loops) { + sch->Bind(loops[0], "blockIdx.x"); + if (IsWarpReduce(context_->config)) { + sch->Bind(loops[1], "threadIdx.y"); + sch->Bind(loops[2], "threadIdx.x"); + } else { + sch->Bind(loops[1], "threadIdx.x"); + } + }; + + DoBind(sch->GetLoops(block_id)); + + if (IsReduceBlock(context_->config, block_id) && + sch->HasBlock(block_id + "_rf")) { + auto loops = sch->GetLoops(block_id + "_rf"); + if (context_->config.base_info->is_reduce_all) { + sch->Split(loops[0], std::vector({1, -1})); + } + DoBind(sch->GetLoops(block_id + "_rf")); + } +} + +std::unique_ptr CreateTileFirstGeneralTactic() { + return std::make_unique(); +} + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h new file mode 100644 index 0000000000000..cda680c8ecf90 --- /dev/null +++ b/paddle/cinn/ir/group_schedule/tactic/tile_first_general_tactic.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/cinn/ir/group_schedule/tactic/schedule_tactic.h" + +namespace cinn { +namespace ir { + +std::unique_ptr CreateTileFirstGeneralTactic(); + +} // namespace ir +} // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc index e0e84d0bcd5b1..114a539e4e3f6 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.cc @@ -19,6 +19,18 @@ namespace cinn { namespace ir { +class TileTactic final : public ScheduleTactic { + public: + void Init(ScheduleContext* context) override; + + void Apply(ir::IRSchedule* sch, const std::string& block_id) override; + + std::string TacticName() const override { return "TileTactic"; } + + private: + ScheduleContext* context_; +}; + void TileTactic::Init(ScheduleContext* context) { context_ = context; // TODO(BiynXu): Create schedule config and bucket info based on hardware @@ -114,5 +126,9 @@ void TileTactic::Apply(ir::IRSchedule* sch, const std::string& block_id) { << sch->GetModule().GetExprs()[0]; } +std::unique_ptr CreateTileTactic() { + return std::make_unique(); +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h index 8a6d2bb8dd766..223287372ddf3 100644 --- a/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h +++ b/paddle/cinn/ir/group_schedule/tactic/tile_tactic.h @@ -20,17 +20,7 @@ namespace cinn { namespace ir { -class TileTactic final : public ScheduleTactic { - public: - void Init(ScheduleContext* context) override; - - void Apply(ir::IRSchedule* sch, const std::string& block_id) override; - - std::string TacticName() const override { return "TileTactic"; } - - private: - ScheduleContext* context_; -}; +std::unique_ptr CreateTileTactic(); } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index 2e194200d1993..a121806e6f3bf 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -218,11 +218,13 @@ Expr _Var_::Make(Expr lower_bound, Expr upper_bound, const std::string &name, bool is_reduce_axis, - bool is_symbolic_constant) { + bool is_symbolic_constant, + bool is_keepdim) { auto *n = make_shared<_Var_>(); n->lower_bound = lower_bound; n->upper_bound = upper_bound; n->is_reduce_axis = is_reduce_axis; + n->is_keepdim = is_keepdim; n->is_symbolic_constant = is_symbolic_constant; n->name = name; n->set_type(lower_bound.type()); @@ -233,6 +235,7 @@ Expr _Var_::Copy() const { auto *n = make_shared<_Var_>(); n->name = name; n->is_reduce_axis = is_reduce_axis; + n->is_keepdim = is_keepdim; n->lower_bound = lower_bound; n->upper_bound = upper_bound; n->set_type(type()); @@ -392,7 +395,6 @@ Expr Store::index() const { return indices[0]; } Expr res = cinn::common::IndiceToAbsOffset(tensor_n->shape, indices); - optim::Simplify(&res); return res; } @@ -630,8 +632,6 @@ Expr Load::index() const { return indices[0]; } Expr res = cinn::common::IndiceToAbsOffset(tensor_n->shape, indices); - VLOG(3) << "Begin Load::index Simplify"; - optim::Simplify(&res); return res; } else { CHECK_EQ(indices.size(), 1UL); diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index c02517f9836fc..d711e93ce61ab 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -381,6 +381,7 @@ struct _Var_ : public ExprNode<_Var_> { std::string name; bool is_reduce_axis{false}; + bool is_keepdim{false}; bool is_symbolic_constant{false}; //! Lower bound and upper bound of a axis. // @{ @@ -401,7 +402,8 @@ struct _Var_ : public ExprNode<_Var_> { Expr upper_bound, const std::string& name, bool is_reduce, - bool is_symbolic_constant = false); + bool is_symbolic_constant = false, + bool is_keepdim = false); void Verify() const override; @@ -419,12 +421,14 @@ struct Var : public IrNodeRef { Var(Expr lower_bound, Expr upper_bound, const std::string& name, - bool is_reduce = false) - : Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {} + bool is_reduce = false, + bool is_keepdim = false) + : Var(_Var_::Make( + lower_bound, upper_bound, name, is_reduce, false, is_keepdim)) {} Var(int upper_bound, const std::string& name) - : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {} + : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false, false)) {} Var(Expr upper_bound, const std::string& name) - : Var(_Var_::Make(Expr(0), upper_bound, name, false)) {} + : Var(_Var_::Make(Expr(0), upper_bound, name, false, false)) {} operator Expr() { return Expr(get()); } operator Expr() const { @@ -962,6 +966,12 @@ struct Block : public ExprNode { static const IrNodeTy _node_type_ = IrNodeTy::Block; }; +struct NoneReduceMethod {}; +struct WarpReduceMethod {}; +struct BlockReduceMethod {}; +using ReduceMethod = + std::variant; + // ScheduleBlock is the unit of schedule IR which represents tensor's // computation struct ScheduleBlock : public ExprNode { @@ -977,6 +987,7 @@ struct ScheduleBlock : public ExprNode { std::map attrs; std::string name; Expr body; + ReduceMethod reduce_method{NoneReduceMethod()}; static Expr Make(const std::vector& iter_vars, const std::vector& read_buffers, diff --git a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc index b75f12712853f..a9740c52652e5 100644 --- a/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc +++ b/paddle/cinn/ir/ir_analyzer/ir_analyzer.cc @@ -34,8 +34,8 @@ #include "paddle/cinn/ir/schedule/schedule_desc.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/ir/utils/ir_nodes_collector.h" -#include "paddle/cinn/utils/error.h" #include "paddle/cinn/utils/random_engine.h" +#include "paddle/common/enforce.h" namespace cinn { namespace ir { @@ -74,9 +74,12 @@ std::vector GetLoops(const std::vector& exprs, const Expr& block) { FindLoopsVisitor visitor(block); auto find_loops = visitor(&it_expr); if (!find_loops.empty()) { - if (!result.empty()) - LOG(FATAL) << "Find block with name: \n" - << block_name << " appeared in more than one AST!"; + if (!result.empty()) { + std::stringstream ss; + ss << "Find block with name: \n" + << block_name << " appeared in more than one AST!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); + } result = find_loops; } } @@ -120,8 +123,10 @@ Expr GetBlock(const std::vector& exprs, const std::string& block_name) { return result; } } - LOG(FATAL) << "Didn't find a block with name " << block_name - << " in this ModuleExpr!"; + std::stringstream ss; + ss << "Didn't find a block with name " << block_name + << " in this ModuleExpr!"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } Expr GetRootBlock(const std::vector& exprs, const Expr& expr) { @@ -139,9 +144,9 @@ Expr GetRootBlock(const std::vector& exprs, const Expr& expr) { return it_expr.As()->stmts[0]; } } - LOG(FATAL) << "Didn't find expr \n" - << expr << "in StScheduleImpl:\n" - << exprs[0]; + std::stringstream ss; + ss << "Didn't find expr \n" << expr << "in StScheduleImpl:\n" << exprs[0]; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } DeviceAPI GetDeviceAPI(const std::vector& exprs) { @@ -208,9 +213,10 @@ Expr AddUnitLoop(const std::vector& exprs, const Expr& block) { visitor.target_->As()->body = loop; return loop; } else { - LOG(FATAL) << "Can't find block's parent!"; + PADDLE_THROW(phi::errors::InvalidArgument("Can't find block's parent!")); } - LOG(FATAL) << "Shouldn't reach code here in AddUnitLoop"; + PADDLE_THROW( + phi::errors::InvalidArgument("Shouldn't reach code here in AddUnitLoop")); return Expr{nullptr}; } @@ -422,7 +428,15 @@ bool IsBroadcastSBlock(ir::Expr block) { return false; } // each load index can be found in store index and maintain relative order + const auto IsIndexZero = [](const ir::Expr& e) -> bool { + return e.is_constant() && e.get_constant() == 0; + }; + int num_load_index_zero = 0; for (size_t i = 0; i < load->indices.size(); ++i) { + if (IsIndexZero(load->indices[i]) && !IsIndexZero(store->indices[i])) { + ++num_load_index_zero; + continue; + } bool found = false; for (size_t j = i; j < store->indices.size(); ++j) { ir::_Var_* load_var = load->indices[i].as_var(); @@ -439,7 +453,7 @@ bool IsBroadcastSBlock(ir::Expr block) { return false; } } - return load->indices.size() < store->indices.size(); + return load->indices.size() - num_load_index_zero < store->indices.size(); } std::vector IndicesToVars(const std::vector& indices) { diff --git a/paddle/cinn/ir/ir_base.cc b/paddle/cinn/ir/ir_base.cc index e4b1b2f95b180..c1b0580d16562 100644 --- a/paddle/cinn/ir/ir_base.cc +++ b/paddle/cinn/ir/ir_base.cc @@ -22,7 +22,7 @@ #include "paddle/cinn/ir/ir_visitor.h" #include "paddle/cinn/ir/module.h" #include "paddle/cinn/ir/tensor.h" -#include "paddle/cinn/utils/error.h" +#include "paddle/common/enforce.h" namespace cinn { namespace ir { @@ -51,7 +51,7 @@ std::ostream &operator<<(std::ostream &os, IrNodeTy type) { #undef __m default: - LOG(FATAL) << "unknown IrNodeTy found"; + PADDLE_THROW(phi::errors::InvalidArgument("unknown IrNodeTy found")); } return os; diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h index 24a7c2271d1fd..236e8afb67fe8 100644 --- a/paddle/cinn/ir/ir_base.h +++ b/paddle/cinn/ir/ir_base.h @@ -492,7 +492,7 @@ static std::ostream& operator<<(std::ostream& os, MemoryType t) { MEMORY_TYPE_FOR_ALL(__) default: - LOG(FATAL) << "Not supported memory type"; + PADDLE_THROW(phi::errors::InvalidArgument("Not supported memory type")); #undef __ } return os; @@ -500,7 +500,7 @@ static std::ostream& operator<<(std::ostream& os, MemoryType t) { template Expr ExprNode::Copy() const { - LOG(FATAL) << "Not Implemented"; + PADDLE_THROW(phi::errors::Unimplemented("Not Implemented")); return Expr(); } diff --git a/paddle/cinn/ir/ir_printer.cc b/paddle/cinn/ir/ir_printer.cc index 61b90ec6c7825..abd3515a8308a 100644 --- a/paddle/cinn/ir/ir_printer.cc +++ b/paddle/cinn/ir/ir_printer.cc @@ -60,7 +60,9 @@ void IrPrinter::Visit(const IntImm *x) { str_ += "(int8_t)"; str_ += std::to_string(x->value); } else { - LOG(FATAL) << "Not support int type: " << x->type(); + std::stringstream ss; + ss << "Not support int type: " << x->type(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } void IrPrinter::Visit(const UIntImm *x) { @@ -82,7 +84,9 @@ void IrPrinter::Visit(const UIntImm *x) { str_ += "false"; } } else { - LOG(FATAL) << "Not support uint type: " << x->type(); + std::stringstream ss; + ss << "Not support uint type: " << x->type(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } void IrPrinter::Visit(const FloatImm *x) { @@ -119,7 +123,9 @@ void IrPrinter::Visit(const FloatImm *x) { ss << std::showpoint; ss << x->value; } else { - LOG(FATAL) << "Not support float type: " << x->type(); + std::stringstream ss; + ss << "Not support float type: " << x->type(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } str_ += ss.str(); } diff --git a/paddle/cinn/ir/ir_visitor.h b/paddle/cinn/ir/ir_visitor.h index 87705597a7b1b..c5377401bbbb5 100644 --- a/paddle/cinn/ir/ir_visitor.h +++ b/paddle/cinn/ir/ir_visitor.h @@ -48,8 +48,10 @@ class IRVisitorRequireReImpl { NODETY_FORALL(__) default: - LOG(FATAL) << "not supported NodeTy, the expr->node_type() = " - << expr->node_type(); + std::stringstream ss; + ss << "not supported NodeTy, the expr->node_type() = " + << expr->node_type(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); #undef __ } return RetTy(); diff --git a/paddle/cinn/ir/layout.cc b/paddle/cinn/ir/layout.cc index f4e4585aa2145..ba0f07d520916 100644 --- a/paddle/cinn/ir/layout.cc +++ b/paddle/cinn/ir/layout.cc @@ -59,7 +59,9 @@ Layout::Layout(const std::string& name) { axes.push_back(ir::Var(factor, std::string(1, c))); factor = 0; } else { - LOG(FATAL) << "Invalid layout: " << name; + std::stringstream ss; + ss << "Invalid layout: " << name; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } name_ = name; diff --git a/paddle/cinn/ir/op/ir_operators.cc b/paddle/cinn/ir/op/ir_operators.cc index fcb0e19a6bb95..d11a26685851f 100644 --- a/paddle/cinn/ir/op/ir_operators.cc +++ b/paddle/cinn/ir/op/ir_operators.cc @@ -88,7 +88,9 @@ Expr operator|(Expr a, Expr b) { auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_or"); return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() << " for bitwise_or."; + std::stringstream ss; + ss << "Unsupport arch: " << target.arch_str() << " for bitwise_or."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -111,8 +113,9 @@ Expr operator&(Expr a, Expr b) { auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_and"); return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() - << " for bitwise_and."; + std::stringstream ss; + ss << "Unsupport arch: " << target.arch_str() << " for bitwise_and."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -135,8 +138,9 @@ Expr operator^(Expr a, Expr b) { auto func_name = hlir::GetExternFuncName(target, t_a, "bitwise_xor"); return lang::CallExtern(func_name, {a, b}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() - << " for bitwise_xor."; + std::stringstream ss; + ss << "Unsupport arch: " << target.arch_str() << " for bitwise_xor."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } @@ -149,8 +153,9 @@ Expr operator~(Expr a) { auto func_name = hlir::GetExternFuncName(target, a->type(), "bitwise_not"); return lang::CallExtern(func_name, {a}, {{"vectorizable", false}}); } else { - LOG(FATAL) << "Unsupport arch: " << target.arch_str() - << " for bitwise_not."; + std::stringstream ss; + ss << "Unsupport arch: " << target.arch_str() << " for bitwise_not."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } diff --git a/paddle/cinn/ir/schedule/factorize_reduction.h b/paddle/cinn/ir/schedule/factorize_reduction.h index d6252bb0a4663..8b0488e9c883c 100644 --- a/paddle/cinn/ir/schedule/factorize_reduction.h +++ b/paddle/cinn/ir/schedule/factorize_reduction.h @@ -90,6 +90,7 @@ class ReduceBlockCreater { is_rf_block_ ? rf_tensor_ : original_update_stmt_.As()->tensor.as_tensor_ref(); + Expr init_value = real_tensor->GetReduceInitVal(); const std::vector& domain = real_tensor->domain_without_reduce_axis(); ir::Tensor init_tensor = lang::Compute( @@ -97,8 +98,21 @@ class ReduceBlockCreater { [=](const std::vector& axis) { return init_value; }, new_init_block_name); init_tensor->Bind(real_tensor->buffer); - Expr init_stmt = ir::Store::Make( - init_tensor, init_value, new_update_stmt_.As()->indices); + std::vector new_indices; + if (new_update_stmt_.As()) { + new_indices = new_update_stmt_.As()->indices; + } else if (new_update_stmt_.As()) { + new_indices = new_update_stmt_.As() + ->true_case.As() + ->stmts[0] + .As() + ->indices; + } else { + throw std::runtime_error("only support store and ifthenelse"); + } + + Expr init_stmt = ir::Store::Make(init_tensor, init_value, new_indices); + new_init_sch_block_ = ScheduleBlock::Make( new_init_iter_vars_, {}, {}, new_init_block_name, init_stmt); new_init_block_realize_ = @@ -111,7 +125,7 @@ class ReduceBlockCreater { VLOG(4) << "new_update_block_realize:\n" << new_update_block_realize_; } - Expr CreateLoops() { + Expr CreateLoops(bool with_init = true) { int num_loops = original_loops_.size(); std::vector new_loops(num_loops); Expr body = new_update_block_realize_; @@ -127,7 +141,7 @@ class ReduceBlockCreater { continue; } // Add reduce init block. - if (!has_add_init_block && is_spatial_loop) { + if (!has_add_init_block && is_spatial_loop && with_init) { body = Block::Make({new_init_block_realize_, body}); has_add_init_block = true; } @@ -201,6 +215,26 @@ class ReduceBlockCreater { Expr new_init_block_realize_; }; +class LoadReplacer : public ir::IRMutator<> { + public: + explicit LoadReplacer(const std::string& src_load_tensor_name, + const ir::Expr& target) + : src_load_tensor_name_(src_load_tensor_name), target_(target) {} + + void operator()(Expr* expr) { IRMutator::Visit(expr, expr); } + + private: + void Visit(const ir::Load* expr, Expr* op) override { + if (expr->tensor.as_tensor()->name == src_load_tensor_name_) { + *op = target_; + } + } + + private: + std::string src_load_tensor_name_; + ir::Expr target_; +}; + // Implement class for building Reduction-Factorized block, // only used for FactorizeReduction schedule primitive. class RFBlockCreater : public ReduceBlockCreater { @@ -211,6 +245,7 @@ class RFBlockCreater : public ReduceBlockCreater { const Expr& original_update_stmt, const ir::Tensor& rf_tensor, const std::map& var2loops, + const Expr& bound_check, int rf_axis) : ReduceBlockCreater(original_block, original_loops, @@ -219,7 +254,8 @@ class RFBlockCreater : public ReduceBlockCreater { rf_tensor, true), var2loops_(var2loops), - rf_axis_(rf_axis) {} + rf_axis_(rf_axis), + bound_check_(ir_utils::IRCopy(bound_check)) {} private: void CreateRFIter() override { @@ -235,6 +271,11 @@ class RFBlockCreater : public ReduceBlockCreater { new_init_iter_vars_.push_back(rf_var_); new_init_iter_values_.push_back(rf_loop_.As()->loop_var); new_spatial_loop_var_names_.insert(rf_loop_.As()->loop_var->name); + + std::vector new_iter_exprs{Expr(rf_var_)}; + ReplaceExpr( + &bound_check_, {rf_loop_.As()->loop_var}, new_iter_exprs); + VLOG(4) << "create new_rf_var = " << rf_var_ << ", with iter value = " << new_iter_values_.back(); } @@ -310,29 +351,19 @@ class RFBlockCreater : public ReduceBlockCreater { rf_tensor_access_indices_.insert( rf_tensor_access_indices_.begin() + rf_axis_, rf_var_); Expr original_store_body = original_update_stmt_.As()->value; + std::string original_store_name = + original_update_stmt_.As()->tensor.as_tensor()->name; Expr new_store_body = ir_utils::IRCopy(original_store_body); -#define REPLACE_RF_TENSOR(Op) \ - if (new_store_body.As()) { \ - auto* node = new_store_body.As(); \ - CHECK(node); \ - auto& operand = node->a(); \ - operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \ - } - - REPLACE_RF_TENSOR(Add) - REPLACE_RF_TENSOR(Mul) - REPLACE_RF_TENSOR(Max) - REPLACE_RF_TENSOR(Min) - REPLACE_RF_TENSOR(And) - REPLACE_RF_TENSOR(Or) - REPLACE_RF_TENSOR(LT) - REPLACE_RF_TENSOR(LE) - REPLACE_RF_TENSOR(GT) - REPLACE_RF_TENSOR(GE) -#undef REPLACE_RF_TENSOR + LoadReplacer load_replacer( + original_store_name, Load::Make(rf_tensor_, rf_tensor_access_indices_)); + load_replacer(&new_store_body); new_update_stmt_ = ir::Store::Make(rf_tensor_, new_store_body, rf_tensor_access_indices_); + + if (!bound_check_.is_constant()) { + new_update_stmt_ = ir::IfThenElse::Make(bound_check_, new_update_stmt_); + } ReplaceExpr(&new_update_stmt_, original_indice2new_expr_); VLOG(4) << "new_update_stmt of rf block: \n" << new_update_stmt_; } @@ -342,6 +373,8 @@ class RFBlockCreater : public ReduceBlockCreater { int rf_axis_; std::map loop_var2block_iters_; + + Expr bound_check_; }; // Implement class for building Writing-Back block, @@ -406,6 +439,9 @@ class RBBlockCreater : public ReduceBlockCreater { void CreateUpdateStmt() override { Expr original_store_body = original_update_stmt_.As()->value; Expr new_store_body = ir_utils::IRCopy(original_store_body); + std::string original_store_name = + original_update_stmt_.As()->tensor.as_tensor()->name; + #define REPLACE_RF_TENSOR(Op) \ if (new_store_body.As()) { \ auto* node = new_store_body.As(); \ diff --git a/paddle/cinn/ir/schedule/impl/base.cc b/paddle/cinn/ir/schedule/impl/base.cc index d27bcd451f508..24583a67374e7 100644 --- a/paddle/cinn/ir/schedule/impl/base.cc +++ b/paddle/cinn/ir/schedule/impl/base.cc @@ -26,10 +26,11 @@ * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW(phi::errors::InvalidArgument( \ + err_handler.FormatErrorMessage(err_msg_level))); \ } namespace cinn { @@ -40,7 +41,7 @@ void DyScheduleImpl::MergeExprs() { std::string primitive = "MergeExprs"; std::ostringstream os; auto exprs = this->GetModule().GetExprs(); - if (exprs.size() == 1U) return; + if (exprs.size() <= 1U) return; if (!exprs[0].As()) { os << "Expr[0] of module_expr should be a Block!\n"; throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); @@ -428,7 +429,7 @@ Expr DyScheduleImpl::SampleCategorical( std::string primitive = "SampleCategorical"; std::ostringstream os; if (candidates.size() != probs.size()) { - os << "vector params(candidates) and vector prama(probs) must " + os << "vector params(candidates) and vector params(probs) must " "have same size in SampleCategorical!\n"; throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); } @@ -662,11 +663,13 @@ void StScheduleImpl::CopyTransformAndLoopInfo(const Expr& block, } } - if (new_iter_values.empty()) - LOG(FATAL) << "Cannot CopyTransformAndLoopInfo since shape[0] of source " - "and target is not equal! " - << vars[0]->upper_bound << " v.s " - << vars_target[0]->upper_bound; + if (new_iter_values.empty()) { + std::stringstream ss; + ss << "Cannot CopyTransformAndLoopInfo since shape[0] of source " + "and target is not equal! " + << vars[0]->upper_bound << " v.s " << vars_target[0]->upper_bound; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); + } int changed_loop_num = new_iter_values.size(); std::set used_target_loop_vars; diff --git a/paddle/cinn/ir/schedule/impl/compute_location.cc b/paddle/cinn/ir/schedule/impl/compute_location.cc index a077039994e81..09d4f26c7c8cb 100644 --- a/paddle/cinn/ir/schedule/impl/compute_location.cc +++ b/paddle/cinn/ir/schedule/impl/compute_location.cc @@ -26,10 +26,11 @@ * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW( \ + phi::errors::Fatal(err_handler.FormatErrorMessage(err_msg_level))); \ } namespace cinn { @@ -42,11 +43,11 @@ void DyScheduleImpl::ComputeAt(const Expr& block, std::string primitive = "ComputeAt"; std::ostringstream os; if (!block.As()) { - os << "Expr prama(block) should be a ScheduleBlockRealize!\n"; + os << "Expr param(block) should be a ScheduleBlockRealize!\n"; throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); } if (!loop.As()) { - os << "Expr prama(loop) should be a For node!\n"; + os << "Expr param(loop) should be a For node!\n"; throw IRScheduleErrorHandler(primitive, os.str(), module_expr_); } Expr root = this->GetRootBlock(block); diff --git a/paddle/cinn/ir/schedule/impl/for_type.cc b/paddle/cinn/ir/schedule/impl/for_type.cc index 53f157eac931a..a53870f09ea46 100644 --- a/paddle/cinn/ir/schedule/impl/for_type.cc +++ b/paddle/cinn/ir/schedule/impl/for_type.cc @@ -29,10 +29,11 @@ namespace ir { * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW( \ + phi::errors::Fatal(err_handler.FormatErrorMessage(err_msg_level))); \ } void DyScheduleImpl::MutateForType(const Expr& loop, @@ -53,7 +54,7 @@ void DyScheduleImpl::MutateForType(const Expr& loop, << static_cast(for_type) << "!\n"; } - auto loop_copy = ir::ir_utils::IRCopy(loop); + auto loop_copy = ir::ir_utils::IRCopy(loop, /* copy_buffer_node = */ false); auto* new_for_node = loop_copy.As(); CHECK(new_for_node); new_for_node->set_for_type(for_type); diff --git a/paddle/cinn/ir/schedule/impl/ir_schedule.h b/paddle/cinn/ir/schedule/impl/ir_schedule.h index 3fe35854cb4aa..42779c968d827 100644 --- a/paddle/cinn/ir/schedule/impl/ir_schedule.h +++ b/paddle/cinn/ir/schedule/impl/ir_schedule.h @@ -87,7 +87,9 @@ class DyScheduleImpl : public ScheduleBase { void ReverseComputeInline(const Expr& schedule_block); void Bind(const Expr& loop, const std::string& thread_axis); Expr Rfactor(const Expr& rf_loop, int rf_axis); - Expr FactorizeReduction(const Expr& rf_loop, int rf_axis); + Expr FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init = true); Expr AddUnitLoop(const Expr& block) const; void Annotate(const Expr& block, const std::string& key, const attr_t& value); void Unannotate(Expr& block, const std::string& key); // NOLINT @@ -161,7 +163,9 @@ class StScheduleImpl : public ScheduleBase { void ReverseComputeInline(const Expr& schedule_block); void Bind(const Expr& loop, const std::string& thread_axis); Expr Rfactor(const Expr& rf_loop, int rf_axis); - Expr FactorizeReduction(const Expr& rf_loop, int rf_axis); + Expr FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init = true); Expr AddUnitLoop(const Expr& block) const; void Annotate(const Expr& block, const std::string& key, const attr_t& value); void Unannotate(Expr& block, const std::string& key); // NOLINT diff --git a/paddle/cinn/ir/schedule/impl/loop_transformation.cc b/paddle/cinn/ir/schedule/impl/loop_transformation.cc index b320f6ace3f69..0b27d66fbbd7a 100644 --- a/paddle/cinn/ir/schedule/impl/loop_transformation.cc +++ b/paddle/cinn/ir/schedule/impl/loop_transformation.cc @@ -28,10 +28,11 @@ * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW( \ + phi::errors::Fatal(err_handler.FormatErrorMessage(err_msg_level))); \ } namespace cinn { diff --git a/paddle/cinn/ir/schedule/impl/reduction.cc b/paddle/cinn/ir/schedule/impl/reduction.cc index 6a28b40741388..6dec0ab489cac 100644 --- a/paddle/cinn/ir/schedule/impl/reduction.cc +++ b/paddle/cinn/ir/schedule/impl/reduction.cc @@ -26,10 +26,11 @@ * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW( \ + phi::errors::Fatal(err_handler.FormatErrorMessage(err_msg_level))); \ } namespace cinn { @@ -50,7 +51,9 @@ Expr DyScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) { CINN_IR_SCHEDULE_END(this->err_msg_level_); } -Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { +Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init) { CINN_IR_SCHEDULE_BEGIN() std::string primitive = "FactorizeReduction"; std::ostringstream os; @@ -103,6 +106,7 @@ Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { original_update_stmt, rf_tensor, var2loops, + Expr(false), rf_axis); rf_block_creater.CreateBlock(); RBBlockCreater wb_block_creater(original_block, @@ -115,7 +119,8 @@ Expr DyScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { wb_block_creater.CreateBlock(); Expr rf_body = rf_block_creater.CreateLoops(); - Expr wb_body = wb_block_creater.CreateLoops(); + Expr wb_body = wb_block_creater.CreateLoops( + /* with_init = */ with_write_back_block_init); Expr new_computational_body = Block::Make({rf_body, wb_body}); @@ -144,7 +149,9 @@ Expr StScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) { return rf_create.CreateRfAllStmts(); } -Expr StScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { +Expr StScheduleImpl::FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init) { std::string primitive = "FactorizeReduction"; // Get child block of the rf_loop and check. std::vector blocks = GetChildBlocks(rf_loop); @@ -165,6 +172,12 @@ Expr StScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { VLOG(3) << "before FactorizeReduction, original computational body of the " "reduction is:\n" << original_loops[0]; + Expr bound_check(false); + auto first_st = original_loops.back().As()->body.As()->stmts[0]; + if (first_st.As()) { + bound_check = first_st.As()->condition; + } + std::map var2loops; for (const Expr& loop : original_loops) { var2loops[loop.As()->loop_var] = loop; @@ -193,6 +206,7 @@ Expr StScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { original_update_stmt, rf_tensor, var2loops, + bound_check, rf_axis); rf_block_creater.CreateBlock(); RBBlockCreater wb_block_creater(original_block, @@ -205,7 +219,8 @@ Expr StScheduleImpl::FactorizeReduction(const Expr& rf_loop, int rf_axis) { wb_block_creater.CreateBlock(); Expr rf_body = rf_block_creater.CreateLoops(); - Expr wb_body = wb_block_creater.CreateLoops(); + Expr wb_body = wb_block_creater.CreateLoops( + /* with_init = */ with_write_back_block_init); Expr new_computational_body = Block::Make({rf_body, wb_body}); diff --git a/paddle/cinn/ir/schedule/impl/storage.cc b/paddle/cinn/ir/schedule/impl/storage.cc index 0233f8c5caa63..c4642f31c2202 100644 --- a/paddle/cinn/ir/schedule/impl/storage.cc +++ b/paddle/cinn/ir/schedule/impl/storage.cc @@ -26,10 +26,11 @@ * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW( \ + phi::errors::Fatal(err_handler.FormatErrorMessage(err_msg_level))); \ } namespace cinn { diff --git a/paddle/cinn/ir/schedule/ir_schedule.cc b/paddle/cinn/ir/schedule/ir_schedule.cc index 7bf684acfc6a9..6143de1f7b433 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.cc +++ b/paddle/cinn/ir/schedule/ir_schedule.cc @@ -85,10 +85,11 @@ std::unique_ptr ScheduleBase::Make(ModuleExpr&& module_expr, * @param err_msg_level A ScheduleErrorMessageLevel enum, level of error message * printing */ -#define CINN_IR_SCHEDULE_END(err_msg_level) \ - } \ - catch (const utils::ErrorHandler& err_handler) { \ - CINN_THROW(err_handler.FormatErrorMessage(err_msg_level)); \ +#define CINN_IR_SCHEDULE_END(err_msg_level) \ + } \ + catch (const utils::ErrorHandler& err_handler) { \ + PADDLE_THROW( \ + phi::errors::Fatal(err_handler.FormatErrorMessage(err_msg_level))); \ } void BaseInliner::operator()(Expr* expr) { @@ -449,6 +450,16 @@ Expr IRSchedule::Fuse(const Expr& block, const std::vector& loops_index) { return result; } +void IRSchedule::Broadcast(const std::string& block_name, + const BroadcastInfo& info) { + impl_->Broadcast(block_name, info); +} + +void IRSchedule::BroadcastToElementwise(const std::string& block_name, + const std::vector& axes) { + impl_->BroadcastToElementwise(block_name, axes); +} + void IRSchedule::ComputeAt(const Expr& block, const Expr& loop, bool keep_unit_loops) { @@ -619,12 +630,17 @@ Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) { return result; } -Expr IRSchedule::FactorizeReduction(const Expr& rf_loop, int rf_axis) { - auto result = impl_->FactorizeReduction(rf_loop, rf_axis); - trace_.Append(ScheduleDesc::Step("FactorizeReduction", - {{"rf_loop", std::vector({rf_loop})}}, - {{"rf_axis", rf_axis}}, - {result})); +Expr IRSchedule::FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init) { + auto result = + impl_->FactorizeReduction(rf_loop, rf_axis, with_write_back_block_init); + trace_.Append(ScheduleDesc::Step( + "FactorizeReduction", + {{"rf_loop", std::vector({rf_loop})}}, + {{"rf_axis", rf_axis}, + {"with_write_back_block_init", with_write_back_block_init}}, + {result})); return result; } @@ -648,7 +664,9 @@ void IRSchedule::Annotate(const Expr& block, TRACE_ANNOTATE_ITEM(std::string, AnnotateStringAttr) #undef TRACE_ANNOTATE_ITEM - LOG(FATAL) << "Value of attribute:" << key << " input unsupported data type"; + std::stringstream ss; + ss << "Value of attribute:" << key << " input unsupported data type"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } void IRSchedule::Unannotate(Expr& block, const std::string& key) { diff --git a/paddle/cinn/ir/schedule/ir_schedule.h b/paddle/cinn/ir/schedule/ir_schedule.h index 9ea4eb9f59b6f..cab1b0d38d868 100644 --- a/paddle/cinn/ir/schedule/ir_schedule.h +++ b/paddle/cinn/ir/schedule/ir_schedule.h @@ -195,6 +195,12 @@ class IRSchedule { * @param memory_type String that indicates the buffer's storage scope. * @return The buffer's cache. */ + + void Broadcast(const std::string& block_name, const BroadcastInfo& info); + + void BroadcastToElementwise(const std::string& block_name, + const std::vector& axes); + Expr CacheRead(const Expr& block, int read_buffer_index, const std::string& memory_type); @@ -402,7 +408,9 @@ class IRSchedule { * B[i] = B[i] + rf_B[j, i] * \endcode */ - Expr FactorizeReduction(const Expr& rf_loop, int rf_axis); + Expr FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init = true); /*! * \brief Annotate a block with a key-value pair to set as its attribute diff --git a/paddle/cinn/ir/schedule/ir_schedule_error.cc b/paddle/cinn/ir/schedule/ir_schedule_error.cc index 3467df28e5485..0b7a098264632 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_error.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_error.cc @@ -21,7 +21,7 @@ namespace ir { std::string IRScheduleErrorHandler::GeneralErrorMessage() const { std::ostringstream os; - os << "[IRScheduleError] An error occurred in the scheduel primitive < " + os << "[IRScheduleError] An error occurred in the schedule primitive < " << this->primitive_ << " >. " << std::endl; os << indent_str_ << "[Error info] " << this->err_msg_; return os.str(); diff --git a/paddle/cinn/ir/schedule/ir_schedule_util.cc b/paddle/cinn/ir/schedule/ir_schedule_util.cc index ba98382ebbf2f..833e1dfce9226 100644 --- a/paddle/cinn/ir/schedule/ir_schedule_util.cc +++ b/paddle/cinn/ir/schedule/ir_schedule_util.cc @@ -113,7 +113,8 @@ void SetCudaAxisInfo(Expr* lowered_func) { info.set_grid_dim(bind_info.offset, range); } } else { - LOG(FATAL) << "The for loop's bind info should be gpu block or thread!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "The for loop's bind info should be gpu block or thread!")); } } return (x->As() && x->As()->bind_info().valid()); @@ -207,7 +208,7 @@ void ReplaceExpr(Expr* source, const std::vector& candidates) { CHECK_EQ(replaced.size(), candidates.size()) << "In ReplaceExpr, the size of Vars to be replaced must be equal to the " - "size of cadidate Exprs! Please check."; + "size of candidate Exprs! Please check."; if (replaced.empty()) return; std::map replacing_map; for (int i = 0; i < replaced.size(); ++i) { @@ -264,20 +265,14 @@ std::vector ValidateFactors(const std::vector& factors, if (!has_minus_one) { if (product < total_extent) { std::ostringstream os; - os << "In Split, the factors' product should be not larger than or equal " - "to original loop's extent!" - << std::endl; + os << "In Split, the factors' product[" << product + << "] should be not larger than or equal " + "to original loop's extent[" + << total_extent << "]!" << std::endl; throw IRScheduleErrorHandler(primitive, os.str(), module_expr); } return validated_factors; } else { - if (product > total_extent) { - std::ostringstream os; - os << "In Split, the factors' product should be not larger than or equal " - "to original loop's extent!" - << std::endl; - throw IRScheduleErrorHandler(primitive, os.str(), module_expr); - } int minus_one_candidate = static_cast( ceil(static_cast(total_extent) / static_cast(product))); for (int i = 0; i < validated_factors.size(); ++i) { @@ -336,10 +331,11 @@ std::vector GetLoopsOfExpr(const Expr& expr, const Expr& root) { root, [&](const Expr* x) { return x->As() && Contains(*x, expr); }); std::vector result(loop_nodes.begin(), loop_nodes.end()); - if (result.empty()) - LOG(FATAL) << "Didn't find expr's : \n" - << expr << "\n loops in root : \n" - << root; + if (result.empty()) { + std::stringstream ss; + ss << "Didn't find expr's : \n" << expr << "\n loops in root : \n" << root; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); + } std::sort(result.begin(), result.end(), [&](Expr i, Expr j) { return (utils::GetStreamCnt(i).size() > utils::GetStreamCnt(j).size()); }); @@ -587,8 +583,8 @@ const std::set CollectLoopsToSet( CHECK(i.As()) << "loops should be For node! Please check."; auto inserted = for_loops.insert(i); if (!inserted.second) { - LOG(FATAL) - << "There should be no duplicate elements in loops! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "There should be no duplicate elements in loops! Please check.")); } } return for_loops; @@ -614,8 +610,9 @@ std::pair GetBoundaryOfReorderRange( // Then loop_i should be the new top if (visited.count(v_for)) { if (v_for != top) { - LOG(FATAL) << "Loops in GetBoundaryOfReorderRange is not a chain! " - "Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Loops in GetBoundaryOfReorderRange is not a chain! " + "Please check.")); } top = loop_i; break; @@ -644,8 +641,8 @@ std::vector GetLoopsInRange(const Expr& top, const Expr& bottom) { for (auto loop_iter = top; loop_iter != bottom;) { Expr tmp = GetNextForLoop(loop_iter); if (!tmp.defined()) - LOG(FATAL) - << "Loops in GetLoopsInReorderRange is not a chain! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Loops in GetLoopsInReorderRange is not a chain! Please check.")); chain.push_back(loop_iter); loop_iter = tmp; } @@ -764,7 +761,7 @@ Expr ConstructNewLoopChain(const std::vector& chain, // } } // } } // - // We go throuph origin loop and check other body stmts, adding it as another + // We go through origin loop and check other body stmts, adding it as another // chain, such as: // // for (i, 0, 32) { @@ -1022,7 +1019,7 @@ void InsertBlock(Expr& for_loop, const Expr& insertion, int index) { // NOLINT auto dst_it = dst_block->stmts.begin() + index; if (dst_it->As()) { auto* inserted_block = dst_it->As()->true_case.As(); - CHECK(inserted_block) << "the IfThenElse node to be inserted shuold " + CHECK(inserted_block) << "the IfThenElse node to be inserted should " "contain a true_case block"; inserted_block->stmts.insert(inserted_block->stmts.begin(), insertion); } else { @@ -1060,7 +1057,7 @@ std::vector CalculateRequiredRegions( } std::vector required_buffer_range; - // deduce accessed regions of the provided tensor in block by itering each + // deduce accessed regions of the provided tensor in block by iterating each // required block for (const Expr& pro_node : provided_nodes) { std::string provided_tensor_name = diff --git a/paddle/cinn/ir/schedule/schedule_base.cc b/paddle/cinn/ir/schedule/schedule_base.cc index 8e6573edeab0e..b34221d73f052 100644 --- a/paddle/cinn/ir/schedule/schedule_base.cc +++ b/paddle/cinn/ir/schedule/schedule_base.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/cinn/ir/schedule/schedule_base.h" +#include "paddle/cinn/ir/schedule/ir_schedule_util.h" namespace cinn { namespace ir { @@ -70,5 +71,181 @@ void ScheduleBase::Replace(const Expr& src_sref, const Expr& tgt_stmt) { } } +void ScheduleBase::BroadcastToElementwise(const std::string& block_name, + const std::vector& axes) { + std::vector all_loops = this->GetLoops(block_name); + Expr broadcast_body = all_loops.back().As()->body; + + auto schedule_realize = broadcast_body.As() + ->expr_fields()[0] + ->As(); + auto schedule_block = + schedule_realize->schedule_block.As(); + auto iter_vars = schedule_block->iter_vars; + + auto load_exprs = ir::ir_utils::CollectIRNodesInOrder( + schedule_block->body, [&](const Expr* x) { return x->As(); }); + + for (auto load_expr : load_exprs) { + auto load = load_expr.As(); + load->indices.resize(all_loops.size(), Expr(0)); + + for (size_t i = 0; i < axes.size(); ++i) { + load->indices[axes[i]] = schedule_block->iter_vars[axes[i]]; + } + } +} + +void ScheduleBase::Broadcast(const std::string& block_name, + const BroadcastInfo& info) { + auto axes = info.broadcast_axes; + + if (axes.size() == 0) { + return; + } + std::vector all_loops = this->GetLoops(block_name); + if (axes[0] >= all_loops.size()) { + throw std::runtime_error("axes execeed loop size"); + } + + // Get Last loop + Expr broadcast_body = all_loops.back().As()->body; + + auto schedule_realize = broadcast_body.As() + ->expr_fields()[0] + ->As(); + auto schedule_block = + schedule_realize->schedule_block.As(); + + auto iter_vars = schedule_block->iter_vars; + auto iter_values = schedule_realize->iter_values; + + auto factors = info.output_shape; + auto full_broadcast = info.full_broadcast; + auto first_broadcast = info.first_broadcast; + if (info.split_first) { + // iter value is one + for (size_t i = 0; i < axes.size(); ++i) { + // new_extent + auto axis = axes[i]; + auto loop_temp = all_loops[axis].As(); + int extent = factors[i]; + loop_temp->extent = Expr(extent); + if (extent < 0) { + ir::Dim dim("var_00", info.output_dim_expr[i]); + loop_temp->extent = Expr(dim->dim_expr); + } + + if (info.with_constrain) { + auto check = ir::EQ::Make(loop_temp->loop_var, Expr(0)); + schedule_block->body = + ir::IfThenElse::Make(check, schedule_block->body); + } + } + + // change load and store + // get new offset + all_loops = this->GetLoops(block_name); + auto offset = Expr(0); + auto stride = Expr(1); + auto in_offset = Expr(0); + + std::set brodacast_set(info.broadcast_axes.begin(), + info.broadcast_axes.end()); + for (int i = all_loops.size() - 1; i >= 0; --i) { + auto loop_temp = all_loops[i].As(); + offset = offset + loop_temp->loop_var * stride; + + stride = stride * loop_temp->extent; + if (!brodacast_set.count(i)) { + in_offset = in_offset + loop_temp->loop_var * stride; + } + } + + auto exprs = ir::ir_utils::CollectIRNodesInOrder( + schedule_block->body, + [&](const Expr* x) { return x->As(); }); + for (auto expr : exprs) { + auto store = expr.As(); + store->indices[0] = offset; + } + + exprs = ir::ir_utils::CollectIRNodesInOrder( + schedule_block->body, [&](const Expr* x) { return x->As(); }); + + for (auto expr : exprs) { + auto load = expr.As(); + if (!info.first_broadcast) { + load->indices[0] = offset; + } else { + load->indices[0] = in_offset; + } + } + + return; + } + + for (size_t i = 0; i < axes.size(); ++i) { + // new_extent + auto axis = axes[i]; + auto loop_temp = all_loops[axis].As(); + int extent = factors[i]; + loop_temp->extent = Expr(extent); + if (extent < 0) { + ir::Dim dim("var_00", info.output_dim_expr[i]); + loop_temp->extent = Expr(dim->dim_expr); + } + + if (!full_broadcast && (!(info.with_constrain))) { + schedule_realize->iter_values[axis] = loop_temp->loop_var; + } + + if (info.with_constrain) { + auto check = ir::EQ::Make(loop_temp->loop_var, Expr(0)); + schedule_block->body = ir::IfThenElse::Make(check, schedule_block->body); + } + } + + if (first_broadcast && !full_broadcast) { + auto exprs = ir::ir_utils::CollectIRNodesInOrder( + schedule_block->body, [&](const Expr* x) { return x->As(); }); + + if (info.op_name == "cinn_op.reshape") { + for (auto expr : exprs) { + auto load = expr.As(); + for (size_t k = 0; k < load->indices.size(); ++k) { + for (size_t i = 0; i < axes.size(); ++i) { + ReplaceExpr(&load->indices[k], + {schedule_block->iter_vars[axes[i]]}, + {Expr(0)}); + } + } + } + + return; + } + for (auto expr : exprs) { + auto load = expr.As(); + if (load->indices.size() == schedule_realize->iter_values.size()) { + for (size_t i = 0; i < axes.size(); ++i) { + load->indices[axes[i]] = Expr(0); + } + } else if (load->indices.size() < schedule_realize->iter_values.size()) { + // only one element + // replace t zeros + for (size_t k = 0; k < load->indices.size(); ++k) { + for (size_t i = 0; i < axes.size(); ++i) { + ReplaceExpr(&load->indices[k], + {schedule_block->iter_vars[axes[i]]}, + {Expr(0)}); + } + } + } else { + throw std::runtime_error("not support broadcast type yet"); + } + } + } +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/schedule/schedule_base.h b/paddle/cinn/ir/schedule/schedule_base.h index 6ce5caaeaad12..0deb44da000cd 100644 --- a/paddle/cinn/ir/schedule/schedule_base.h +++ b/paddle/cinn/ir/schedule/schedule_base.h @@ -18,12 +18,27 @@ #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/utils/error.h" #include "paddle/cinn/utils/random_engine.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" PD_DECLARE_int32(cinn_error_message_level); namespace cinn { namespace ir { +struct BroadcastInfo { + std::vector broadcast_axes; + std::vector output_shape; + std::vector output_dim_expr; + + bool with_constrain{false}; + bool first_broadcast{false}; + bool full_broadcast{false}; + std::string op_name; + + bool split_first{false}; + std::vector>> split_info; +}; + /** * A struct representing a module that contains Expr. This struct is only used * in Schedule process. @@ -95,6 +110,7 @@ class ScheduleBase { virtual std::vector GetAllBlocks() const = 0; virtual std::vector GetChildBlocks(const Expr& expr) const = 0; virtual Expr GetBlock(const std::string& block_name) const = 0; + virtual std::vector Split(const Expr& loop, const std::vector& factors) = 0; virtual std::vector Split(const Expr& loop, @@ -142,7 +158,9 @@ class ScheduleBase { virtual void ReverseComputeInline(const Expr& schedule_block) = 0; virtual void Bind(const Expr& loop, const std::string& thread_axis) = 0; virtual Expr Rfactor(const Expr& rf_loop, int rf_axis) = 0; - virtual Expr FactorizeReduction(const Expr& rf_loop, int rf_axis) = 0; + virtual Expr FactorizeReduction(const Expr& rf_loop, + int rf_axis, + bool with_write_back_block_init = true) = 0; virtual Expr AddUnitLoop(const Expr& block) const = 0; virtual void Annotate(const Expr& block, const std::string& key, @@ -159,6 +177,12 @@ class ScheduleBase { const std::vector& candidates, const std::vector& probs) = 0; + void Broadcast(const std::string& block_name, + const cinn::ir::BroadcastInfo& info); + + void BroadcastToElementwise(const std::string& block_name, + const std::vector& axes); + protected: void Replace(const Expr& src_sref, const Expr& tgt_stmt); diff --git a/paddle/cinn/ir/schedule/schedule_desc.cc b/paddle/cinn/ir/schedule/schedule_desc.cc index c9a26dfa1643d..fbf2a268054e1 100644 --- a/paddle/cinn/ir/schedule/schedule_desc.cc +++ b/paddle/cinn/ir/schedule/schedule_desc.cc @@ -27,7 +27,7 @@ namespace cinn { namespace ir { -// ------ Following codes are about `Apply` functions registry of variaous types +// ------ Following codes are about `Apply` functions registry of various types // of ScheduleDesc::Step class PackedStepContext; // uniformed function prototype of a scheduling operation in IRSchedule @@ -117,9 +117,11 @@ class PackedStepContext { try { return absl::get(attrs_.at(idx)); } catch (absl::bad_variant_access& ex) { - LOG(FATAL) << "Attribute cast error, idx:" << idx - << ", get tpye:" << typeid(AttrType).name() - << ", real index:" << attrs_.at(idx).index(); + std::stringstream ss; + ss << "Attribute cast error, idx:" << idx + << ", get type:" << typeid(AttrType).name() + << ", real index:" << attrs_.at(idx).index(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); throw ex; } } @@ -197,7 +199,7 @@ struct FreeFuncConverter { } }; -// used for formatting scheduling functions with variaous function signatures to +// used for formatting scheduling functions with various function signatures to // be uniformed form template struct ApplyFuncImpl; @@ -483,6 +485,7 @@ CINN_BUILD_STEP_KIND(Rfactor) CINN_BUILD_STEP_KIND(FactorizeReduction) .Inputs({"rf_loop"}) .Attrs({"rf_axis"}) + .Attrs({"with_write_back_block_init"}) .SetApplyFn(APPLY_FUNC_UNIFORM( FREE_FUNCTION_CONVERTER(&IRSchedule::FactorizeReduction))); @@ -600,7 +603,9 @@ void AttrVariantToProto(const utils::Attribute& attr, SET_DESC_REPEATED_ITEM(10, std::vector, LONGS, longs); SET_DESC_REPEATED_ITEM(11, std::vector, DOUBLES, doubles); default: - LOG(FATAL) << "Invalid index:" << attr.index(); + std::stringstream ss; + ss << "Invalid index:" << attr.index(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } #undef SET_DESC_SINGLE_ITEM @@ -634,7 +639,9 @@ utils::Attribute AttrProtoToVariant(const proto::ScheduleDesc_Attr& attr) { PARSE_DESC_REPEATED_ITEM(LONGS, longs, std::vector); PARSE_DESC_REPEATED_ITEM(DOUBLES, doubles, std::vector); default: - LOG(FATAL) << "Invalid type:" << attr.DebugString(); + std::stringstream ss; + ss << "Invalid type:" << attr.DebugString(); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } #undef PARSE_DESC_SINGLE_ITEM @@ -689,8 +696,8 @@ proto::ScheduleDesc ScheduleDesc::ToProto() const { } } - // each output Expr is represented by a formatted name, to be refered by - // suceeding steps + // each output Expr is represented by a formatted name, to be referred by + // succeeding steps for (auto&& expr : step.outputs) { std::string local_name = "e" + std::to_string(expr2name.size()); expr2name.emplace(expr, local_name); @@ -722,7 +729,7 @@ std::vector ScheduleDesc::ReplayWithProto( absl::flat_hash_map name2expr; std::vector last_outputs; - // resotre each scheduling step and apply to the new IRSchedule object + // restore each scheduling step and apply to the new IRSchedule object for (auto&& step_proto : desc_proto.steps()) { VLOG(4) << "Replay step:\n" << step_proto.DebugString(); ScheduleDesc::Step step; diff --git a/paddle/cinn/ir/tensor.cc b/paddle/cinn/ir/tensor.cc index 5224a2172ac5c..6c5ba14efe680 100644 --- a/paddle/cinn/ir/tensor.cc +++ b/paddle/cinn/ir/tensor.cc @@ -32,6 +32,8 @@ #include "paddle/cinn/poly/isl_utils.h" #include "paddle/cinn/poly/stage.h" +PD_DECLARE_bool(cinn_bucket_compile); + namespace cinn { namespace ir { @@ -359,7 +361,7 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, std::vector reduce_axis_input = stages[this]->origin_reduce_axis_names(); auto origin_domain = stages[this]->domain(); - auto reduce_axis_output = poly::GetRelatedOutputAxies( + auto reduce_axis_output = poly::GetRelatedOutputAxes( temp_transform, origin_domain, reduce_axis_input); std::set reduce_axis_output_set; for (auto &i : reduce_axis_output) { @@ -374,7 +376,7 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, } } - temp_transform = poly::RemoveAxiesByOutputNames( + temp_transform = poly::RemoveAxesByOutputNames( temp_transform, origin_domain, reduce_axis_output); //! When the first axis is not reduce axis, do ComputeAt. @@ -386,7 +388,7 @@ ir::Tensor _Tensor_::InitReduction(poly::StageMap stages, init_tensor->shape = shape; return init_tensor; } - //! When reduce axies are reordered to front, ComputeAt is illegal. + //! When reduce axes are reordered to front, ComputeAt is illegal. //! So we just copy transform and forloopInfo. isl_map_set_tuple_name( temp_transform.get(), isl_dim_in, init_reduce_tensor_name.c_str()); @@ -506,7 +508,9 @@ void _Tensor_::WithBuffer(const std::string &memory_type, } else if (memory_type == "global") { this->buffer->memory_type = MemoryType::Heap; } else { - LOG(FATAL) << "Not supported memory type " << memory_type; + std::stringstream ss; + ss << "Not supported memory type " << memory_type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } else { lang::Buffer buf(buf_type, buffer_name); @@ -520,7 +524,9 @@ void _Tensor_::WithBuffer(const std::string &memory_type, } else if (memory_type == "global") { buf->memory_type = MemoryType::Heap; } else { - LOG(FATAL) << "Not supported memory type " << memory_type; + std::stringstream ss; + ss << "Not supported memory type " << memory_type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } } @@ -689,7 +695,18 @@ ir::Tensor _Tensor_::ReshapeCopied(const std::vector &shape, } Shared CreateStage(Tensor tensor) { - auto isl_domain = tensor->GenerateIslDomain(); + isl::set isl_domain; + // We will remove isl, and the subsequent compilation process will no longer + // use it. But it has not been completely removed in the process. it cannot be + // supported here under dynamic shape. Therefore, we temporarily use fake + // domain. + if (FLAGS_cinn_bucket_compile) { + poly::Domain fake_domain(Context::isl_ctx(), "fake_domain", {}); + isl_domain = fake_domain.to_isl(); + } else { + isl_domain = tensor->GenerateIslDomain(); + } + return poly::Stage::New(isl_domain, tensor->body(), tensor.self()); } diff --git a/paddle/cinn/ir/test/tensor_test.cc b/paddle/cinn/ir/test/tensor_test.cc index cea1263f2aba3..4bf64f309735e 100644 --- a/paddle/cinn/ir/test/tensor_test.cc +++ b/paddle/cinn/ir/test/tensor_test.cc @@ -144,7 +144,7 @@ TEST(Tensor, ReshapeCopied) { stages->InsertLazily(B); - ir::Module::Builder builder("some_modue", cinn::common::DefaultHostTarget()); + ir::Module::Builder builder("some_module", cinn::common::DefaultHostTarget()); auto func = lang::Lower("fn", stages, {A, B}, {}, {}, &builder); backends::CodeGenC codegenc(cinn::common::DefaultHostTarget()); diff --git a/paddle/cinn/ir/utils/ir_copy.cc b/paddle/cinn/ir/utils/ir_copy.cc index c560652b5442b..e463df0fb067d 100644 --- a/paddle/cinn/ir/utils/ir_copy.cc +++ b/paddle/cinn/ir/utils/ir_copy.cc @@ -31,9 +31,15 @@ namespace ir { namespace ir_utils { namespace { struct IRCopyVisitor : public ir::IRVisitorRequireReImpl { + public: + explicit IRCopyVisitor(bool copy_buffer_node) + : copy_buffer_node(copy_buffer_node) {} + // Use maps to unify all the copied tensors and buffers. std::map tensor_map; std::map buffer_map; + // whether to deep copy Buffer node. + bool copy_buffer_node; Expr Visit(const Expr* op) override { return IRVisitorRequireReImpl::Visit(op); @@ -188,9 +194,14 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl { auto name = op->name; auto tensor = make_shared<_Tensor_>(); + // tensor->buffer = op->buffer; if (buffer_expr.defined()) { - auto buffer = Visit(&buffer_expr); - tensor->buffer = buffer.as_buffer_ref(); + if (copy_buffer_node) { + auto buffer = Visit(&buffer_expr); + tensor->buffer = buffer.as_buffer_ref(); + } else { + tensor->buffer = op->buffer; + } } tensor->domain = domain; tensor->shape = shape; @@ -405,6 +416,7 @@ struct IRCopyVisitor : public ir::IRVisitorRequireReImpl { Expr res = ir::ScheduleBlock::Make( iter_vars, read_buffers, write_buffers, op->name, Visit(&op->body)); res.As()->attrs = op->attrs; + res.As()->reduce_method = op->reduce_method; return res; } @@ -489,35 +501,36 @@ Expr IRCopyVisitor::Visit(const ir::intrinsics::BuiltinIntrin* op) { op->name, op->args, op->id, op->arg_nums, op->type()); } } // namespace -Expr IRCopy(Expr x) { - IRCopyVisitor visitor; +Expr IRCopy(Expr x, bool copy_buffer_node) { + IRCopyVisitor visitor(copy_buffer_node); auto copied = visitor.Visit(&x); return copied; } -std::vector IRCopy(const std::vector& x) { +std::vector IRCopy(const std::vector& x, bool copy_buffer_node) { std::vector res; for (auto& i : x) { - res.emplace_back(IRCopy(i)); + res.emplace_back(IRCopy(i, copy_buffer_node)); } return res; } -ir::ModuleExpr IRCopy(const ir::ModuleExpr& x) { - return ir::ModuleExpr(IRCopy(x.GetExprs())); +ir::ModuleExpr IRCopy(const ir::ModuleExpr& x, bool copy_buffer_node) { + return ir::ModuleExpr(IRCopy(x.GetExprs(), copy_buffer_node)); } -ir::LoweredFunc IRCopy(const ir::LoweredFunc& x) { - ir::Expr copy_func_expr = IRCopy(static_cast(x)); +ir::LoweredFunc IRCopy(const ir::LoweredFunc& x, bool copy_buffer_node) { + ir::Expr copy_func_expr = IRCopy(static_cast(x), copy_buffer_node); ir::_LoweredFunc_* copy_func_ptr = copy_func_expr.As(); return ir::LoweredFunc(copy_func_ptr); } // TODO(zhhsplendid): make IRCopy of std::vector a template function -std::vector IRCopy(const std::vector& x) { +std::vector IRCopy(const std::vector& x, + bool copy_buffer_node) { std::vector res; for (const auto& i : x) { - res.emplace_back(IRCopy(i)); + res.emplace_back(IRCopy(i, copy_buffer_node)); } return res; } diff --git a/paddle/cinn/ir/utils/ir_copy.h b/paddle/cinn/ir/utils/ir_copy.h index 594f07e91cfa0..69bcc16ab13dd 100644 --- a/paddle/cinn/ir/utils/ir_copy.h +++ b/paddle/cinn/ir/utils/ir_copy.h @@ -28,15 +28,17 @@ class ModuleExpr; namespace ir_utils { //! Shallow copy an expression. -Expr IRCopy(Expr x); +Expr IRCopy(Expr x, bool copy_buffer_node = true); -std::vector IRCopy(const std::vector& x); +std::vector IRCopy(const std::vector& x, + bool copy_buffer_node = true); -ir::ModuleExpr IRCopy(const ir::ModuleExpr& x); +ir::ModuleExpr IRCopy(const ir::ModuleExpr& x, bool copy_buffer_node = true); -ir::LoweredFunc IRCopy(const ir::LoweredFunc& x); +ir::LoweredFunc IRCopy(const ir::LoweredFunc& x, bool copy_buffer_node = true); -std::vector IRCopy(const std::vector& x); +std::vector IRCopy(const std::vector& x, + bool copy_buffer_node = true); } // namespace ir_utils } // namespace ir diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index e4ebaca653bae..fc36e87cbfc31 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -59,7 +59,7 @@ struct IrNodesCollector : public IRVisitorRequireReImpl { NODETY_FORALL(__) default: - LOG(FATAL) << "not supported NodeTy"; + PADDLE_THROW(phi::errors::InvalidArgument("not supported NodeTy")); #undef __ } } diff --git a/paddle/cinn/ir/utils/ir_replace.cc b/paddle/cinn/ir/utils/ir_replace.cc index 7e64e7aaa7e7f..5e782536c1d3a 100644 --- a/paddle/cinn/ir/utils/ir_replace.cc +++ b/paddle/cinn/ir/utils/ir_replace.cc @@ -50,7 +50,7 @@ struct IrReplaceVarBroadcastMutator : ir::IRMutator { void Visit(const ir::Broadcast* op, Expr* expr) override { if (op->node_type() == from_->node_type() && from_repr_ == GetStreamCnt(*expr)) { - *expr = ir::ir_utils::IRCopy(to_); + *expr = ir::ir_utils::IRCopy(to_, /* copy_buffer_node = */ false); } } @@ -68,7 +68,7 @@ struct IrReplaceMutator : ir::IRMutator { void Visit(const Expr* op, Expr* expr) override { ir::IRMutator<>::Visit(expr, expr); if (from_repr_ == GetStreamCnt(*expr)) { - *expr = ir::ir_utils::IRCopy(to_); + *expr = ir::ir_utils::IRCopy(to_, /* copy_buffer_node = */ false); } } diff --git a/paddle/cinn/lang/builtin.cc b/paddle/cinn/lang/builtin.cc index b50a49096847b..fd5f63d13ed96 100644 --- a/paddle/cinn/lang/builtin.cc +++ b/paddle/cinn/lang/builtin.cc @@ -96,13 +96,17 @@ EXTERN_CALL_IMP(Popc, popc); #undef EXTERN_CALL_IMP #undef EXTERN_CALL_IMP_NO_VEC -#define EXTERN_BINARY_CALL_IMP(name__, target__) \ - Expr name__(Expr a, Expr b) { \ - CHECK_EQ(a.type(), b.type()) \ - << #name__ << "'s inputs type not equal, where a:" << a.type() \ - << " but b:" << b.type(); \ - return ir::Call::Make( \ - a->type(), #target__, {a, b}, {}, ir::CallType::Extern); \ +#define EXTERN_BINARY_CALL_IMP(name__, target__) \ + Expr name__(Expr a, Expr b) { \ + PADDLE_ENFORCE_EQ( \ + a.type(), \ + b.type(), \ + phi::errors::InvalidArgument(#name__ "'s inputs type not equal," \ + "where a:%s but b:%s.", \ + a.type(), \ + b.type())); \ + return ir::Call::Make( \ + a->type(), #target__, {a, b}, {}, ir::CallType::Extern); \ } EXTERN_BINARY_CALL_IMP(Remainder, mod) @@ -117,9 +121,13 @@ Expr Zero(const Type& type) { return ir::Zero(type); } Expr One(const Type& type) { return ir::One(type); } Expr FloorDivide(Expr a, Expr b) { - CHECK_EQ(a.type(), b.type()) - << "FloorDivide's inputs type not equal, where a:" << a.type() - << " but b:" << b.type(); + PADDLE_ENFORCE_EQ(a.type(), + b.type(), + phi::errors::InvalidArgument( + "FloorDivide's inputs type not equal, where a:%s " + " but b:%s.", + a.type(), + b.type())); if (a.type().is_float()) { return Floor(a / b); } else if (a.type().is_uint()) { @@ -136,7 +144,12 @@ Expr FloorDivide(Expr a, Expr b) { } Expr min_value(const Type& type) { - CHECK_EQ(type.lanes(), 1); + PADDLE_ENFORCE_EQ( + type.lanes(), + 1, + phi::errors::InvalidArgument("The value of min type's lanes is incorrect" + "Expected value is 1, but receive %d. ", + type.lanes())); #define FOR_CASE(type__) \ if (type == type_of()) { \ return Expr(static_cast(std::numeric_limits::lowest())); \ @@ -158,7 +171,12 @@ Expr min_value(const Type& type) { } Expr max_value(const Type& type) { - CHECK_EQ(type.lanes(), 1); + PADDLE_ENFORCE_EQ( + type.lanes(), + 1, + phi::errors::InvalidArgument("The value of max type's lanes is incorrect" + "Expected value is 1, but receive %d. ", + type.lanes())); #define FOR_CASE(type__) \ if (type == type_of()) { \ @@ -183,7 +201,12 @@ Expr max_value(const Type& type) { } Expr Epsilon(const Type& type) { - CHECK_EQ(type.lanes(), 1); + PADDLE_ENFORCE_EQ(type.lanes(), + 1, + phi::errors::InvalidArgument( + "The value of epsilon type's lanes is incorrect" + "Expected value is 1, but receive %d. ", + type.lanes())); #define FOR_CASE(type__) \ if (type == type_of()) { \ @@ -219,7 +242,9 @@ Expr Abs(Expr e) { } return ir::Select::Make(e > Zero(e->type()), e, -e); } else { - LOG(FATAL) << "Abs Not support data type " << type; + std::stringstream ss; + ss << "Abs Not support data type " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return e; } @@ -235,13 +260,20 @@ Expr IsNan(Expr e) { } return CallExtern("isnan", {e}, {{"vectorizable", false}}); } else { - LOG(FATAL) << type << "is not supported for isnan op."; + std::stringstream ss; + ss << type << "is not supported for isnan op."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return e; } } Expr Infinity(const Type& type) { - CHECK_EQ(type.lanes(), 1U); + PADDLE_ENFORCE_EQ(type.lanes(), + 1U, + phi::errors::InvalidArgument( + "The value of infinity type's lanes is incorrect" + "Expected value is 1, but receive %d. ", + type.lanes())); if (type.is_float()) { if (type.bits() == 64) { return make_const(type, std::numeric_limits::infinity()); @@ -251,7 +283,9 @@ Expr Infinity(const Type& type) { return make_const(type, std::numeric_limits::infinity()); } } - LOG(FATAL) << "Cannot decide infinity for type " << type; + std::stringstream ss; + ss << "Cannot decide infinity for type " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return Expr(); } @@ -266,7 +300,9 @@ Expr IsInf(Expr e) { } return CallExtern("isinf", {e}, {{"vectorizable", false}}); } else { - LOG(FATAL) << type << "is not supported for isinf op."; + std::stringstream ss; + ss << type << "is not supported for isinf op."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return e; } } diff --git a/paddle/cinn/lang/compute.cc b/paddle/cinn/lang/compute.cc index 4828eaac64e13..946b87857f66f 100644 --- a/paddle/cinn/lang/compute.cc +++ b/paddle/cinn/lang/compute.cc @@ -47,7 +47,12 @@ ir::Tensor Compute(const std::vector &domain, return Compute( domain, [fn](const std::vector &axis) -> Expr { - CHECK_EQ(axis.size(), 1); + PADDLE_ENFORCE_EQ(axis.size(), + 1, + phi::errors::InvalidArgument( + "The size of axis vector is incorrect" + "Expected value is 1, but receive %d. ", + axis.size())); return fn(axis[0]); }, name, @@ -61,7 +66,12 @@ ir::Tensor Compute(const std::vector &domain, return Compute( domain, [fn](const std::vector &axis) -> Expr { - CHECK_EQ(axis.size(), 2); + PADDLE_ENFORCE_EQ(axis.size(), + 2, + phi::errors::InvalidArgument( + "The size of axis vector is incorrect" + "Expected value is 2, but receive %d. ", + axis.size())); return fn(axis[0], axis[1]); }, name, @@ -75,7 +85,12 @@ ir::Tensor Compute(const std::vector &domain, return Compute( domain, [fn](const std::vector &axis) -> Expr { - CHECK_EQ(axis.size(), 3); + PADDLE_ENFORCE_EQ(axis.size(), + 3, + phi::errors::InvalidArgument( + "The size of axis vector is incorrect" + "Expected value is 3, but receive %d. ", + axis.size())); return fn(axis[0], axis[1], axis[2]); }, name, @@ -89,7 +104,12 @@ ir::Tensor Compute(const std::vector &domain, return Compute( domain, [fn](const std::vector &axis) -> Expr { - CHECK_EQ(axis.size(), 4); + PADDLE_ENFORCE_EQ(axis.size(), + 4, + phi::errors::InvalidArgument( + "The size of axis vector is incorrect" + "Expected value is 4, but receive %d. ", + axis.size())); return fn(axis[0], axis[1], axis[2], axis[3]); }, name, @@ -103,7 +123,12 @@ ir::Tensor Compute(const std::vector &domain, return Compute( domain, [fn](const std::vector &axis) -> Expr { - CHECK_EQ(axis.size(), 5); + PADDLE_ENFORCE_EQ(axis.size(), + 5, + phi::errors::InvalidArgument( + "The size of axis vector is incorrect" + "Expected value is 5, but receive %d. ", + axis.size())); return fn(axis[0], axis[1], axis[2], axis[3], axis[4]); }, name, @@ -117,7 +142,12 @@ ir::Tensor Compute(const std::vector &domain, return Compute( domain, [fn](const std::vector &axis) -> Expr { - CHECK_EQ(axis.size(), 6); + PADDLE_ENFORCE_EQ(axis.size(), + 6, + phi::errors::InvalidArgument( + "The size of axis vector is incorrect" + "Expected value is 6, but receive %d. ", + axis.size())); return fn(axis[0], axis[1], axis[2], axis[3], axis[4], axis[5]); }, name, @@ -187,6 +217,13 @@ ir::Tensor Compute(const std::vector &domain, domain_without_reduce_axis, op, reduce_axis); + const auto set_keep_dim_for_tensor = [&]() { + for (int i = 0; i < _axis.size(); ++i) { + const auto &axis_var = _axis.at(i); + tensor->axis_[i]->is_keepdim = axis_var.as_var_ref()->is_keepdim; + } + }; + set_keep_dim_for_tensor(); return tensor; } diff --git a/paddle/cinn/lang/lower.cc b/paddle/cinn/lang/lower.cc index ac94803a2128a..75be3ee619582 100644 --- a/paddle/cinn/lang/lower.cc +++ b/paddle/cinn/lang/lower.cc @@ -337,8 +337,11 @@ ir::LoweredFunc LowerToAst(const std::string& name, const Target& target) { std::vector result = LowerToAstVec(name, tensor_args, tensor_group, target); - CHECK_EQ(result.size(), 1UL) << "LowerToAst contains not only 1 LoweredFunc, " - "use LowerToAstVec instead."; + PADDLE_ENFORCE_EQ(result.size(), + 1UL, + phi::errors::InvalidArgument( + "LowerToAst contains not only 1 LoweredFunc, " + "use LowerToAstVec instead.")); return result[0]; } diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc index 1b085c03e2240..f938d1712c92f 100644 --- a/paddle/cinn/lang/lower_impl.cc +++ b/paddle/cinn/lang/lower_impl.cc @@ -586,7 +586,7 @@ std::vector LowerImpl::operator()() { for (auto& i : tensor_args_) { LOG(INFO) << i->name; } - LOG(FATAL) << "Fatal Error!"; + PADDLE_THROW(phi::errors::InvalidArgument("Fatal Error!")); } Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer; } @@ -718,7 +718,13 @@ std::vector LowerImpl::GenerateFunctionBody( std::unordered_map> resized_buffer_cache; for (auto& group : schedule->groups) { - CHECK_GT(group.nodes.size(), 0) << "group is empty"; + PADDLE_ENFORCE_GT( + group.nodes.size(), + 0, + phi::errors::InvalidArgument( + "Group is empty" + "Expected size of group is larger than 0, but receive %d. ", + group.nodes.size())); bool all_temp_tensor = true; for (auto& node : group.nodes) { if (!tensor_map.count(node->id())) { diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h index b5f82ba7312e6..840fcfce860a0 100644 --- a/paddle/cinn/lang/lower_impl.h +++ b/paddle/cinn/lang/lower_impl.h @@ -150,8 +150,8 @@ class LowerImpl { std::vector CollectTemporaryTensors(); /** - * \brief Check both the tensor_args and sclar_args not contain duplication - * (different arguemnt with the same name). + * \brief Check both the tensor_args and scalar_args not contain duplication + * (different argument with the same name). */ void CheckArgsUnique(); @@ -304,7 +304,7 @@ struct MarkParallelMutator : public ir::IRMutator { auto it = parallels.find(tensor_n->name); if (it != parallels.end()) { for (int level : it->second) { - VLOG(1) << "Mark " << level << " Paralled"; + VLOG(1) << "Mark " << level << " Parallelled"; CHECK_LT(level, stack.size()); stack[level]->set_parallel(); } diff --git a/paddle/cinn/lang/lower_tensor_group.cc b/paddle/cinn/lang/lower_tensor_group.cc index 93453621e1839..c6b3ba5173565 100644 --- a/paddle/cinn/lang/lower_tensor_group.cc +++ b/paddle/cinn/lang/lower_tensor_group.cc @@ -81,7 +81,7 @@ std::vector LowerTensorGroup::operator()() { for (auto& i : tensor_args_) { LOG(INFO) << i->name; } - LOG(FATAL) << "Fatal Error!"; + PADDLE_THROW(phi::errors::InvalidArgument("Fatal Error!")); } Reference(&arg)->buffer = tensor_map.at(arg->name)->buffer; } diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index d5f758623d628..e6f3aa2ee6c4f 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -29,7 +29,10 @@ gather_srcs( resize_buffer.cc update_buffer_axis_pass.cc trans_buffer_with_dynamic_shape.cc - schedule_block_dce.cc) + schedule_block_dce.cc + eliminate_common_factor_of_local_index.cc + if_fusion.cc + eliminate_common_global_memory_read.cc) if(WITH_CUDA) gather_srcs(cinnapi_src SRCS transform_gpu_forloop.cc) diff --git a/paddle/cinn/optim/compute_inline_expand.cc b/paddle/cinn/optim/compute_inline_expand.cc index f6b7c6f24e2b8..9c66064d2773d 100644 --- a/paddle/cinn/optim/compute_inline_expand.cc +++ b/paddle/cinn/optim/compute_inline_expand.cc @@ -113,7 +113,14 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> { CHECK(tensor); // fix computeAt case auto shapes = tensor->shape; - CHECK_EQ(shapes.size(), node->indices.size()); + PADDLE_ENFORCE_EQ( + shapes.size(), + node->indices.size(), + phi::errors::InvalidArgument( + "The size of tensor shape and node indices is not equal," + "where tensor shape:%d but node indices:%d.", + shapes.size(), + node->indices.size())); for (int i = 0; i < shapes.size(); i++) { if (cinn::common::is_zero(shapes[i] - 1)) { node->indices[i] = Expr(0); diff --git a/paddle/cinn/optim/eliminate_common_factor_of_local_index.cc b/paddle/cinn/optim/eliminate_common_factor_of_local_index.cc new file mode 100644 index 0000000000000..020c32b60845d --- /dev/null +++ b/paddle/cinn/optim/eliminate_common_factor_of_local_index.cc @@ -0,0 +1,305 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/eliminate_common_factor_of_local_index.h" + +#include + +#include "paddle/cinn/common/cas.h" +#include "paddle/cinn/ir/ir.h" +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/utils/external_func_names.h" +#include "paddle/cinn/utils/string.h" + +namespace cinn { +namespace optim { +namespace { + +class GatherLocalIndexVisitor : public ir::IRMutator<> { + public: + void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + const std::unordered_map>>& + local_var_to_indexes() const { + return local_var_to_indexes_; + } + + private: + void Visit(const ir::Store* op, Expr* expr) override { + auto store = expr->As(); + + ir::IRMutator<>::Visit(op, expr); + if (!store->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (store->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::GPULocal) { + local_var_to_indexes_[store->tensor.as_tensor_ref()->buffer->name] + .push_back(store->indices); + } + } + + void Visit(const ir::Load* op, Expr* expr) override { + auto load = expr->As(); + + if (load->is_addr_scalar()) { + return; + } + if (!load->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + + if (load->tensor.as_tensor_ref()->buffer->memory_type == + ir::MemoryType::GPULocal) { + local_var_to_indexes_[load->tensor.as_tensor_ref()->buffer->name] + .push_back(load->indices); + } + ir::IRMutator<>::Visit(op, expr); + } + + std::unordered_map>> + local_var_to_indexes_; +}; + +class GatherProhibitedLocalVarVisitor : public ir::IRMutator<> { + public: + void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + const std::unordered_set& prohibited_local_vars() const { + return prohibited_local_vars_; + } + + private: + void Visit(const ir::Store* op, Expr* expr) override { + auto store = expr->As(); + + ir::IRMutator<>::Visit(op, expr); + if (!store->tensor.as_tensor_ref()->buffer.defined()) { + return; + } + if (store->tensor.as_tensor_ref()->buffer->memory_type != + ir::MemoryType::GPULocal) { + return; + } + const auto& local_var_name = store->tensor.as_tensor_ref()->buffer->name; + if (store->value.As()) { + const auto& call_name = store->value.As()->name; + if (cinn::utils::GetProhibitScheduleExternalFuncNames().count(call_name) > + 0) { + prohibited_local_vars_.insert(local_var_name); + } + } + } + + std::unordered_set prohibited_local_vars_; +}; + +std::unordered_map>> +EraseProhibitedLocalVar( + const std::unordered_map>>& + local_var_to_indexes, + const std::unordered_set& prohibited_local_vars) { + std::unordered_map>> ret{}; + for (const auto& [local_var, indexes] : local_var_to_indexes) { + if (prohibited_local_vars.count(local_var) == 0) { + ret[local_var] = indexes; + } + } + return ret; +} + +std::unordered_map>> +CollectLocalVarToIndexes(ir::Expr* expr) { + GatherLocalIndexVisitor gather_local_index_visitor; + gather_local_index_visitor(expr); + + GatherProhibitedLocalVarVisitor gather_prohibited_local_var_visitor; + gather_prohibited_local_var_visitor(expr); + + return EraseProhibitedLocalVar( + gather_local_index_visitor.local_var_to_indexes(), + gather_prohibited_local_var_visitor.prohibited_local_vars()); +} + +template +void VisitEachRowExpr(const std::vector>& indexes, + std::size_t var_idx, + DoEachT&& DoEach) { + for (std::size_t i = 0; i < indexes.size(); ++i) { + DoEach(indexes[i][var_idx]); + } +} + +int ExtractNumberFromExpr(const ir::Expr& expr) { + ir::Expr simplied_expr = cinn::common::AutoSimplify(expr); + if (simplied_expr.is_constant()) { + return static_cast(simplied_expr.get_constant()); + } else if (expr.As()) { + auto mul = expr.As(); + return std::max(ExtractNumberFromExpr(mul->a()), + ExtractNumberFromExpr(mul->b())); + } else { + VLOG(6) << "Not supported for calculating gcd, expr = " << expr; + return 1; + } + PADDLE_THROW(phi::errors::Fatal("Dead code")); +} + +int gcd(int a, int b) { + if (b == 0) { + return a; + } + return gcd(b, a % b); +} + +// Note (Hongyu Jia): Currently, we only calculates gcd of int factors. +ir::Expr CalculateGcdForExprPair(const ir::Expr& expr1, const ir::Expr& expr2) { + return ir::Expr( + gcd(ExtractNumberFromExpr(expr1), ExtractNumberFromExpr(expr2))); +} + +std::vector CalculateIndexVectorGcd( + const std::string& local_var, + const std::vector>& indexes) { + CHECK_GE(indexes.size(), 2) + << "We should guarantee indexes.size() >= 2, because local variable " + << local_var << " should at least load and store once."; + for (std::size_t i = 1; i < indexes.size(); ++i) { + // NOTE(Hongyu Jia): Ideally, we can guarantee the size of indexes are equal + // under flags FLAGS_cinn_new_group_scheduler=1 and + // FLAGS_cinn_bucket_compile=1. However, some unit tests (e.g. + // test_resnet_cinn, test_instance_norm_op) are still running with the + // deprecated OpScheduler, and the ir::Expr will break this guarantee after + // IRCudaScheduleBlockReduce function. So we have to relax the restriction + // here. + if (indexes[i].size() != indexes[0].size()) { + LOG(WARNING) << "Not supported for calculating gcd, local var = " + << local_var; + return std::vector( + std::max(indexes[0].size(), indexes[i].size()), ir::Expr(1)); + } + } + std::size_t var_index_size = indexes[0].size(); + std::vector gcd_indexes; + for (std::size_t var_idx = 0; var_idx < var_index_size; ++var_idx) { + std::optional gcd_expr; + VisitEachRowExpr(indexes, var_idx, [&](const ir::Expr& expr) { + if (gcd_expr.has_value()) { + gcd_expr = CalculateGcdForExprPair(gcd_expr.value(), expr); + } else { + gcd_expr = expr; + } + }); + gcd_indexes.push_back(gcd_expr.value()); + } + return gcd_indexes; +} + +std::unordered_map> CalculateLocalIndexGcd( + const std::unordered_map>>& + local_var_to_indexes) { + std::unordered_map> + local_var_to_gcd_factor; + for (const auto& [local_var, indexes] : local_var_to_indexes) { + local_var_to_gcd_factor[local_var] = + CalculateIndexVectorGcd(local_var, indexes); + } + return local_var_to_gcd_factor; +} + +class DivideGcdForLocalIndexVisitor : public ir::IRMutator<> { + public: + DivideGcdForLocalIndexVisitor( + const std::unordered_map>& + local_var_to_gcd_factor) + : local_var_to_gcd_factor_(local_var_to_gcd_factor) {} + + void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Store* op, Expr* expr) override { + auto store = expr->As(); + + ir::IRMutator<>::Visit(op, expr); + const auto& store_buffer = store->tensor.as_tensor_ref()->buffer; + if (!store_buffer.defined()) { + return; + } + + if (store_buffer->memory_type == ir::MemoryType::GPULocal) { + if (local_var_to_gcd_factor_.count(store_buffer->name) == 0) { + return; + } + const auto& gcd_factors = local_var_to_gcd_factor_.at(store_buffer->name); + for (std::size_t i = 0; i < store->indices.size(); ++i) { + if (gcd_factors[i] != ir::Expr(0)) { + store->indices[i] = cinn::common::AutoSimplify( + ir::Div::Make(store->indices[i], gcd_factors[i])); + } + } + } + } + + void Visit(const ir::Load* op, Expr* expr) override { + auto load = expr->As(); + + if (load->is_addr_scalar()) { + return; + } + const auto& load_buffer = load->tensor.as_tensor_ref()->buffer; + if (!load_buffer.defined()) { + return; + } + + if (load_buffer->memory_type == ir::MemoryType::GPULocal) { + if (local_var_to_gcd_factor_.count(load_buffer->name) == 0) { + return; + } + const auto& gcd_factors = local_var_to_gcd_factor_.at(load_buffer->name); + for (std::size_t i = 0; i < load->indices.size(); ++i) { + if (gcd_factors[i] != ir::Expr(0)) { + load->indices[i] = cinn::common::AutoSimplify( + ir::Div::Make(load->indices[i], gcd_factors[i])); + } + } + } + ir::IRMutator<>::Visit(op, expr); + } + std::unordered_map> + local_var_to_gcd_factor_; +}; + +} // namespace + +void EliminateCommonFactorOfLocalIndex(ir::Expr* expr) { + VLOG(2) << "Before EliminateCommonFactorOfLocalIndex, Expr = \n" << *expr; + + std::unordered_map>> + local_var_to_indexes = CollectLocalVarToIndexes(expr); + + std::unordered_map> + local_var_to_gcd_factor = CalculateLocalIndexGcd(local_var_to_indexes); + + DivideGcdForLocalIndexVisitor divide_gcd_for_local_index_visitor( + local_var_to_gcd_factor); + divide_gcd_for_local_index_visitor(expr); + + VLOG(2) << "After EliminateCommonFactorOfLocalIndex, Expr = \n" << *expr; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/eliminate_common_factor_of_local_index.h b/paddle/cinn/optim/eliminate_common_factor_of_local_index.h new file mode 100644 index 0000000000000..243f36490f31a --- /dev/null +++ b/paddle/cinn/optim/eliminate_common_factor_of_local_index.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 CINN Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Given Expr AST, analyze the Greatest Common Divisor (GCD) of local variable + * indexes. Then each local index divides it's GCD value. This optimization + * could help analysising the space allocated for local variables. + */ +void EliminateCommonFactorOfLocalIndex(ir::Expr* expr); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/eliminate_common_global_memory_read.cc b/paddle/cinn/optim/eliminate_common_global_memory_read.cc new file mode 100644 index 0000000000000..d9fa523064e00 --- /dev/null +++ b/paddle/cinn/optim/eliminate_common_global_memory_read.cc @@ -0,0 +1,297 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/eliminate_common_global_memory_read.h" + +#include "paddle/cinn/common/cas.h" +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/utils/ir_compare.h" +#include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/optim/replace_var_with_expr.h" +#include "paddle/common/enforce.h" + +namespace cinn { +namespace optim { + +namespace { + +struct ForVarExtent { + ir::Var loop_var; + ir::Expr extent; +}; + +struct IndicesAndExtent { + std::vector indices; + std::vector for_var_extents; +}; + +std::unordered_map ConstructForVarReplaceMap( + const std::vector& lhs_extents, + const std::vector& rhs_extents) { + std::unordered_map ret; + std::unordered_set visited_rhs_index; + for (const auto& [lhs_var, lhs_extent] : lhs_extents) { + for (std::size_t i = 0; i < rhs_extents.size(); ++i) { + const auto& [rhs_var, rhs_extent] = rhs_extents[i]; + if (cinn::common::AutoSimplify(ir::Sub::Make(lhs_extent, rhs_extent)) == + ir::Expr(0) && + visited_rhs_index.count(i) == 0) { + ret[lhs_var] = rhs_var; + visited_rhs_index.insert(i); + break; + } + } + } + return ret; +} + +struct GlobalTensorInfoCollector : public ir::IRMutator { + public: + void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + std::unordered_set GetEliminateBufferNames() const { + auto IndiceToExprWithForVar = + [&](ir::Expr indice, + const std::unordered_map& for_var_map) + -> ir::Expr { + ir::Expr ret = ir::ir_utils::IRCopy(indice); + for (const auto& [lhs_var, rhs_var] : for_var_map) { + ReplaceVarWithExpr(&ret, lhs_var, ir::ir_utils::IRCopy(rhs_var)); + } + return ret; + }; + + auto IndiceAndExtentEqual = + [&](const IndicesAndExtent& indice_and_extent1, + const IndicesAndExtent& indice_and_extent2) -> bool { + const auto& indice1 = indice_and_extent1.indices; + const auto& indice2 = indice_and_extent2.indices; + if (indice1.size() != indice2.size()) return false; + + std::unordered_map for_var_map = + ConstructForVarReplaceMap(indice_and_extent1.for_var_extents, + indice_and_extent2.for_var_extents); + + for (size_t i = 0; i < indice1.size(); ++i) { + ir::Expr lhs = IndiceToExprWithForVar(indice1.at(i), for_var_map); + ir::Expr rhs = IndiceToExprWithForVar(indice2.at(i), for_var_map); + if (cinn::common::AutoSimplify(ir::Sub::Make(lhs, rhs)) != + ir::Expr(0)) { + return false; + } + } + return true; + }; + + auto AllIndiceAndExtentEqual = + [&](const std::vector& indice_and_extent) -> bool { + PADDLE_ENFORCE_GE( + indice_and_extent.size(), + 2, + ::common::errors::InvalidArgument( + "The size of indice_and_extent should greater_equal to 2")); + for (size_t i = 1; i < indice_and_extent.size(); ++i) { + if (!IndiceAndExtentEqual(indice_and_extent[0], indice_and_extent[i])) + return false; + } + return true; + }; + + auto IndiceContainsLoad = + [&](const IndicesAndExtent& indice_and_extent) -> bool { + for (const auto& index : indice_and_extent.indices) { + std::set load_tensors = ir::ir_utils::CollectLoadTensors( + index, /*teller=*/[&](const Expr*) -> bool { return true; }); + if (load_tensors.size() > 0) { + return true; + } + } + return false; + }; + + auto IsGlobalTensorNeedEliminate = + [&](const std::vector& indice_and_extent) -> bool { + if (indice_and_extent.size() <= 1) return false; + if (IndiceContainsLoad(indice_and_extent[0])) return false; + return AllIndiceAndExtentEqual(indice_and_extent); + }; + + std::unordered_set global_buffer_name; + for (const auto& [buffer_name, indice_and_extent] : + buffer_to_indice_and_extent_) { + if (IsGlobalTensorNeedEliminate(indice_and_extent)) { + global_buffer_name.insert(buffer_name); + } + } + return global_buffer_name; + } + + private: + void Visit(const ir::ScheduleBlockRealize* op, ir::Expr* expr) override { + const auto* sbr_node = expr->As(); + CHECK(sbr_node); + const auto& iter_values = sbr_node->iter_values; + const auto* sb_node = sbr_node->schedule_block.As(); + const auto& iter_vars = sb_node->iter_vars; + PADDLE_ENFORCE_EQ( + iter_values.size(), + iter_vars.size(), + ::common::errors::InvalidArgument( + "The size of iter_values should equal to the size of iter_vars, as " + "they comes from the same ScheduleBlockRealize")); + + for (std::size_t i = 0; i < iter_values.size(); ++i) { + var_to_sb_expr_[iter_vars[i]] = iter_values[i]; + } + ir::IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::For* op, ir::Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + for_var_extents_.push_back( + {node->loop_var, ir::ir_utils::IRCopy(node->extent)}); + ir::IRMutator<>::Visit(op, expr); + for_var_extents_.pop_back(); + } + + void Visit(const ir::Load* op, ir::Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + const auto& load_buffer = node->tensor.as_tensor_ref()->buffer; + if (load_buffer->memory_type == ir::MemoryType::Heap) { + std::vector tensor_indices; + for (const auto& indice : node->indices) { + ir::Expr new_indice = ir::ir_utils::IRCopy(indice); + for (const auto& [var, sb_expr] : var_to_sb_expr_) { + ReplaceVarWithExpr(&new_indice, var, ir::ir_utils::IRCopy(sb_expr)); + } + tensor_indices.push_back(new_indice); + } + buffer_to_indice_and_extent_[load_buffer->name].push_back( + {tensor_indices, for_var_extents_}); + } + } + + std::vector for_var_extents_; + std::unordered_map var_to_sb_expr_; + std::unordered_map> + buffer_to_indice_and_extent_; +}; + +struct CommonGlobalMemoryEliminator : public ir::IRMutator { + CommonGlobalMemoryEliminator( + const std::unordered_set& eliminate_buffer_names) + : eliminate_buffer_names_(eliminate_buffer_names) {} + + void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } + + private: + void Visit(const ir::Block* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + current_block_ = node; + IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::ScheduleBlockRealize* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + current_sbr_ = node; + IRMutator<>::Visit(op, expr); + } + + void Visit(const ir::Load* op, Expr* expr) override { + auto* node = expr->As(); + CHECK(node); + const auto& buffer_name = node->tensor.as_tensor_ref()->buffer->name; + if (eliminate_buffer_names_.count(buffer_name) == 0) { + return; + } + + if (global_buffer_to_local_buffer_.count(buffer_name) == 0) { + InsertLocalTensorBlock(node, buffer_name); + } + SubstituteGlobalTensor(node, buffer_name); + } + + void InsertLocalTensorBlock(ir::Load* load_node, + const std::string& buffer_name) { + ir::Expr sb = ir::ir_utils::IRCopy(current_sbr_->schedule_block); + ir::ScheduleBlock* sb_node = sb.As(); + CHECK(sb_node); + + const auto& old_tensor = load_node->tensor.as_tensor_ref(); + ir::Expr new_tensor = + ir::_Tensor_::Make(old_tensor->name + "_local", + old_tensor->type(), + ir::ir_utils::IRCopy(old_tensor->shape), + ir::ir_utils::IRCopy(old_tensor->domain), + old_tensor->reduce_axis); + new_tensor.as_tensor_ref()->WithBuffer( + "local", new_tensor.as_tensor_ref()->name + "_buffer"); + ir::Expr new_body = + ir::Store::Make(new_tensor, + ir::ir_utils::IRCopy(ir::Expr(load_node)), + ir::ir_utils::IRCopy(load_node->indices)); + ir::Expr new_sb = ir::ScheduleBlock::Make( + sb_node->iter_vars, {}, {}, sb_node->name + "_local", new_body); + + ir::Expr new_sbr = ir::ScheduleBlockRealize::Make( + ir::ir_utils::IRCopy(current_sbr_->iter_values), new_sb); + PADDLE_ENFORCE_EQ( + global_buffer_to_local_buffer_.count(buffer_name), + 0, + ::common::errors::InvalidArgument( + "buffer_name %s should not be in global_buffer_to_local_buffer_", + buffer_name)); + global_buffer_to_local_buffer_[buffer_name] = new_tensor; + current_block_->stmts.insert(current_block_->stmts.begin(), new_sbr); + } + + void SubstituteGlobalTensor(ir::Load* load_node, + const std::string& buffer_name) { + PADDLE_ENFORCE_GT( + global_buffer_to_local_buffer_.count(buffer_name), + 0, + ::common::errors::InvalidArgument( + "global_buffer_to_local_buffer_ should contain buffer_name %s", + buffer_name)); + load_node->tensor = global_buffer_to_local_buffer_[buffer_name]; + } + + std::unordered_set eliminate_buffer_names_; + std::unordered_map global_buffer_to_local_buffer_; + + ir::Block* current_block_; + ir::ScheduleBlockRealize* current_sbr_; +}; + +} // namespace + +void EliminateCommonGlobalMemoryRead(Expr* e) { + VLOG(4) << "Before EliminateCommonGlobalMemoryRead: \n" << *e; + GlobalTensorInfoCollector collector; + collector(e); + + const auto& eliminate_buffer_names = collector.GetEliminateBufferNames(); + + CommonGlobalMemoryEliminator eliminator(eliminate_buffer_names); + eliminator(e); + VLOG(4) << "After EliminateCommonGlobalMemoryRead: \n" << *e; +} + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/eliminate_common_global_memory_read.h b/paddle/cinn/optim/eliminate_common_global_memory_read.h new file mode 100644 index 0000000000000..0db44e2b25444 --- /dev/null +++ b/paddle/cinn/optim/eliminate_common_global_memory_read.h @@ -0,0 +1,28 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/** + * Remove common global memory read and substitue them with local memory read. + */ +void EliminateCommonGlobalMemoryRead(Expr* e); + +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/if_fusion.cc b/paddle/cinn/optim/if_fusion.cc new file mode 100644 index 0000000000000..4e66748208a72 --- /dev/null +++ b/paddle/cinn/optim/if_fusion.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/cinn/optim/if_fusion.h" + +#include +#include "paddle/cinn/ir/ir_mutator.h" +#include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/utils/ir_compare.h" +#include "paddle/cinn/optim/ir_simplify.h" + +#define VisitImpl(_TYPE) \ + void Visit(const ir::_TYPE *op, Expr *expr) override { \ + last_op = Expr(const_cast(op)); \ + ir::IRMutator<>::Visit(op, expr); \ + } + +namespace cinn { +namespace optim { + +namespace { + +struct IfFusionMutator : public ir::IRMutator { + void operator()(Expr *expr) { Visit(expr, expr); } + + private: + void Visit(const ir::IfThenElse *op, Expr *expr) override { + // the implementation of ifFusion + // compare the last condition with current condition + // judge whether last_op is nullptr + if (!last_op.get()) { + last_op = Expr(const_cast(op)); + return; + } + + // judge whether last_op is IfThenElse + ir::IfThenElse *lop = last_op.As(); + if (!lop) { + last_op = Expr(const_cast(op)); + return; + } + + // judge whether condition is same + bool is_need_fuse = ir::ir_utils::IRCompare(op->condition, lop->condition); + if (is_need_fuse) { + // do fusion (cop.true_case <-> lop.true_case) + Fuse(op->true_case, lop->true_case); + + // support for recursive true case merge + Expr tmp = last_op; + Visit(&lop->true_case, &lop->true_case); + last_op = tmp; + + if (op->false_case.defined() && lop->false_case.defined()) { + Fuse(op->false_case, lop->false_case); + // support for recusive false case merge + tmp = last_op; + Visit(&lop->false_case, &lop->false_case); + last_op = tmp; + } + + // Remove the op which refers to current ir::IfThenElse block, + // because this block is merged with previous ir::IfThenElse block, + // so blank now. + // push the elements position which will be deleted after visit current + // block. + RecordIndexForErase(Expr(const_cast(op)), cur_block); + } + + if (!is_need_fuse) { + last_op = Expr(const_cast(op)); + } + } + + void Visit(const ir::Block *op, Expr *expr) override { + int element_num_before_visit = erase_elements_ind.size(); + ir::Block *last_block = (cur_block); + cur_block = const_cast(op); + ir::IRMutator<>::Visit(op, expr); + cur_block = last_block; + + EraseBlankElements(const_cast(op), element_num_before_visit); + } + + // Recode for the sequent Erasure + void RecordIndexForErase(Expr op, ir::Block *cur_block) { + for (int i = 0; i < cur_block->stmts.size(); i++) { + if (ir::ir_utils::IRCompare(cur_block->stmts[i], op)) { + erase_elements_ind.push(i); + return; + } + } + } + + // Erase the blank block + void EraseBlankElements(ir::Block *op, int stack_upper_bound) { + while (erase_elements_ind.size() > stack_upper_bound) { + int erase_pos = erase_elements_ind.top(); + erase_elements_ind.pop(); + op->stmts.erase(op->stmts.begin() + erase_pos); + } + } + + VisitImpl(Expr); + VisitImpl(ScheduleBlock); + VisitImpl(For); + VisitImpl(IntImm); + VisitImpl(UIntImm); + VisitImpl(FloatImm); + VisitImpl(StringImm); + VisitImpl(Cast); + VisitImpl(PolyFor); + VisitImpl(Select); + VisitImpl(Call); + VisitImpl(_Module_); + VisitImpl(_Var_); + VisitImpl(Load); + VisitImpl(Store); + VisitImpl(Alloc); + VisitImpl(Free); + VisitImpl(_Buffer_); + VisitImpl(_Tensor_); + VisitImpl(_LoweredFunc_); + VisitImpl(Let); + VisitImpl(Reduce); + VisitImpl(Ramp); + VisitImpl(Broadcast); + VisitImpl(FracOp); + VisitImpl(Product); + VisitImpl(Sum); + VisitImpl(PrimitiveNode); + VisitImpl(IntrinsicOp); + VisitImpl(_BufferRange_); + VisitImpl(_Dim_); + + void Fuse(Expr ne, Expr oe) { + // fuse old expr with new expr, merge the stmts in them. + ir::Block *neb = ne.As(); + ir::Block *oeb = oe.As(); + +#ifdef __cpp_lib_containers_range + oeb->stmts.append_range(neb->stmts); +#else + oeb->stmts.insert(oeb->stmts.end(), neb->stmts.cbegin(), neb->stmts.cend()); +#endif + + neb->stmts.clear(); + } + + std::stack erase_elements_ind; + + // record the condition of it if last block is if-block, nullptr otherwise. + Expr last_op = Expr(nullptr); + + ir::Block *cur_block; +}; // IfFusionMutator +} // namespace + +void IfFusion(Expr *expr) { IfFusionMutator()(expr); } +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/if_fusion.h b/paddle/cinn/optim/if_fusion.h new file mode 100644 index 0000000000000..abf7bb88b6593 --- /dev/null +++ b/paddle/cinn/optim/if_fusion.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/cinn/ir/ir.h" + +namespace cinn { +namespace optim { + +/* + * Do fusion with the adjaccnt if-block. + */ +void IfFusion(Expr *expr); +} // namespace optim +} // namespace cinn diff --git a/paddle/cinn/optim/insert_debug_log_callee.cc b/paddle/cinn/optim/insert_debug_log_callee.cc index fdab377bc88cc..1bcfd34bbaf9c 100644 --- a/paddle/cinn/optim/insert_debug_log_callee.cc +++ b/paddle/cinn/optim/insert_debug_log_callee.cc @@ -139,7 +139,7 @@ struct InsertDebugLogCalleeMutator : public ir::IRMutator<> { ir::IRMutator<>::Visit(&node->body, &node->body); auto deal_with_exprs = - [&](std::vector *exprs) { // deal with op->argument_preapre_exprs + [&](std::vector *exprs) { // deal with op->argument_prepare_exprs std::vector new_stmts; for (auto &expr : *exprs) { auto msg = diff --git a/paddle/cinn/optim/map_extern_call.cc b/paddle/cinn/optim/map_extern_call.cc index c462fd1aa0f01..d260cea233dd4 100644 --- a/paddle/cinn/optim/map_extern_call.cc +++ b/paddle/cinn/optim/map_extern_call.cc @@ -65,7 +65,13 @@ void MapExternCall(Expr *e, Target target) { void DealWithCpuIntrinsics(ir::Call *node, Expr *expr) { if (kExternFp32CallsCPU.count(node->name)) { - CHECK_GE(node->read_args.size(), 1UL); + PADDLE_ENFORCE_GE( + node->read_args.size(), + 1UL, + phi::errors::InvalidArgument( + "The size of node's read args is incorrect." + "Expected size is greater than or equal to 1, but receive %d.", + node->read_args.size())); CHECK(node->read_args.front().type().is_float()) << "CPU extern call intrinsics only support float now! Please " "check."; diff --git a/paddle/cinn/optim/optimize.cc b/paddle/cinn/optim/optimize.cc index 567cb2e2b6021..bd6690838c09e 100644 --- a/paddle/cinn/optim/optimize.cc +++ b/paddle/cinn/optim/optimize.cc @@ -22,6 +22,7 @@ #include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h" #include "paddle/cinn/optim/extern_call_process.h" #include "paddle/cinn/optim/fold_cinn_call_arguments.h" +#include "paddle/cinn/optim/if_fusion.h" #include "paddle/cinn/optim/insert_debug_log_callee.h" #include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/lower_function_call_bind_vars.h" @@ -80,6 +81,9 @@ Expr Optimize(Expr e, Simplify(&copied); VLOG(10) << "After Optimize Simplify:" << copied; + IfFusion(&copied); + VLOG(10) << "After Optimize IfFusion" << copied; + if (runtime_debug_info) { LOG(WARNING) << "Turn on runtime debug information output"; InsertDebugLogCallee(&copied); diff --git a/paddle/cinn/optim/remove_schedule_block.cc b/paddle/cinn/optim/remove_schedule_block.cc index 007174801550d..404840b59aa9d 100644 --- a/paddle/cinn/optim/remove_schedule_block.cc +++ b/paddle/cinn/optim/remove_schedule_block.cc @@ -35,7 +35,13 @@ struct ScheduleBlockRemover : public ir::IRMutator { CHECK(schedule_block); auto& iter_vars = schedule_block->iter_vars; Expr body = schedule_block->body; - CHECK_EQ(iter_vars.size(), iter_values.size()); + PADDLE_ENFORCE_EQ(iter_vars.size(), + iter_values.size(), + phi::errors::InvalidArgument( + "The size of iter vars and iter values is not equal," + "where iter vars:%d but iter values:%d.", + iter_vars.size(), + iter_values.size())); for (int i = 0; i < iter_vars.size(); i++) { optim::ReplaceVarWithExpr(&body, iter_vars[i], iter_values[i]); } diff --git a/paddle/cinn/optim/replace_call_with_expr.cc b/paddle/cinn/optim/replace_call_with_expr.cc index 00fbca0fca623..d6ba57210ee45 100644 --- a/paddle/cinn/optim/replace_call_with_expr.cc +++ b/paddle/cinn/optim/replace_call_with_expr.cc @@ -36,7 +36,8 @@ struct ReplaceCallWithExprModifier : public ir::IRMutator<> { VLOG(3) << "Processing Call node " << *op; if (statement_ != node->name) return; - Expr expr_candidate = ir::ir_utils::IRCopy(candidate_); + Expr expr_candidate = + ir::ir_utils::IRCopy(candidate_, /* copy_buffer_node = */ false); VLOG(3) << "Original candidate expr: " << candidate_; VLOG(3) << "Copied candidate expr: " << expr_candidate; @@ -62,7 +63,7 @@ void ReplaceIslCallWithExpr(Expr *e, const Expr &candidate, const std::map &axis_map) { VLOG(3) << "ReplaceCallWithExpr, original expression: " << candidate; - Expr copied = ir::ir_utils::IRCopy(candidate); + Expr copied = ir::ir_utils::IRCopy(candidate, /* copy_buffer_node = */ false); // update the axis in the copied expression. // we treat the Store node as the normal statement, the others like Call node diff --git a/paddle/cinn/optim/replace_cross_thread_reduction.cc b/paddle/cinn/optim/replace_cross_thread_reduction.cc index 2524874bace60..56f1802dcd07e 100644 --- a/paddle/cinn/optim/replace_cross_thread_reduction.cc +++ b/paddle/cinn/optim/replace_cross_thread_reduction.cc @@ -19,6 +19,7 @@ #include "paddle/cinn/optim/replace_cross_thread_reduction.h" #include +#include "paddle/cinn/adt/adt.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/hlir/pe/reduction.h" #include "paddle/cinn/ir/ir.h" @@ -46,7 +47,11 @@ struct CrossThreadReductionReplacer : public ir::IRMutator<> { bool CanReplace(const ir::ScheduleBlockRealize* block_realize) { const ir::ScheduleBlock* schedule_block = block_realize->schedule_block.As(); - CHECK_NOTNULL(schedule_block); + + PADDLE_ENFORCE_NOT_NULL( + schedule_block, + phi::errors::PreconditionNotMet( + "The schedule block pointer in CanReplace must not be null.")); if (block_realize->schedule_block.As()->name.substr( 0, 4) == "root") { @@ -67,20 +72,27 @@ struct CrossThreadReductionReplacer : public ir::IRMutator<> { if (x->as_var()) { reduce_var_names.insert(x->as_var()->name); } + return false; }); } + auto IsThreadBindOnReduceAxis = [&](const ir::For* for_node) { + return reduce_var_names.count(for_node->loop_var->name) > 0 && + for_node->is_gpu_thread_binded(); + }; + std::vector thread_binded_reduce_loop_indices; + bool is_thread_binded_inner_loop = false; for (int i = 0; i < cur_loops_.size(); ++i) { - if (reduce_var_names.count(cur_loops_[i].As()->loop_var->name) > - 0) { - if (cur_loops_[i].As()->is_gpu_thread_binded()) { - if (ir::GetLoopExtent(cur_loops_[i]) > 1024) { - return false; - } - thread_binded_reduce_loop_indices.push_back(i); + if (is_thread_binded_inner_loop || + IsThreadBindOnReduceAxis(cur_loops_[i].As())) { + if (ir::GetLoopExtent(cur_loops_[i]) > 1024) { + return false; } + + is_thread_binded_inner_loop = true; + thread_binded_reduce_loop_indices.push_back(i); } } if (thread_binded_reduce_loop_indices.size() == 0 || @@ -126,18 +138,35 @@ struct CrossThreadReductionReplacer : public ir::IRMutator<> { const ir::ScheduleBlock* schedule_block = expr->schedule_block.As(); - CHECK_NOTNULL(schedule_block); + PADDLE_ENFORCE_NOT_NULL( + schedule_block, + phi::errors::PreconditionNotMet( + "The schedule block pointer in Visit must not be null.")); ir::Expr original_update_body = schedule_block->body; ir::Expr original_update_stmt; CHECK(original_update_body.As() || original_update_body.As()); if (original_update_body.As()) { - CHECK_EQ(original_update_body.As()->stmts.size(), 1); + PADDLE_ENFORCE_EQ( + original_update_body.As()->stmts.size(), + 1, + phi::errors::InvalidArgument( + "The size of stmts is incorrect." + "Expected size is 1, but receive %d.", + original_update_body.As()->stmts.size())); original_update_stmt = original_update_body.As()->stmts[0]; } else if (original_update_body.As()) { original_update_stmt = original_update_body; } + const auto& IsWarpReduce = cinn::adt::match{ + [&](const ir::NoneReduceMethod&) { return ir::Expr(false); }, + [&](const ir::WarpReduceMethod&) { return ir::Expr(true); }, + [&](const ir::BlockReduceMethod&) { return ir::Expr(false); }, + }; + ir::Expr return_warp = + std::visit(IsWarpReduce, schedule_block->reduce_method); + #define REPLACE_TO_EXTERNAL_CALL(Op) \ if (original_update_stmt.As()->value.As()) { \ auto* node = original_update_stmt.As()->value.As(); \ @@ -154,8 +183,8 @@ struct CrossThreadReductionReplacer : public ir::IRMutator<> { tmp_buffer->dtype = tmp_dtype; \ tmp_buffer->memory_type = ir::MemoryType::GPUShared; \ shm_buffer_.insert(tmp_buffer); \ - original_update_stmt.As()->value = \ - lang::CallExtern(reduce_func_name, {node->b(), tmp_buffer}); \ + original_update_stmt.As()->value = lang::CallExtern( \ + reduce_func_name, {node->b(), tmp_buffer, return_warp}); \ } REPLACE_TO_EXTERNAL_CALL(ir::Add) diff --git a/paddle/cinn/optim/replace_cross_thread_reduction_test.cc b/paddle/cinn/optim/replace_cross_thread_reduction_test.cc index d7bd9f6defc49..9f616c7f8a5f2 100644 --- a/paddle/cinn/optim/replace_cross_thread_reduction_test.cc +++ b/paddle/cinn/optim/replace_cross_thread_reduction_test.cc @@ -71,7 +71,7 @@ TEST(CrossThreadReductionReplacer, basic) { ScheduleBlock(B) { i0_0, i1 = axis.bind(i, reduce_j) - B[i0_0] = cinn_block_reduce_sum_fp32_internal_shm(A[i0_0, i1], _Buffer_(shm32__fp32_reduce)) + B[i0_0] = cinn_block_reduce_sum_fp32_internal_shm(A[i0_0, i1], _Buffer_(shm32__fp32_reduce), false) } } } diff --git a/paddle/cinn/optim/resize_buffer.cc b/paddle/cinn/optim/resize_buffer.cc index e73929a97aa57..2ec4e172b3fc7 100644 --- a/paddle/cinn/optim/resize_buffer.cc +++ b/paddle/cinn/optim/resize_buffer.cc @@ -16,14 +16,17 @@ #include #include "paddle/cinn/common/cas.h" +#include "paddle/cinn/common/integer_set.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" +#include "paddle/cinn/ir/op/ir_operators.h" #include "paddle/cinn/ir/utils/ir_copy.h" #include "paddle/cinn/optim/replace_mod_to_max.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/utils/string.h" +PD_DECLARE_bool(group_schedule_tiling_first); namespace cinn { namespace optim { @@ -70,6 +73,7 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> { ir::Store* store = expr->As(); ir::Tensor tensor = store->tensor.as_tensor_ref(); AnalyzeTensorRange(store->indices, tensor); + AnalyzeBufferSize(store->indices, tensor); ir::IRMutator<>::Visit(op, expr); } @@ -102,10 +106,8 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> { private: void AnalyzeTensorRange(const std::vector& indices, const ir::Tensor& tensor) { - if (!tensor->buffer.defined() || - tensor->buffer->memory_type == ir::MemoryType::Heap) { - return; - } + if (!tensor->buffer.defined()) return; + if (tensor->buffer->memory_type == ir::MemoryType::Heap) return; std::vector indice_extent; for (int i = 0; i < indices.size(); ++i) { @@ -143,6 +145,45 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> { << buffer_name_to_indice_extent[buffer_name]; } + void AnalyzeBufferSize(const std::vector& indices, + const ir::Tensor& tensor) { + if (!tensor->buffer.defined()) return; + if (tensor->buffer->memory_type == ir::MemoryType::Heap) return; + + const std::string& buffer_name = tensor->buffer->name; + buffer_name_to_size[buffer_name] = AnalyzeBufferSize(indices); + VLOG(6) << "buffer_name = " << buffer_name + << ", size = " << buffer_name_to_size[buffer_name]; + } + + ir::Expr AnalyzeBufferSize(const std::vector& indices) { + const auto GetIterVarNames = + [](const std::vector& indices) -> std::set { + std::set iter_var_names; + for (const ir::Expr& e : indices) { + ir::ir_utils::CollectIRNodes(e, [&](const ir::Expr* x) { + if (x->as_var() && !x->as_var()->is_symbolic_constant) { + iter_var_names.insert(x->as_var()->name); + } + return false; + }); + } + return iter_var_names; + }; + + std::set iter_var_names = GetIterVarNames(indices); + ir::Expr size(1); + for (const std::string& var_name : iter_var_names) { + PADDLE_ENFORCE_GT(var_name_to_extent_.count(var_name), + 0, + ::common::errors::PreconditionNotMet( + "Cannot find the extent of var %s", var_name)); + size = common::AutoSimplify(size * var_name_to_extent_.at(var_name)); + } + + return size; + } + // A recursion function to calculate the max index range // The index may contain some vars like index = 8 * i / j, where we know the // range of i, j, we search all values to get the max index range @@ -168,13 +209,26 @@ class AnalyzeLoopVarRange : public ir::IRMutator<> { } } ir::Expr tmp = ir::Add::Make(copy, ir::Expr(1)); - ir::Expr simplify = common::AutoSimplify(tmp); - return simplify; + ir::Expr simplified = common::AutoSimplify(tmp); + if (simplified.As()) { + ir::Expr lhs = simplified.As()->a(); + ir::Expr rhs = simplified.As()->b(); + common::cas_intervals_t var_intervals = + common::CollectVarIntervalsOfExprs({lhs, rhs}); + common::SymbolicExprAnalyzer analyzer(var_intervals); + if (analyzer.ProveLE(lhs, rhs)) { + return lhs; + } else if (analyzer.ProveGE(lhs, rhs)) { + return rhs; + } + } + return simplified; } public: std::unordered_map> buffer_name_to_indice_extent; + std::unordered_map buffer_name_to_size; private: std::unordered_map var_name_to_extent_; @@ -184,8 +238,10 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> { public: ResizeBufferFromAnalyzedRange( const std::unordered_map>& - buffer_name_to_shape) - : buffer_name_to_shape_(buffer_name_to_shape) {} + buffer_name_to_shape, + const std::unordered_map& buffer_name_to_size) + : buffer_name_to_shape_(buffer_name_to_shape), + buffer_name_to_size_(buffer_name_to_size) {} void operator()(ir::Expr* expr) { ir::IRMutator<>::Visit(expr, expr); } @@ -208,8 +264,11 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> { return; } - load->tensor.as_tensor_ref()->shape = - load->tensor.as_tensor_ref()->buffer->shape; + const std::string& buffer_name = load->tensor.as_tensor_ref()->buffer->name; + if (buffer_name_to_shape_.count(buffer_name) > 0) { + load->tensor.as_tensor_ref()->shape = + buffer_name_to_shape_.at(buffer_name); + } // For the moment, align the load tensor indices with the tensor shape using // the trick method. A better way would be to modify the FlattenLoop @@ -224,25 +283,31 @@ class ResizeBufferFromAnalyzedRange : public ir::IRMutator<> { private: void ResizeTensor(ir::Tensor* tensor_ptr) { ir::Buffer buffer = (*tensor_ptr)->buffer; - if (!buffer.defined() || buffer->memory_type == ir::MemoryType::Heap) { - return; - } + if (!buffer.defined()) return; + if (buffer->memory_type == ir::MemoryType::Heap) return; + const std::string& buffer_name = buffer->name; if (buffer_name_to_shape_.count(buffer_name)) { const std::vector& analyzed_shape = buffer_name_to_shape_.at(buffer_name); VLOG(6) << "Replacing shape of tensor " << (*tensor_ptr)->name - << ", buffer " << buffer->name << ", with shape " - << analyzed_shape; - + << " with shape " << analyzed_shape; (*tensor_ptr)->shape = analyzed_shape; buffer->shape = analyzed_shape; } + if (FLAGS_group_schedule_tiling_first && + buffer_name_to_size_.count(buffer_name) > 0) { + const ir::Expr& analyzed_size = buffer_name_to_size_.at(buffer_name); + VLOG(6) << "Replacing shape of buffer " << buffer->name << " with shape " + << analyzed_size; + buffer->shape = {analyzed_size}; + } } private: const std::unordered_map>& buffer_name_to_shape_; + const std::unordered_map& buffer_name_to_size_; }; void ResizeBufferToMaxVarRange(ir::Expr* expr) { @@ -250,7 +315,8 @@ void ResizeBufferToMaxVarRange(ir::Expr* expr) { AnalyzeLoopVarRange analyze_functor; analyze_functor(expr); ResizeBufferFromAnalyzedRange resize_functor( - analyze_functor.buffer_name_to_indice_extent); + analyze_functor.buffer_name_to_indice_extent, + analyze_functor.buffer_name_to_size); resize_functor(expr); VLOG(6) << "After ResizeBufferToMaxVarRange, Expr = \n" << *expr; } diff --git a/paddle/cinn/optim/transform_gpu_forloop.cc b/paddle/cinn/optim/transform_gpu_forloop.cc index 7f2cc54f352eb..4e5d5f4c5ae8e 100644 --- a/paddle/cinn/optim/transform_gpu_forloop.cc +++ b/paddle/cinn/optim/transform_gpu_forloop.cc @@ -27,6 +27,7 @@ #include "paddle/cinn/ir/ir_mutator.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/utils/ir_copy.h" +#include "paddle/cinn/optim/eliminate_common_factor_of_local_index.h" #include "paddle/cinn/optim/ir_simplify.h" #include "paddle/cinn/optim/replace_var_with_expr.h" #include "paddle/cinn/optim/resize_buffer.h" @@ -221,7 +222,13 @@ class ReplaceIndexToBindExpr : public ir::IRMutator<> { schedule_block_realize->schedule_block.As() ->iter_vars; - CHECK_EQ(iter_values.size(), iter_vars.size()); + PADDLE_ENFORCE_EQ(iter_values.size(), + iter_vars.size(), + phi::errors::InvalidArgument( + "The size of iter values and iter vars is not equal," + "where iter values:%d but iter vars:%d.", + iter_values.size(), + iter_vars.size())); for (int idx = 0; idx < iter_values.size(); ++idx) { ReplaceVarWithExpr(&body, iter_vars[idx], iter_values[idx]); } @@ -260,7 +267,7 @@ class ReplaceLoopVarToGpu : public ir::IRMutator<> { ir::IRMutator<>::Visit(&for_ir->body, &for_ir->body); } void Visit(const ir::PolyFor *op, Expr *expr) override { - LOG(FATAL) << "Unkown PolyFor!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown PolyFor!")); } }; @@ -444,6 +451,8 @@ void OptimizeExprGPU(Expr *expr) { LocalAxisVisitor local_axis_visitor; local_axis_visitor(expr); + EliminateCommonFactorOfLocalIndex(expr); + ResizeBufferToMaxVarRange(expr); ReplaceVarToZero replace_var_to_zero; diff --git a/paddle/cinn/optim/transform_polyfor_to_for.cc b/paddle/cinn/optim/transform_polyfor_to_for.cc index ff29bb0058801..655619efe8cc9 100644 --- a/paddle/cinn/optim/transform_polyfor_to_for.cc +++ b/paddle/cinn/optim/transform_polyfor_to_for.cc @@ -99,17 +99,27 @@ struct PolyForWithSimpleConditionToForMutator : public ir::IRMutator { if (node->condition.As()) { auto le = node->condition.As(); CHECK(le->a().As()); - CHECK_EQ(le->b().As()->value, 0UL); + PADDLE_ENFORCE_EQ( + le->b().As()->value, + 0UL, + phi::errors::InvalidArgument("The value of le is incorrect." + "Expected value is 0, but receive %d.", + le->b().As()->value)); auto sub = le->a().As(); node->condition = ir::LE::Make(sub->a(), sub->b()); } else if (node->condition.As()) { auto lt = node->condition.As(); CHECK(lt->a().As()); - CHECK_EQ(lt->b().As()->value, 0UL); + PADDLE_ENFORCE_EQ( + lt->b().As()->value, + 0UL, + phi::errors::InvalidArgument("The value of lt is incorrect." + "Expected value is 0, but receive %d.", + lt->b().As()->value)); auto sub = lt->a().As(); node->condition = ir::LT::Make(sub->a(), sub->b()); } else { - LOG(FATAL) << "Unkown Type!"; + PADDLE_THROW(phi::errors::InvalidArgument("Unkown Type!")); } lt_n = node->condition.As(); diff --git a/paddle/cinn/optim/unroll_loops.cc b/paddle/cinn/optim/unroll_loops.cc index 9f2e8bf244e4c..276a633924991 100644 --- a/paddle/cinn/optim/unroll_loops.cc +++ b/paddle/cinn/optim/unroll_loops.cc @@ -62,7 +62,7 @@ struct UnrollMutator : public ir::IRMutator { void Visit(const ir::For* op, Expr* expr) override { IRMutator<>::Visit(op, expr); if (op->extent.As() == nullptr) { - VLOG(5) << "loop to be unrolled should have a contant extent"; + VLOG(5) << "loop to be unrolled should have a constant extent"; return; } int64_t extent = op->extent.as_int64(); @@ -94,7 +94,8 @@ struct UnrollMutator : public ir::IRMutator { for (int i = min->value; i < extent->value; i++) { Expr start = op->min + i; - body.push_back(ir::ir_utils::IRCopy(op->body)); + body.push_back( + ir::ir_utils::IRCopy(op->body, /* copy_buffer_node = */ false)); cinn::ir::ir_utils::IrReplaceVarBroadcast( &body.back(), op->loop_var, start); } diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index 67e309c73a6a0..c32991612e561 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -50,8 +50,11 @@ Expr Widen(Expr e, int lanes) { } } - CHECK_EQ(e.type().lanes(), 1) - << "Cannot broadcast lanes from " << e.type().lanes() << " to " << lanes; + PADDLE_ENFORCE_EQ( + e.type().lanes(), + 1, + phi::errors::InvalidArgument( + "Cannot broadcast lanes from %d to %d.", e.type().lanes(), lanes)); return ir::Broadcast::Make(e, lanes); } @@ -742,7 +745,13 @@ struct VectorizeLoops_ : public IRMutator { if (forloop->is_vectorized()) { Context::info_rgt().Get("vectorized_forloop_count")++; - CHECK_GT(forloop->vectorize_info().factor, 0); + PADDLE_ENFORCE_GT( + forloop->vectorize_info().factor, + 0, + phi::errors::InvalidArgument( + "The value of factor in forloop's vectorize_info is incorrect." + "Expected value is larger than 0, but receive %d. ", + forloop->vectorize_info().factor)); CHECK(is_zero(forloop->min)); Expr for_extent = cinn::common::AutoSimplify(forloop->extent); @@ -795,10 +804,14 @@ struct VectorizeLoops_ : public IRMutator { } int extent = extent_int->value; - CHECK_GT(extent, 0) - << "Loop over " << Expr(new_forloop->loop_var) << " has extent " - << new_forloop->extent - << ". Can only vectorize loops over a constant extent > 1"; + PADDLE_ENFORCE_GT( + extent, + 0, + phi::errors::InvalidArgument( + "Loop over %s has extent %d" + ". Can only vectorize loops over a constant extent > 1", + Expr(new_forloop->loop_var), + new_forloop->extent)); VLOG(2) << "Vectorizing " << new_forloop->loop_var << " extent " << extent; @@ -810,7 +823,8 @@ struct VectorizeLoops_ : public IRMutator { cuda_vectorizer.Visit(&new_forloop->body); // unroll the new forloop to compute each element of the vector // iteratively - auto copied_loop = ir::ir_utils::IRCopy(_new_forloop); + auto copied_loop = + ir::ir_utils::IRCopy(_new_forloop, /* copy_buffer_node = */ false); copied_loop.As()->set_unrolled(); optim::UnrollLoop(&copied_loop); // add cast exprs of vector type in the front of vectorized forloop, @@ -893,13 +907,14 @@ struct VectorizeLoops_ : public IRMutator { Var new_iterator_outer( cinn::common::UniqName(outer_for->loop_var->name + "_s")); - Expr inner_for_b = - Block::Make({For::Make(new_iterator_inner, - inner_for->min, - b, - ForType::Serial, - DeviceAPI::UNK, - ir::ir_utils::IRCopy(inner_for->body))}); + Expr inner_for_b = Block::Make({For::Make( + new_iterator_inner, + inner_for->min, + b, + ForType::Serial, + DeviceAPI::UNK, + ir::ir_utils::IRCopy(inner_for->body, + /* copy_buffer_node = */ false))}); cinn::ir::ir_utils::IrReplaceVarBroadcast( &inner_for_b, inner_for->loop_var, Expr(new_iterator_inner)); @@ -925,7 +940,12 @@ struct VectorizeLoops_ : public IRMutator { //! Split the forloop with size \p factor. //! @return The new forloop. Expr SplitForLoop(For *forloop, int factor) { - CHECK_GT(factor, 1); + PADDLE_ENFORCE_GT(factor, + 1, + phi::errors::InvalidArgument( + "The value of factor in SplitForLoop is incorrect." + "Expected value is larger than 1, but receive %d. ", + factor)); auto *for_min_i = forloop->min.As(); CHECK(forloop); if (!for_min_i) return Expr(); diff --git a/paddle/cinn/optim/vectorize_loops_test.cc b/paddle/cinn/optim/vectorize_loops_test.cc index 270e37f1dc46a..7f9abe1e2c512 100644 --- a/paddle/cinn/optim/vectorize_loops_test.cc +++ b/paddle/cinn/optim/vectorize_loops_test.cc @@ -80,7 +80,7 @@ void matmul(void* _args, int32_t num_args) float* C = ((float*)(_C->memory)); for (int32_t i = 0; i < 100; i += 1) { for (int32_t j = 0; j < 32; j += 1) { - C[StackVec<16,int32_t>::Ramp(((500 * i) + (16 * j)), 1, 16)] = (StackedVec::Load(A,((500 * i) + (16 * j))) * StackedVec::Load(B,((500 * i) + (16 * j)))); + C[StackVec<16,int32_t>::Ramp(((16 * j) + (i * 500)), 1, 16)] = (StackedVec::Load(A,((16 * j) + (i * 500))) * StackedVec::Load(B,((16 * j) + (i * 500)))); }; }; cinn_buffer_free((void*)(0), _C); diff --git a/paddle/cinn/poly/ast_gen.cc b/paddle/cinn/poly/ast_gen.cc index f71ec5fed9ed6..dad3f25fe1b4e 100644 --- a/paddle/cinn/poly/ast_gen.cc +++ b/paddle/cinn/poly/ast_gen.cc @@ -359,8 +359,9 @@ void IslAstNodeToCinnExpr(const isl::ast_node& node, ir::Expr* expr) { // EatMark(node, expr); } break; default: - LOG(FATAL) << "Unexpected ISL node type " - << isl_ast_node_get_type(node.get()); + std::stringstream ss; + ss << "Unexpected ISL node type " << isl_ast_node_get_type(node.get()); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); break; } } @@ -566,7 +567,9 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) { *expr = ir::Select::Make(ops[0], ops[1], ops[2]); break; default: - LOG(FATAL) << "unsupported op " << op_type; + std::stringstream ss; + ss << "unsupported op " << op_type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } break; default: diff --git a/paddle/cinn/poly/isl_utils.cc b/paddle/cinn/poly/isl_utils.cc index ed3a9b7f86e15..8262db4f14e29 100644 --- a/paddle/cinn/poly/isl_utils.cc +++ b/paddle/cinn/poly/isl_utils.cc @@ -422,14 +422,14 @@ isl::set isl_set_dim_name_if_null( return isl::manage(set); } -isl::map RemoveAxiesByInputNames(const isl::map &x, - const isl::set &origin_domain, - const std::vector &dim_in_names) { +isl::map RemoveAxesByInputNames(const isl::map &x, + const isl::set &origin_domain, + const std::vector &dim_in_names) { std::string map_str = isl_map_to_str(x.get()); isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); auto related_output_names = - GetRelatedOutputAxies(x, origin_domain, dim_in_names); + GetRelatedOutputAxes(x, origin_domain, dim_in_names); if (dim_in_names.empty()) return temp_transform; for (auto &i : dim_in_names) { temp_transform = isl::manage(isl_remove_axis_by_name( @@ -442,7 +442,7 @@ isl::map RemoveAxiesByInputNames(const isl::map &x, return temp_transform; } -isl::map RemoveAxiesByOutputNames( +isl::map RemoveAxesByOutputNames( const isl::map &x, const isl::set &origin_domain, const std::vector &dim_out_names) { @@ -450,7 +450,7 @@ isl::map RemoveAxiesByOutputNames( isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); auto related_input_names = - GetRelatedInputAxies(x, origin_domain, dim_out_names); + GetRelatedInputAxes(x, origin_domain, dim_out_names); if (dim_out_names.empty()) return temp_transform; for (auto &i : dim_out_names) { temp_transform = isl::manage(isl_remove_axis_by_name( @@ -463,24 +463,24 @@ isl::map RemoveAxiesByOutputNames( return temp_transform; } -std::vector GetRelatedOutputAxies( +std::vector GetRelatedOutputAxes( const isl::map &x, const isl::set &origin_domain, const std::vector &dim_in_names) { std::string map_str = isl_map_to_str(x.get()); - VLOG(1) << "GetRelatedOutputAxies map_str is : " << map_str; + VLOG(1) << "GetRelatedOutputAxes map_str is : " << map_str; isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); auto dim_out_names = isl_get_dim_names(temp_transform, isl_dim_out); std::set dim_in_set; for (auto &i : dim_in_names) { - VLOG(1) << "GetRelatedOutputAxies dim_in_names is : " << i; + VLOG(1) << "GetRelatedOutputAxes dim_in_names is : " << i; dim_in_set.insert(i); } std::set res_set; for (auto &i : dim_out_names) { auto related_in_dim = - GetRelatedInputAxies(temp_transform, origin_domain, {i}); + GetRelatedInputAxes(temp_transform, origin_domain, {i}); for (auto &j : related_in_dim) { if (dim_in_set.count(j) > 0) { res_set.insert(i); @@ -489,24 +489,24 @@ std::vector GetRelatedOutputAxies( } std::vector res; for (auto &i : res_set) { - VLOG(1) << "GetRelatedOutputAxies res is : " << i; + VLOG(1) << "GetRelatedOutputAxes res is : " << i; res.push_back(i); } return res; } -std::vector GetRelatedInputAxies( +std::vector GetRelatedInputAxes( const isl::map &x, const isl::set &origin_domain, const std::vector &dim_out_names, bool strict) { std::string map_str = isl_map_to_str(x.get()); - VLOG(1) << "GetRelatedInputAxies map_str is : " << map_str; + VLOG(1) << "GetRelatedInputAxes map_str is : " << map_str; isl::ctx this_ctx = x.ctx(); isl::map temp_transform(this_ctx, map_str); auto dim_in_names = isl_get_dim_names(temp_transform, isl_dim_in); for (auto &i : dim_out_names) { - VLOG(1) << "GetRelatedInputAxies dim_out_names is : " << i; + VLOG(1) << "GetRelatedInputAxes dim_out_names is : " << i; temp_transform = isl::manage(isl_remove_axis_by_name( temp_transform.release(), isl_dim_out, i.c_str())); } @@ -526,10 +526,10 @@ std::vector GetRelatedInputAxies( } for (auto &i : dim_in_names) { if (utils::Count(&map_str, i) != utils::Count(&deleted_map, i)) { - VLOG(1) << "GetRelatedInputAxies res is : " << i; + VLOG(1) << "GetRelatedInputAxes res is : " << i; res.push_back(i); } else if (out_set_without_suffix.count(i) > 0 && !strict) { - VLOG(1) << "GetRelatedInputAxies res is : " << i; + VLOG(1) << "GetRelatedInputAxes res is : " << i; res.push_back(i); } else if (out_set.count(i) > 0) { auto range1 = isl_set_get_axis_range_by_name(origin_domain.get(), i); diff --git a/paddle/cinn/poly/isl_utils.h b/paddle/cinn/poly/isl_utils.h index d9ae0ca65de82..6b74aadc73816 100644 --- a/paddle/cinn/poly/isl_utils.h +++ b/paddle/cinn/poly/isl_utils.h @@ -122,9 +122,9 @@ isl::set SetGetDims(isl::set set, const std::vector& dims); * @param dim_in_names The names of input dims to remove. * @return The edited map. */ -isl::map RemoveAxiesByInputNames(const isl::map& x, - const isl::set& origin_domain, - const std::vector& dim_in_names); +isl::map RemoveAxesByInputNames(const isl::map& x, + const isl::set& origin_domain, + const std::vector& dim_in_names); /** * Given an isl::map and a vector of names of dim_out, @@ -133,22 +133,21 @@ isl::map RemoveAxiesByInputNames(const isl::map& x, * @param dim_in_names The names of output dims to remove. * @return The edited map. */ -isl::map RemoveAxiesByOutputNames( - const isl::map& x, - const isl::set& origin_domain, - const std::vector& dim_out_names); +isl::map RemoveAxesByOutputNames(const isl::map& x, + const isl::set& origin_domain, + const std::vector& dim_out_names); /** * Given an isl::map and a vector of names of dim_out, * get the names of related input dims. * @param x The input map. * @param dim_out_names The names of output dims. - * @param strict Indicates whether computes the strictly related input axies. + * @param strict Indicates whether computes the strictly related input axes. * For example, if strict == true, then input 'j' is related to output * 'j_outer_inner_outer' * @return The vector of names of related input dims. */ -std::vector GetRelatedInputAxies( +std::vector GetRelatedInputAxes( const isl::map& x, const isl::set& origin_domain, const std::vector& dim_out_names, @@ -161,7 +160,7 @@ std::vector GetRelatedInputAxies( * @param dim_in_names The names of input dims. * @return The vector of names of related output dims. */ -std::vector GetRelatedOutputAxies( +std::vector GetRelatedOutputAxes( const isl::map& x, const isl::set& origin_domain, const std::vector& dim_in_names); diff --git a/paddle/cinn/poly/poly_scheduler.cc b/paddle/cinn/poly/poly_scheduler.cc index 539be8221d8df..7cfc7851a145a 100644 --- a/paddle/cinn/poly/poly_scheduler.cc +++ b/paddle/cinn/poly/poly_scheduler.cc @@ -266,8 +266,9 @@ std::vector NaivePartitionGraph(cinn::common::Graph* graph) { auto* node0 = node; if (name2node.count(compute_at.stage->id()) == 0) { continue; - LOG(FATAL) << "Didn't find node with name " << compute_at.stage->id() - << " !"; + std::stringstream ss; + ss << "Didn't find node with name " << compute_at.stage->id() << " !"; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } auto* node1 = name2node[compute_at.stage->id()]; VLOG(3) << "a -> b: " << node0->id() << " -> " << node1->id(); diff --git a/paddle/cinn/poly/stage.cc b/paddle/cinn/poly/stage.cc index aca5e548f09fb..60ae01782770d 100644 --- a/paddle/cinn/poly/stage.cc +++ b/paddle/cinn/poly/stage.cc @@ -441,7 +441,7 @@ void Stage::EditTempTensor(Stage *other, int level) { } } // Iterators of loop within level will be erased. - auto related_dim_in = GetRelatedInputAxies( + auto related_dim_in = GetRelatedInputAxes( this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); @@ -460,27 +460,27 @@ void Stage::EditTempTensor(Stage *other, int level) { if (bind_info[new_i].for_type == ir::ForType::GPUBlock && (this->scope() == ScopeKind::kShared || this->scope() == ScopeKind::kLocal)) { - auto related_dim_in = GetRelatedInputAxies( + auto related_dim_in = GetRelatedInputAxes( this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } } else if (bind_info[new_i].for_type == ir::ForType::GPUThread && (this->scope() == ScopeKind::kLocal)) { - auto related_dim_in = GetRelatedInputAxies( + auto related_dim_in = GetRelatedInputAxes( this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { erase_var.insert(j); } } else { - auto related_dim_in = GetRelatedInputAxies( + auto related_dim_in = GetRelatedInputAxes( this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { undo_erase_var.insert(j); } } } else { - auto related_dim_in = GetRelatedInputAxies( + auto related_dim_in = GetRelatedInputAxes( this->transform(), this->domain(), {transform_domain_names[i]}); for (auto &j : related_dim_in) { undo_erase_var.insert(j); @@ -608,9 +608,9 @@ void Stage::ComputeAt(Stage *other, int level) { level_out_dims.push_back(target_map_dims[i]); related_output_dims_set.insert(target_map_dims[i]); } - auto related_input_dims = GetRelatedInputAxies( + auto related_input_dims = GetRelatedInputAxes( new_target_transform, other->domain(), level_out_dims); - auto related_output_dims = GetRelatedOutputAxies( + auto related_output_dims = GetRelatedOutputAxes( new_target_transform, other->domain(), related_input_dims); for (auto &i : related_output_dims) { related_output_dims_set.insert(i); @@ -708,7 +708,7 @@ void Stage::ComputeAt(Stage *other, int level) { int max_iv = maxv.get_num_si(); int min_iv = minv.get_num_si(); auto related_input_dims = - GetRelatedInputAxies(trans_res, domain_, {trans_dim_out[i]}, true); + GetRelatedInputAxes(trans_res, domain_, {trans_dim_out[i]}, true); if (max_iv != min_iv && related_input_dims.empty()) { trans_res = isl::manage(isl_remove_axis_by_name( trans_res.release(), isl_dim_out, trans_dim_out[i].c_str())); @@ -1627,7 +1627,7 @@ void Stage::AddForloopInfo(int level, const StageForloopInfo &info) { } void Stage::CopyTransform(Stage *other, int level) { - auto target_transform = RemoveAxiesByInputNames( + auto target_transform = RemoveAxesByInputNames( other->transform(), other->domain(), other->origin_reduce_axis_names()); isl::set target_origin_domain(other->domain().ctx(), isl_set_to_str(other->domain().get())); @@ -1654,9 +1654,9 @@ void Stage::CopyTransform(Stage *other, int level) { dim_out_level.push_back( isl_map_get_dim_name(temp_target_trans.get(), isl_dim_out, i)); } - auto related_dim_in = GetRelatedInputAxies( + auto related_dim_in = GetRelatedInputAxes( temp_target_trans, target_origin_domain, dim_out_level); - auto related_dim_out = GetRelatedOutputAxies( + auto related_dim_out = GetRelatedOutputAxes( temp_target_trans, target_origin_domain, related_dim_in); for (auto &i : related_dim_out) { if (i == pivot_dim_out) { diff --git a/paddle/cinn/pybind/framework.cc b/paddle/cinn/pybind/framework.cc index fde1f7dd8eba0..5122a61d9fc7b 100644 --- a/paddle/cinn/pybind/framework.cc +++ b/paddle/cinn/pybind/framework.cc @@ -131,7 +131,8 @@ void BindFramework(pybind11::module *m) { t->shape().numel() * t->type().bytes(), cudaMemcpyDeviceToHost)); #else - LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal("To use CUDA backends, " + "you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED @@ -175,7 +176,8 @@ void BindFramework(pybind11::module *m) { self->shape().numel() * self->type().bytes(), cudaMemcpyDeviceToHost)); #else - LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal("To use CUDA backends, " + "you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED @@ -210,7 +212,8 @@ void BindFramework(pybind11::module *m) { self->shape().numel() * self->type().bytes(), cudaMemcpyHostToDevice)); #else - LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal("To use CUDA backends, " + "you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED diff --git a/paddle/cinn/pybind/frontend.cc b/paddle/cinn/pybind/frontend.cc index 05e814ce107f8..f7eaf01a59f07 100644 --- a/paddle/cinn/pybind/frontend.cc +++ b/paddle/cinn/pybind/frontend.cc @@ -229,7 +229,8 @@ void BindFrontend(pybind11::module *m) { in_tensor->shape().numel() * dtype.bytes(), cudaMemcpyHostToDevice)); #else - LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal("To use CUDA backends, " + "you need to set WITH_CUDA ON!")); #endif } else if (target.arch == Target::Arch::X86) { memcpy(data, @@ -323,7 +324,8 @@ void BindFrontend(pybind11::module *m) { in_tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice)); #else - LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal("To use CUDA backends, " + "you need to set WITH_CUDA ON!")); #endif } else if (target.arch == Target::Arch::X86) { for (size_t j = 0; j < in_tensor->shape().numel(); j++) { @@ -373,7 +375,8 @@ void BindFrontend(pybind11::module *m) { in_tensor->shape().numel() * sizeof(float), cudaMemcpyHostToDevice)); #else - LOG(FATAL) <<"To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal("To use CUDA backends, " + "you need to set WITH_CUDA ON!")); #endif } else if (target.arch == Target::Arch::X86) { for (size_t j = 0; j < in_tensor->shape().numel(); j++) { diff --git a/paddle/cinn/pybind/ir/ir.cc b/paddle/cinn/pybind/ir/ir.cc index 6118f7c8a5e69..d9f9bd5fcdf7f 100644 --- a/paddle/cinn/pybind/ir/ir.cc +++ b/paddle/cinn/pybind/ir/ir.cc @@ -47,8 +47,8 @@ std::vector AxisMap(const std::string& kinds, } else if (c == 'R') { iter_var->is_reduce_axis = true; } else { - LOG(FATAL) - << "kind of axis setting error, must be R(Reduce) or S(Spatial)"; + PADDLE_THROW(phi::errors::InvalidArgument( + "kind of axis setting error, must be R(Reduce) or S(Spatial)")); } rets.push_back(SetScheduleBlockIterVar(iter_var, iter_expression[i])); } diff --git a/paddle/cinn/pybind/ir/ir_api.cc b/paddle/cinn/pybind/ir/ir_api.cc index 56dff498dd710..224bf87e09bfa 100644 --- a/paddle/cinn/pybind/ir/ir_api.cc +++ b/paddle/cinn/pybind/ir/ir_api.cc @@ -383,6 +383,7 @@ void BindIrIr(py::module *m) { ir::Expr, const std::string &, bool, + bool, bool>(&ir::_Var_::Make)) .def("copy", &ir::_Var_::Copy); @@ -747,8 +748,9 @@ auto PackedFuncCall(lang::PackedFunc &self, py::args args) { // NOLINT } else if (py::isinstance(handle)) { cinn_args.Append(CINNValue(py::cast(handle))); } else { - LOG(FATAL) << "unsupported type: " - << std::string(py::str(handle.get_type())); + std::stringstream ss; + ss << "unsupported type: " << std::string(py::str(handle.get_type())); + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } } lang::RetValue ret_value; diff --git a/paddle/cinn/pybind/ir/ir_context.cc b/paddle/cinn/pybind/ir/ir_context.cc index 8b4d0a4cf1e1d..14dad90d841b5 100644 --- a/paddle/cinn/pybind/ir/ir_context.cc +++ b/paddle/cinn/pybind/ir/ir_context.cc @@ -59,10 +59,12 @@ void LowerFuncContextNode::ExitWithContext() { void IfContextNode::ExitWithContext() { IRContextNode::ExitWithContext(); if (!exprs.empty()) { - LOG(FATAL) << "Expr not be either in ThenBlock or ElseBlock in if"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Expr not be either in ThenBlock or ElseBlock in if")); } if (!true_case.defined()) { - LOG(FATAL) << "Expr not be defined in ThenBlock"; + PADDLE_THROW( + phi::errors::InvalidArgument("Expr not be defined in ThenBlock")); } LinkToParentContext(ir::IfThenElse::Make(condition, true_case, false_case)); } diff --git a/paddle/cinn/pybind/ir/ir_context.h b/paddle/cinn/pybind/ir/ir_context.h index 8cdf0ed85c081..837d66e8c0760 100644 --- a/paddle/cinn/pybind/ir/ir_context.h +++ b/paddle/cinn/pybind/ir/ir_context.h @@ -21,7 +21,7 @@ #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/lowered_func.h" -#include "paddle/cinn/utils/error.h" +#include "paddle/common/enforce.h" namespace cinn { namespace pybind { @@ -73,7 +73,7 @@ class IRContext { err_msg << "TypeConvertError: convert " << data_.get()->type_info() << " to " << TIRContextNode::__type_info__; - CINN_THROW(err_msg.str()); + PADDLE_THROW(phi::errors::InvalidArgument(err_msg.str())); } return ctx_node; } @@ -82,8 +82,10 @@ class IRContext { CHECK(data_.get()) << "IrContext holds null"; auto* ctx_node = data_.get()->safe_as(); if (!ctx_node) { - LOG(FATAL) << "TypeConvertError: convert " << data_.get()->type_info() - << " to " << TIRContextNode::__type_info__; + std::stringstream ss; + ss << "TypeConvertError: convert " << data_.get()->type_info() << " to " + << TIRContextNode::__type_info__; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return ctx_node; } @@ -235,8 +237,10 @@ void LinkToParentContext(ir::Expr); template IRContext IRBuilderNode::GetLastContext() const { if (!(contexts.back().As())) { - LOG(FATAL) << "TypeError: The last context is not " - << TIRContextNode::__type_info__; + std::stringstream ss; + ss << "TypeError: The last context is not " + << TIRContextNode::__type_info__; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return contexts.back(); } diff --git a/paddle/cinn/pybind/optim.cc b/paddle/cinn/pybind/optim.cc index bb1a18a2c24fe..4f40ea660149c 100755 --- a/paddle/cinn/pybind/optim.cc +++ b/paddle/cinn/pybind/optim.cc @@ -42,7 +42,10 @@ void BindSimplify(py::module* m) { }, py::arg("expr")); - m->def("ir_copy", py::overload_cast(&ir::ir_utils::IRCopy)); + m->def("ir_copy", + py::overload_cast(&ir::ir_utils::IRCopy), + py::arg("x"), + py::arg("copy_buffer_node") = true); } } // namespace diff --git a/paddle/cinn/pybind/runtime.cc b/paddle/cinn/pybind/runtime.cc index 91db8af397ec2..0ef1ee542aa35 100644 --- a/paddle/cinn/pybind/runtime.cc +++ b/paddle/cinn/pybind/runtime.cc @@ -92,7 +92,8 @@ cinn_buffer_t *CreateBufferFromNumpy( buffer->memory, data.data(), data.nbytes(), cudaMemcpyHostToDevice)); return buffer; #else - LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal( + "To use CUDA backends, you need to set WITH_CUDA ON!")); #endif } else { CINN_NOT_IMPLEMENTED @@ -108,7 +109,8 @@ void BufferCopyTo(const cinn_buffer_t &buffer, py::array array) { CUDA_CALL(cudaMemcpy( array_data, buffer.memory, array.nbytes(), cudaMemcpyDeviceToHost)); #else - LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; + PADDLE_THROW(phi::errors::Fatal( + "To use CUDA backends, you need to set WITH_CUDA ON!")); #endif } else { @@ -135,7 +137,7 @@ py::array BufferHostMemoryToNumpy(cinn_buffer_t &buffer) { // NOLINT } else if (buffer.type == cinn_bool_t()) { dt = py::dtype::of(); } else { - LOG(FATAL) << "Not supported type found"; + PADDLE_THROW(phi::errors::InvalidArgument("Not supported type found")); } py::array::ShapeContainer shape(buffer.dims, buffer.dims + buffer.dimensions); diff --git a/paddle/cinn/runtime/buffer.cc b/paddle/cinn/runtime/buffer.cc old mode 100755 new mode 100644 index 6f9e6d51ecaa8..9ab9d591c0a51 --- a/paddle/cinn/runtime/buffer.cc +++ b/paddle/cinn/runtime/buffer.cc @@ -25,21 +25,30 @@ Shape::Shape(const Shape &other) } void Shape::Resize(int ndim) { - CHECK_GT(ndim, 0); + PADDLE_ENFORCE_GT(ndim, + 0, + phi::errors::InvalidArgument( + "Target dimension to resize must be greater than 0.")); ndims_ = ndim; if (data_) delete data_; data_ = new value_type[ndim]; } Shape::value_type &Shape::operator[](int i) { - CHECK_GT(ndims_, 0) << "shape is empty"; - CHECK_LT(i, ndims_) << "index " << i << "out of range " << ndims_; + PADDLE_ENFORCE_GT(ndims_, 0, phi::errors::InvalidArgument("Shape is empty.")); + PADDLE_ENFORCE_LT( + i, + ndims_, + phi::errors::OutOfRange("Index %d out of range %d.", i, ndims_)); return data_[i]; } Shape::value_type Shape::operator[](int i) const { - CHECK_GT(ndims_, 0) << "shape is empty"; - CHECK_LT(i, ndims_) << "index " << i << "out of range " << ndims_; + PADDLE_ENFORCE_GT(ndims_, 0, phi::errors::InvalidArgument("Shape is empty.")); + PADDLE_ENFORCE_LT( + i, + ndims_, + phi::errors::OutOfRange("Index %d out of range %d.", i, ndims_)); return data_[i]; } diff --git a/paddle/cinn/runtime/buffer.h b/paddle/cinn/runtime/buffer.h old mode 100755 new mode 100644 index b211389c6dcce..f384d136fdafc --- a/paddle/cinn/runtime/buffer.h +++ b/paddle/cinn/runtime/buffer.h @@ -16,6 +16,7 @@ #include #include +#include "paddle/common/enforce.h" /** * runtime::Buffer is an encapsulation of memory operations. */ @@ -68,9 +69,13 @@ class Buffer { //! Allocate the memory in host device. void AllocHost() { - CHECK(shape_.defined()); + PADDLE_ENFORCE_EQ( + shape_.defined(), + true, + phi::errors::InvalidArgument("shape haven't been defined.")); data_ = new T[shape_.num_elements()]; - CHECK(data_) << "alloc buffer failed"; + PADDLE_ENFORCE_NOT_NULL(data_, + phi::errors::NotFound("alloc buffer failed.")); } //! Deallocate the memory in host device. void DeallocHost() { @@ -79,15 +84,27 @@ class Buffer { } T& operator()(int i0) { - CHECK_EQ(shape_.ndims(), 1); + PADDLE_ENFORCE_EQ(shape_.ndims(), + 1, + phi::errors::InvalidArgument( + "Expected shape has 1 dimension, but recevied %d.", + shape_.ndims())); return static_cast(data_)[i0]; } T& operator()(int i0, int i1) { - CHECK_EQ(shape_.ndims(), 2); + PADDLE_ENFORCE_EQ(shape_.ndims(), + 2, + phi::errors::InvalidArgument( + "Expected shape has 2 dimensions, but recevied %d.", + shape_.ndims())); return static_cast(data_)[i0 * shape_[0] + i1]; } T& operator()(int i0, int i1, int i2) { - CHECK_EQ(shape_.ndims(), 3); + PADDLE_ENFORCE_EQ(shape_.ndims(), + 3, + phi::errors::InvalidArgument( + "Expected shape has 3 dimensions, but recevied %d.", + shape_.ndims())); return static_cast( data_)[i0 * shape_[1] * shape_[2] + i1 * shape_[2] + i2]; } diff --git a/paddle/cinn/runtime/cpu/cblas.cc b/paddle/cinn/runtime/cpu/cblas.cc index 9e08c128cb66b..5c4887ab20973 100644 --- a/paddle/cinn/runtime/cpu/cblas.cc +++ b/paddle/cinn/runtime/cpu/cblas.cc @@ -18,6 +18,7 @@ #include "paddle/cinn/backends/extern_func_jit_register.h" #include "paddle/cinn/common/cas.h" +#include "paddle/common/enforce.h" namespace { @@ -117,8 +118,11 @@ void cinn_call_cholesky_host( memcpy(out->memory, x->memory, x->memory_size); uint8_t bits = x->type.bits; - CHECK(bits == 32 || bits == 64) - << "Unsupported bits = " << bits << " float data type for cholesky"; + PADDLE_ENFORCE_EQ( + bits == 32 || bits == 64, + true, + phi::errors::InvalidArgument( + "Unsupported bits = %d float data type for cholesky.", bits)); char uplo = upper ? 'U' : 'L'; for (int i = 0; i < batch_size; i++) { if (bits == 32) { @@ -141,8 +145,12 @@ CINN_REGISTER_HELPER(cinn_cpu_mkl) { FunctionProto::shape_inference_t inference_shape_gemm = [](const std::vector& args, int offset) { - CHECK_EQ(offset, 0UL) << "Only one output"; - CHECK_EQ(args.size(), 12UL) << "Wrong number of arguments passed in"; + PADDLE_ENFORCE_EQ( + offset, 0UL, phi::errors::InvalidArgument("Only one output.")); + PADDLE_ENFORCE_EQ(args.size(), + 12UL, + phi::errors::InvalidArgument( + "Wrong number of arguments passed in.")); auto M = cinn::common::AutoSimplify(args[1]); auto N = cinn::common::AutoSimplify(args[2]); std::vector shape; @@ -153,11 +161,16 @@ CINN_REGISTER_HELPER(cinn_cpu_mkl) { FunctionProto::shape_inference_t inference_shape_gemm_batch = [](const std::vector& args, int offset) { - CHECK_EQ(offset, 0UL) << "Only one output"; - CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; + PADDLE_ENFORCE_EQ( + offset, 0UL, phi::errors::InvalidArgument("Only one output.")); + PADDLE_ENFORCE_EQ(args.size(), + 16UL, + phi::errors::InvalidArgument( + "Wrong number of arguments passed in.")); auto& A = args[14]; auto A_tensor = A.as_tensor(); - CHECK(A_tensor); + PADDLE_ENFORCE_NOT_NULL( + A_tensor, phi::errors::InvalidArgument("expected type is tensor.")); auto batch_size = cinn::common::AutoSimplify(args[1]); int32_t batch_size_val = batch_size.as_int32(); @@ -169,7 +182,10 @@ CINN_REGISTER_HELPER(cinn_cpu_mkl) { int total = 1; for (auto& v : A_tensor->shape) { auto val = cinn::common::AutoSimplify(v); - CHECK(val.is_constant()); + PADDLE_ENFORCE_EQ( + val.is_constant(), + true, + phi::errors::InvalidArgument("expected type is constant.")); shape.push_back(val); total *= val.as_int32(); if (total >= batch_size_val) break; diff --git a/paddle/cinn/runtime/cpu/mkl_math.cc b/paddle/cinn/runtime/cpu/mkl_math.cc index f481ef072129d..0b2dc7aadd1b3 100644 --- a/paddle/cinn/runtime/cpu/mkl_math.cc +++ b/paddle/cinn/runtime/cpu/mkl_math.cc @@ -23,19 +23,32 @@ #include "paddle/cinn/backends/extern_func_jit_register.h" #include "paddle/cinn/backends/function_prototype.h" #include "paddle/cinn/runtime/cpu/host_intrinsics.h" +#include "paddle/common/enforce.h" -#define CINN_MKL_VECTOR_MATH_FP(fn__, name__) \ - void cinn_mkl_##name__##_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { \ - CHECK_EQ(x->num_elements(), out->num_elements()); \ - vs##fn__(x->num_elements(), \ - reinterpret_cast(x->memory), \ - reinterpret_cast(out->memory)); \ - } \ - void cinn_mkl_##name__##_v_fp64(cinn_buffer_t *x, cinn_buffer_t *out) { \ - CHECK_EQ(x->num_elements(), out->num_elements()); \ - vd##fn__(x->num_elements(), \ - reinterpret_cast(x->memory), \ - reinterpret_cast(out->memory)); \ +#define CINN_MKL_VECTOR_MATH_FP(fn__, name__) \ + void cinn_mkl_##name__##_v_fp32(cinn_buffer_t *x, cinn_buffer_t *out) { \ + PADDLE_ENFORCE_EQ( \ + x->num_elements(), \ + out->num_elements(), \ + phi::errors::InvalidArgument("X's number of elements (%d) should " \ + "be equal to output's (%d).", \ + x->num_elements(), \ + out->num_elements())); \ + vs##fn__(x->num_elements(), \ + reinterpret_cast(x->memory), \ + reinterpret_cast(out->memory)); \ + } \ + void cinn_mkl_##name__##_v_fp64(cinn_buffer_t *x, cinn_buffer_t *out) { \ + PADDLE_ENFORCE_EQ( \ + x->num_elements(), \ + out->num_elements(), \ + phi::errors::InvalidArgument("X's number of elements (%d) should " \ + "be equal to output's (%d).", \ + x->num_elements(), \ + out->num_elements())); \ + vd##fn__(x->num_elements(), \ + reinterpret_cast(x->memory), \ + reinterpret_cast(out->memory)); \ } CINN_MKL_VECTOR_MATH_FP(Exp, exp); diff --git a/paddle/cinn/runtime/cpu/mkl_math_test.cc b/paddle/cinn/runtime/cpu/mkl_math_test.cc index d064535d940c1..50798ebb39029 100644 --- a/paddle/cinn/runtime/cpu/mkl_math_test.cc +++ b/paddle/cinn/runtime/cpu/mkl_math_test.cc @@ -24,6 +24,7 @@ #include "paddle/cinn/common/test_helper.h" #include "paddle/cinn/runtime/cpu/host_intrinsics.h" #include "paddle/cinn/runtime/cpu/use_extern_funcs.h" +#include "paddle/common/enforce.h" namespace cinn { namespace runtime { @@ -89,11 +90,18 @@ void TestCallElementwise(const std::string &fn_name, jit->Link(module); auto fn = jit->Lookup("fn"); - CHECK(fn); + PADDLE_ENFORCE_NOT_NULL(fn, phi::errors::NotFound("fn is not found.")); auto fn_ = reinterpret_cast(fn); cinn_buffer_t *A_buf; if (set_value != 0) { + PADDLE_ENFORCE_EQ( + x->num_elements(), + out->num_elements(), + phi::errors::InvalidArgument("X's number of elements (%d) should " + "be equal to output's (%d).", + x->num_elements(), + out->num_elements())); A_buf = CreateBuffer({10, 10}, false, set_value); } else { A_buf = CreateBuffer({10, 10}); diff --git a/paddle/cinn/runtime/cpu/mkldnn_math.cc b/paddle/cinn/runtime/cpu/mkldnn_math.cc index b45ddedd2e890..f20e56e32f1e6 100644 --- a/paddle/cinn/runtime/cpu/mkldnn_math.cc +++ b/paddle/cinn/runtime/cpu/mkldnn_math.cc @@ -18,6 +18,7 @@ #include "paddle/cinn/backends/extern_func_jit_register.h" #include "paddle/cinn/common/cas.h" +#include "paddle/common/enforce.h" using dnnl::algorithm; using dnnl::memory; @@ -50,7 +51,9 @@ void cinn_cpu_mkldnn_softmax_fp32(int batch, format_tag = tag::abcd; break; default: - LOG(FATAL) << "wrong dim: " << size; + std::stringstream ss; + ss << "wrong dim: " << size; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); break; } @@ -161,7 +164,10 @@ CINN_REGISTER_HELPER(cinn_cpu_mkldnn) { FunctionProto::shape_inference_t inference_shape_conv2d_nchw = [](const std::vector& args, int offset) { - CHECK_EQ(args.size(), 16UL) << "Wrong number of arguments passed in"; + PADDLE_ENFORCE_EQ(args.size(), + 16UL, + phi::errors::InvalidArgument( + "Wrong number of arguments passed in.")); auto N = cinn::common::AutoSimplify(args[0]); int input_h = cinn::common::AutoSimplify(args[2]).as_int32(); int input_w = cinn::common::AutoSimplify(args[3]).as_int32(); diff --git a/paddle/cinn/runtime/cpu/thread_backend.cc b/paddle/cinn/runtime/cpu/thread_backend.cc index 43804e33b1e60..2bc67bd95e723 100644 --- a/paddle/cinn/runtime/cpu/thread_backend.cc +++ b/paddle/cinn/runtime/cpu/thread_backend.cc @@ -25,6 +25,7 @@ #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/runtime/intrinsic.h" +#include "paddle/common/enforce.h" int max_concurrency() { int max_concurrency = 1; @@ -56,7 +57,8 @@ int cinn_backend_parallel_launch(FCINNParallelLambda flambda, (*flambda)(thread_num, num_task, datas); } #else - LOG(FATAL) << "CINN host parallel launch need OpenMP! Please check."; + PADDLE_THROW(phi::errors::Fatal( + "CINN host parallel launch need OpenMP! Please check.")); #endif // CINN_USE_OPENMP return 0; } diff --git a/paddle/cinn/runtime/cuda/cublas_util.h b/paddle/cinn/runtime/cuda/cublas_util.h index bdd21dafed544..904678f2ce2e3 100644 --- a/paddle/cinn/runtime/cuda/cublas_util.h +++ b/paddle/cinn/runtime/cuda/cublas_util.h @@ -130,10 +130,12 @@ inline cublasStatus_t cublasGemm(cudaDataType_t dtype, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); #else - LOG(FATAL) << "cublasGemmEx with bfloat16 is not supported on cuda <= 11"; + PADDLE_THROW(phi::errors::Fatal( + "cublasGemmEx with bfloat16 is not supported on cuda <= 11")); #endif } - LOG(FATAL) << "Unsupported cublasGemm precision."; + PADDLE_THROW( + phi::errors::InvalidArgument("Unsupported cublasGemm precision.")); } inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, @@ -269,11 +271,13 @@ inline cublasStatus_t cublasGemmStridedBatched(cudaDataType_t dtype, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); #else - LOG(FATAL) << "cublasGemmStridedBatched with bfloat16 is not supported on " - "cuda <= 11"; + PADDLE_THROW(phi::errors::InvalidArgument( + "cublasGemmStridedBatched with bfloat16 is not supported on " + "cuda <= 11")); #endif } - LOG(FATAL) << "Unsupported cublasGemmStridedBatched precision."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported cublasGemmStridedBatched precision.")); } inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, @@ -390,11 +394,12 @@ inline cublasStatus_t cublasGemmBatched(cudaDataType_t dtype, CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); #else - LOG(FATAL) - << "cublasGemmBatched with bfloat16 is not supported on cuda <= 11"; + PADDLE_THROW(phi::errors::Fatal( + "cublasGemmBatched with bfloat16 is not supported on cuda <= 11")); #endif } - LOG(FATAL) << "Unsupported cublasGemmBatched precision."; + PADDLE_THROW( + phi::errors::InvalidArgument("Unsupported cublasGemmBatched precision.")); } } // namespace cuda diff --git a/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc b/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc index 15fcb4030e89b..685c466f7f9c9 100644 --- a/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc +++ b/paddle/cinn/runtime/cuda/cuda_intrinsics_reduce.cc @@ -146,22 +146,22 @@ CINN_REGISTER_HELPER(cuda_intrinsics_reduce) { #undef REGISTER_BLOCK_REDUCE_FUNC_IMPL -#define REGISTER_BLOCK_SHUFLLE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ +#define REGISTER_BLOCK_SHUFFLE_FUNC_IMPL(REDUCE_TYPE, DTYPE) \ REGISTER_FACKED_EXTERN_FUNC_HELPER(block_shuffle_##REDUCE_TYPE, target) \ .SetRetType() \ .AddInputType() \ .AddInputType() \ .End(); - EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) - EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) - EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) - EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) - EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) - EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) - EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_SHUFLLE_FUNC_IMPL) + EXPAND_REDUCE_INT32_REGISTER_MARCO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_INT64_REGISTER_MARCO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_BF16_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_FP16_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_FP32_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_FP64_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) + EXPAND_REDUCE_BOOL_REGISTER_MACRO(REGISTER_BLOCK_SHUFFLE_FUNC_IMPL) -#undef REGISTER_BLOCK_SHUFLLE_FUNC_IMPL +#undef REGISTER_BLOCK_SHUFFLE_FUNC_IMPL #undef EXPAND_REDUCE_INT32_REGISTER_MARCO #undef EXPAND_REDUCE_INT64_REGISTER_MARCO diff --git a/paddle/cinn/runtime/cuda/cuda_module.cc b/paddle/cinn/runtime/cuda/cuda_module.cc index 430516d9168d3..2cc1701d774fa 100644 --- a/paddle/cinn/runtime/cuda/cuda_module.cc +++ b/paddle/cinn/runtime/cuda/cuda_module.cc @@ -27,6 +27,7 @@ #include "paddle/cinn/runtime/cuda/cuda_util.h" #include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/utils/profiler.h" +#include "paddle/common/enforce.h" namespace cinn { namespace runtime { @@ -34,10 +35,12 @@ namespace cuda { CUDAModule::CUDAModule(const std::string& data, Kind kind) : data_(data), kind_(kind) { - CHECK(!data.empty()); + PADDLE_ENFORCE_NE( + data.empty(), true, phi::errors::PreconditionNotMet("data is is empty!")); cudaGetDeviceCount(&num_devices_); - CHECK_GT(num_devices_, 0) << "No available devices"; + PADDLE_ENFORCE_GT( + num_devices_, 0, phi::errors::ResourceExhausted("No available devices!")); // TODO(Superjomn) Determine whether to initialize all the devices. int current_device_id; @@ -61,7 +64,10 @@ void CUDAModule::LaunchKernel(int device_id, << ", blockDim.y:" << blockDim.y << ", blockDim.z:" << blockDim.z << ", share_memory_size:" << share_memory_size; auto function = GetFunction(device_id, func_name); - CHECK(function); + PADDLE_ENFORCE_NOT_NULL( + function, + phi::errors::NotFound( + "%s function not found on device %d.", func_name, device_id)); cinn::utils::RecordEvent record_run("cuLaunchKernel", cinn::utils::EventType::kInstruction); CUDA_DRIVER_CALL(cuLaunchKernel(function, diff --git a/paddle/cinn/runtime/cuda/cuda_module_test.cc b/paddle/cinn/runtime/cuda/cuda_module_test.cc index fe41a1ed0ca2e..9a0ac3c8b29f3 100644 --- a/paddle/cinn/runtime/cuda/cuda_module_test.cc +++ b/paddle/cinn/runtime/cuda/cuda_module_test.cc @@ -23,6 +23,7 @@ #include "paddle/cinn/runtime/cuda/cuda_util.h" #include "paddle/cinn/runtime/cuda/test_util.h" #include "paddle/cinn/runtime/cuda/use_extern_funcs.h" +#include "paddle/common/enforce.h" namespace cinn { namespace runtime { @@ -43,7 +44,7 @@ void saxpy(float a, float *x, float *y, float *out, size_t n) )ROC"; auto ptx = compiler(source_code); - CHECK(!ptx.empty()); + PADDLE_ENFORCE_NE(ptx.empty(), true, phi::errors::NotFound("ptx is empty!")); CUDAModule module(ptx, CUDAModule::Kind::PTX); auto func = module.GetFunction(0, "saxpy"); @@ -73,7 +74,8 @@ TEST(CUDAModule, float16) { )"; auto ptx = compiler(source_code); - CHECK(!ptx.empty()); + PADDLE_ENFORCE_NE( + ptx.empty(), true, phi::errors::NotFound("ptx is empty!")); return ptx; }; @@ -116,7 +118,11 @@ TEST(CUDAModule, float16) { [](float x, float16 y) -> bool { return std::abs(x - static_cast(y)) < 1e-2f; }); - CHECK(res) << "The difference between two arrays exceeds the bound."; + PADDLE_ENFORCE_EQ( + res, + true, + phi::errors::PreconditionNotMet( + "The difference between two arrays exceeds the bound.")); } TEST(CUDAModule, bfloat16) { @@ -142,7 +148,8 @@ TEST(CUDAModule, bfloat16) { )"; auto ptx = compiler(source_code); - CHECK(!ptx.empty()); + PADDLE_ENFORCE_NE( + ptx.empty(), true, phi::errors::NotFound("ptx is empty!")); return ptx; }; @@ -185,7 +192,11 @@ TEST(CUDAModule, bfloat16) { [](float x, bfloat16 y) -> bool { return std::abs(x - static_cast(y)) < 1e-2f; }); - CHECK(res) << "The difference between two arrays exceeds the bound."; + PADDLE_ENFORCE_EQ( + res, + true, + phi::errors::PreconditionNotMet( + "The difference between two arrays exceeds the bound.")); } } // namespace cuda diff --git a/paddle/cinn/runtime/cuda/cuda_util.cc b/paddle/cinn/runtime/cuda/cuda_util.cc index 18c277339ddaf..9a565ba072a28 100644 --- a/paddle/cinn/runtime/cuda/cuda_util.cc +++ b/paddle/cinn/runtime/cuda/cuda_util.cc @@ -37,6 +37,7 @@ #include "paddle/cinn/runtime/flags.h" #include "paddle/cinn/utils/profiler.h" #include "paddle/cinn/utils/timer.h" +#include "paddle/common/enforce.h" namespace cinn { namespace runtime { @@ -151,7 +152,11 @@ void cinn_call_cublas(void *v_args, void *stream) { cinn::utils::RecordEvent record_run("cinn_call_cublas", cinn::utils::EventType::kInstruction); - CHECK_EQ(num_args, 3); + PADDLE_ENFORCE_EQ( + num_args, + 3, + phi::errors::InvalidArgument( + "Expected number of arguments is 3, but received %d.", num_args)); cublasHandle_t &cuhandle = CublasHandle::GetInstance().GetCublasHandle(); cinn_pod_value_t *args = static_cast(v_args); cudaStream_t custream = static_cast(stream); @@ -202,8 +207,10 @@ void cinn_call_cublas(void *v_args, } else if (is_bfloat16) { cuda_dtype = CUDA_R_16BF; } else { - LOG(FATAL) << "unsupported cublas data type: " - << static_cast(type_code) << ", bytes = " << bytes; + std::stringstream ss; + ss << "unsupported cublas data type: " << static_cast(type_code) + << ", bytes = " << bytes; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } if (a1 * a2 * b1 * b2 == 1) { @@ -404,7 +411,10 @@ void cinn_call_batched_cublas(void *v_args, int b4, void *stream) { // A * [B, C, D, ...] or [B, C, D, ...] * A - CHECK_EQ((num_args - 1) % 2, 0); + PADDLE_ENFORCE_EQ((num_args - 1) % 2, + 0, + phi::errors::PreconditionNotMet( + "(num_args - 1) should be divided by 2.")); cublasHandle_t &cuhandle = CublasHandle::GetInstance().GetCublasHandle(); cinn_pod_value_t *args = static_cast(v_args); cudaStream_t custream = static_cast(stream); @@ -424,8 +434,10 @@ void cinn_call_batched_cublas(void *v_args, } else if (is_bfloat16) { cuda_dtype = CUDA_R_16BF; } else { - LOG(FATAL) << "unsupported cublas data type: " - << static_cast(type_code) << ", bytes = " << bytes; + std::stringstream ss; + ss << "unsupported cublas data type: " << static_cast(type_code) + << ", bytes = " << bytes; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } int m = trans_o ? (trans_a ? a4 : a3) : (trans_b ? b3 : b4); @@ -481,7 +493,7 @@ void cinn_call_batched_cublas(void *v_args, void *B = args[1 + g].operator cinn_buffer_t *()->memory; void *C = args[1 + num_gemm + g].operator cinn_buffer_t *()->memory; - // if opside is 1, exhange A,B. + // if opside is 1, exchange A,B. if (opside) { auto tmp = A; A = B; @@ -533,7 +545,10 @@ void cinn_call_batched_cublas(void *v_args, void cinn_call_cuda_memset( void *v_args, int num_args, int value, size_t count, void *stream) { - CHECK_EQ(num_args, 1) << "The cinn_call_cuda_memset only accept a output"; + PADDLE_ENFORCE_EQ(num_args, + 1, + phi::errors::PreconditionNotMet( + "The cinn_call_cuda_memset only accept a output.")); VLOG(4) << "call cinn_call_cuda_memset with value=" << value << ", count=" << count; @@ -549,8 +564,11 @@ void cinn_call_cuda_memcpy(void *v_args, int num_args, size_t count, void *stream) { - CHECK_EQ(num_args, 2) - << "The cinn_call_cuda_memcpy only accept a input and a output"; + PADDLE_ENFORCE_EQ( + num_args, + 2, + phi::errors::PreconditionNotMet( + "The cinn_call_cuda_memset only accept a input and a output.")); VLOG(4) << "call cinn_call_cuda_memcpy with count=" << count; cinn_pod_value_t *args = static_cast(v_args); @@ -622,7 +640,10 @@ class ConvAlgoMap { }; cudnnDataType_t convert_to_cudnn_dtype(void *v_args, int num_args) { - CHECK_GT(num_args, 0) << "the number of arguments must larger than zero"; + PADDLE_ENFORCE_GT(num_args, + 0, + phi::errors::PreconditionNotMet( + "the number of arguments must larger than zero")); cinn_pod_value_t *args = static_cast(v_args); auto type_code = args[0].operator cinn_buffer_t *()->type.code; int bits = args[0].operator cinn_buffer_t *()->type.bits; @@ -630,7 +651,8 @@ cudnnDataType_t convert_to_cudnn_dtype(void *v_args, int num_args) { auto t = args[i].operator cinn_buffer_t *()->type.code; int b = args[0].operator cinn_buffer_t *()->type.bits; if (t != type_code || bits != b) { - LOG(FATAL) << "The types of all arguments need to be consistent."; + PADDLE_THROW(phi::errors::InvalidArgument( + "The types of all arguments need to be consistent.")); } } cudnnDataType_t data_type; @@ -645,8 +667,10 @@ cudnnDataType_t convert_to_cudnn_dtype(void *v_args, int num_args) { } else if (is_float && bits == 64) { data_type = CUDNN_DATA_DOUBLE; } else { - LOG(FATAL) << "unsupported cudnn data type: " << static_cast(type_code) - << ", bits = " << bits; + std::stringstream ss; + ss << "unsupported cudnn data type: " << static_cast(type_code) + << ", bits = " << bits; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return data_type; } @@ -660,8 +684,9 @@ cudnnDataType_t get_cudnn_compute_dtype(cudnnDataType_t data_type) { case CUDNN_DATA_DOUBLE: return CUDNN_DATA_DOUBLE; default: - LOG(FATAL) << "unsupported cudnn data type, only support " - "float16/bfloat16/float32/float64 now!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "unsupported cudnn data type, only support " + "float16/bfloat16/float32/float64 now!")); } return CUDNN_DATA_FLOAT; } @@ -673,7 +698,8 @@ std::string debug_cudnn_tensor_format(cudnnTensorFormat_t tensor_format) { case CUDNN_TENSOR_NHWC: return "NHWC"; default: - LOG(FATAL) << "Only support NCHW and NHWC data layout\n"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support NCHW and NHWC data layout\n")); } return ""; } @@ -689,7 +715,8 @@ std::string debug_cudnn_tensor_dtype(cudnnDataType_t tensor_dtype) { case CUDNN_DATA_DOUBLE: return "float64"; default: - LOG(FATAL) << "Only support float16/bfloat16/float32/float64 now!"; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support float16/bfloat16/float32/float64 now!")); } return ""; } @@ -703,9 +730,10 @@ std::string debug_cudnn_pool_mode(cudnnPoolingMode_t pool_mode) { case CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING: return "avg_include_padding"; case CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING: - return "avg_exclulude_padding"; + return "avg_exclude_padding"; default: - LOG(FATAL) << "Pool only support max and avg now!"; + PADDLE_THROW( + phi::errors::InvalidArgument("Pool only support max and avg now!")); } return ""; } @@ -735,7 +763,11 @@ void cinn_call_cudnn_conv2d_forward(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 3); + PADDLE_ENFORCE_EQ( + num_args, + 3, + phi::errors::InvalidArgument( + "Expected number of argruments is 3, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -885,7 +917,11 @@ void cinn_call_cudnn_conv2d_backward_data(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 3); + PADDLE_ENFORCE_EQ( + num_args, + 3, + phi::errors::InvalidArgument( + "Expected number of argruments is 3, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -1038,7 +1074,11 @@ void cinn_call_cudnn_conv2d_backward_filter(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 3); + PADDLE_ENFORCE_EQ( + num_args, + 3, + phi::errors::InvalidArgument( + "Expected number of argruments is 3, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -1188,7 +1228,11 @@ void cinn_call_cudnn_pool2d_forward(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 2); + PADDLE_ENFORCE_EQ( + num_args, + 2, + phi::errors::InvalidArgument( + "Expected number of argruments is 2, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -1282,7 +1326,11 @@ void cinn_call_cudnn_pool2d_backward(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 4); + PADDLE_ENFORCE_EQ( + num_args, + 4, + phi::errors::InvalidArgument( + "Expected number of argruments is 4, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -1392,7 +1440,11 @@ void cinn_call_cudnn_softmax_forward(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 2); + PADDLE_ENFORCE_EQ( + num_args, + 2, + phi::errors::InvalidArgument( + "Expected number of argruments is 2, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -1462,7 +1514,11 @@ void cinn_call_cudnn_softmax_backward(void *v_args, int output_h, int output_w, void *stream) { - CHECK_EQ(num_args, 3); + PADDLE_ENFORCE_EQ( + num_args, + 3, + phi::errors::InvalidArgument( + "Expected number of argruments is 3, but recived %d.", num_args)); cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); cinn_pod_value_t *args = static_cast(v_args); @@ -1558,9 +1614,12 @@ void Gemm(const cublasHandle_t &cublas, } int contracting_size = lhs_trans ? lhs_row : lhs_col; - CHECK_EQ(contracting_size, (rhs_trans ? rhs_col : rhs_row)) - << "The contracting dimension value of lhs matrix should be equal to the " - "one of rhs matrix."; + PADDLE_ENFORCE_EQ( + contracting_size, + (rhs_trans ? rhs_col : rhs_row), + phi::errors::PreconditionNotMet("The contracting dimension value of lhs " + "matrix should be equal to the " + "one of rhs matrix.")); auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; cublasSgemm(cublas, @@ -1601,8 +1660,14 @@ void GemmStridedBatched(const cublasHandle_t &cublas, int output_bs = output_shape[0]; int output_row = output_shape[1]; int output_col = output_shape[2]; - CHECK_EQ(lhs_bs, rhs_bs); - CHECK_EQ(lhs_bs, output_bs); + PADDLE_ENFORCE_EQ( + lhs_bs, + rhs_bs, + phi::errors::InvalidArgument("bs of lhs and rhs dismatch.")); + PADDLE_ENFORCE_EQ( + lhs_bs, + output_bs, + phi::errors::InvalidArgument("bs of lhs and output dismatch.")); // copy values of bias_data to the output_data if (bias_data != nullptr) { @@ -1614,9 +1679,12 @@ void GemmStridedBatched(const cublasHandle_t &cublas, } int contracting_size = lhs_trans ? lhs_row : lhs_col; - CHECK_EQ(contracting_size, (rhs_trans ? rhs_col : rhs_row)) - << "The contracting dimension value of lhs matrix should be equal to the " - "one of rhs matrix."; + PADDLE_ENFORCE_EQ( + contracting_size, + (rhs_trans ? rhs_col : rhs_row), + phi::errors::PreconditionNotMet("The contracting dimension value of lhs " + "matrix should be equal to the " + "one of rhs matrix.")); auto trans_a = rhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; auto trans_b = lhs_trans ? CUBLAS_OP_T : CUBLAS_OP_N; int64_t lhs_stride = lhs_row * lhs_col; @@ -1677,9 +1745,17 @@ void cinn_call_cholesky_nvgpu(void *v_args, size_t numel = x->num_elements(); uint8_t bits = x->type.bits; uint8_t bytes = bits / 8; - CHECK_EQ(x->type.code, cinn_type_code_t::cinn_type_float); - CHECK(bits == 32 || bits == 64) - << "Unsupported bits = " << bits << " float data type for cholesky"; + PADDLE_ENFORCE_EQ( + x->type.code, + cinn_type_code_t::cinn_type_float, + phi::errors::InvalidArgument("x's type code (%d) is inequal to %d.", + x->type.code, + cinn_type_code_t::cinn_type_float)); + PADDLE_ENFORCE_EQ( + bits == 32 || bits == 64, + true, + phi::errors::InvalidArgument( + "Unsupported bits = %d float data type for cholesky", bits)); auto cuda_stream = static_cast(stream); @@ -1724,9 +1800,12 @@ void cinn_call_cholesky_nvgpu(void *v_args, // Check result thrust::copy(dev_info.begin(), dev_info.end(), host_info.begin()); for (int i = 0; i < host_info.size(); i++) { - CHECK_EQ(host_info[i], 0) - << "Cholesky decomposition fail, please check the " << i + 1 - << "th input matrix."; + PADDLE_ENFORCE_EQ(host_info[i], + 0, + phi::errors::PreconditionNotMet( + "Cholesky decomposition fail, please check the %d" + "th input matrix.", + i + 1)); } } @@ -1760,13 +1839,29 @@ void cinn_call_triangular_solve_nvgpu(void *v_args, cinn_buffer_t *input2 = args[1].operator cinn_buffer_t *(); cinn_buffer_t *output = args[2].operator cinn_buffer_t *(); - CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float); - CHECK_EQ(input2->type.code, cinn_type_code_t::cinn_type_float); - CHECK_EQ(input1->type.bits, input2->type.bits); + PADDLE_ENFORCE_EQ( + input1->type.code, + cinn_type_code_t::cinn_type_float, + phi::errors::InvalidArgument("input1's type code (%d) is inequal to %d.", + input1->type.code, + cinn_type_code_t::cinn_type_float)); + PADDLE_ENFORCE_EQ( + input2->type.code, + cinn_type_code_t::cinn_type_float, + phi::errors::InvalidArgument("input1's type code (%d) is inequal to %d.", + input2->type.code, + cinn_type_code_t::cinn_type_float)); + PADDLE_ENFORCE_EQ(input1->type.bits, + input2->type.bits, + phi::errors::InvalidArgument( + "input1 and ipnput2's type bits is dismatch.")); uint8_t bits = input1->type.bits; uint8_t bytes = bits / 8; - CHECK(bits == 32 || bits == 64) << "unsupported bits = " << bits - << " float data type for triangular solve"; + PADDLE_ENFORCE_EQ( + bits == 32 || bits == 64, + true, + phi::errors::InvalidArgument( + "Unsupported bits = %d float data type for triangular solve", bits)); std::string debug_info = "triangular solve op: left_side=" + std::to_string(left_side) + @@ -1852,14 +1947,23 @@ void cinn_gpu_cublas_mul(const std::vector &attrs, cinn_buffer_t *output, cudaStream_t stream) { cublasHandle_t &handle = CublasHandle::GetInstance().GetCublasHandle(); - CHECK_EQ(input1->type.code, cinn_type_code_t::cinn_type_float); + PADDLE_ENFORCE_EQ(input1->type.code, + cinn_type_code_t::cinn_type_float, + phi::errors::InvalidArgument( + "Expected type code of input is %d, but received %d.", + cinn_type_code_t::cinn_type_float, + input1->type.code)); cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(handle, custream)); float *x_data = reinterpret_cast(input1->memory); float *y_data = reinterpret_cast(input2->memory); float *out_data = reinterpret_cast(output->memory); int M = 1; - CHECK_GE(attrs.size(), 6); + PADDLE_ENFORCE_GE(attrs.size(), + 6, + phi::errors::InvalidArgument( + "Expected size of attributions is 6, but received %d.", + attrs.size())); for (int i = 0; i < attrs[attrs.size() - 2]; i++) { M *= attrs[i]; } @@ -1894,14 +1998,24 @@ void cinn_gpu_cublas_gemm(const std::vector &attrs, cudaStream_t custream = static_cast(stream); CUBLAS_CALL(cublasSetStream(handle, custream)); - CHECK_EQ(lhs->type.code, cinn_type_code_t::cinn_type_float); + PADDLE_ENFORCE_EQ( + lhs->type.code, + cinn_type_code_t::cinn_type_float, + phi::errors::InvalidArgument("lhs's type code (%d) is inequal to %d.", + lhs->type.code, + cinn_type_code_t::cinn_type_float)); const float *lhs_data = reinterpret_cast(lhs->memory); const float *rhs_data = reinterpret_cast(rhs->memory); const float *bias_data = bias ? reinterpret_cast(bias->memory) : nullptr; float *output_data = reinterpret_cast(output->memory); - CHECK_GE(attrs.size(), 13); + PADDLE_ENFORCE_GE(attrs.size(), + 13, + phi::errors::InvalidArgument( + "Expected size of attributions is greater or " + "qeual to 13, but received %d.", + attrs.size())); int lhs_dim_size = attrs[attrs.size() - 7]; int rhs_dim_size = attrs[attrs.size() - 6]; int out_dim_size = attrs[attrs.size() - 5]; @@ -1924,9 +2038,18 @@ void cinn_gpu_cublas_gemm(const std::vector &attrs, VLOG(4) << "The out_trans value used by cinn_gpu_cublas_gemm: " << out_trans; VLOG(4) << "The alpha value used by cinn_gpu_cublas_gemm: " << alpha; VLOG(4) << "The beta value used by cinn_gpu_cublas_gemm: " << beta; - CHECK_EQ(lhs_dim_size, rhs_dim_size); - CHECK_EQ(lhs_dim_size, out_dim_size); - CHECK((lhs_dim_size == 2 || lhs_dim_size == 3)); + PADDLE_ENFORCE_EQ( + lhs_dim_size, + rhs_dim_size, + phi::errors::InvalidArgument("dimension dismatch between lhs and rhs.")); + PADDLE_ENFORCE_EQ( + lhs_dim_size, + out_dim_size, + phi::errors::InvalidArgument("dimension dismatch between lhs and out.")); + PADDLE_ENFORCE_EQ( + (lhs_dim_size == 2 || lhs_dim_size == 3), + true, + phi::errors::InvalidArgument("left operand has 2 or 3 dimension.")); if (lhs_dim_size == 2) { // [row, col] @@ -2076,8 +2199,8 @@ void cinn_call_gaussian_random( double *ptr = reinterpret_cast(output->memory); CURAND_CALL(curandGenerateNormalDouble(generator, ptr, numel, mean, std)); } else { - LOG(FATAL) - << "gaussian_random only support float32 and float64! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "gaussian_random only support float32 and float64! Please check.")); } } @@ -2105,8 +2228,8 @@ void cinn_call_uniform_random( double *ptr = reinterpret_cast(output->memory); CURAND_CALL(curandGenerateUniformDouble(generator, ptr, numel)); } else { - LOG(FATAL) - << "uniform_random only support float32 and float64! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "uniform_random only support float32 and float64! Please check.")); } } @@ -2129,7 +2252,8 @@ void cinn_call_randint(void *v_args, int num_args, int seed, void *stream) { unsigned int *ptr = reinterpret_cast(output->memory); CURAND_CALL(curandGenerate(generator, ptr, numel)); } else { - LOG(FATAL) << "randint only support int32! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "randint only support int32! Please check.")); } } @@ -2137,7 +2261,8 @@ void cinn_call_randint(void *v_args, int num_args, int seed, void *stream) { namespace { cudnnDataType_t convert_to_cudnn_dtype(cinn_buffer_t *input) { - CHECK(input) << "the pointer of input is null"; + PADDLE_ENFORCE_NOT_NULL( + input, phi::errors::NotFound("the pointer of input is null")); auto type_code = input->type.code; int bits = input->type.bits; cudnnDataType_t data_type; @@ -2152,21 +2277,25 @@ cudnnDataType_t convert_to_cudnn_dtype(cinn_buffer_t *input) { } else if (is_float && bits == 64) { data_type = CUDNN_DATA_DOUBLE; } else { - LOG(FATAL) << "unsupported cudnn data type: " << static_cast(type_code) - << ", bits = " << bits; + std::stringstream ss; + ss << "unsupported cudnn data type: " << static_cast(type_code) + << ", bits = " << bits; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } return data_type; } } // namespace -#define GetAttrValue(attr_map, key_name, default_value) \ - int key_name = 0; \ - if (attr_map.count(#key_name) != 0) { \ - key_name = attr_map.find(#key_name)->second; \ - } else if (default_value >= 0) { \ - key_name = default_value; \ - } else { \ - LOG(FATAL) << #key_name << " is not exist in attr_map!"; \ +#define GetAttrValue(attr_map, key_name, default_value) \ + int key_name = 0; \ + if (attr_map.count(#key_name) != 0) { \ + key_name = attr_map.find(#key_name)->second; \ + } else if (default_value >= 0) { \ + key_name = default_value; \ + } else { \ + std::stringstream ss; \ + ss << #key_name << " is not exist in attr_map!"; \ + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); \ } void cinn_gpu_cudnn_conv2d(const absl::flat_hash_map &attr, @@ -2645,7 +2774,11 @@ void cinn_gpu_cudnn_pool2d(const std::vector &attrs, cudaStream_t stream) { cudnnHandle_t &handle = CudnnHandle::GetInstance().GetCudnnHandle(); CUDNN_CALL(cudnnSetStream(handle, static_cast(stream))); - CHECK_EQ(attrs.size(), 17); + PADDLE_ENFORCE_EQ(attrs.size(), + 17, + phi::errors::InvalidArgument( + "Expected size of attributions is 17, but received %d.", + attrs.size())); // Here the input paddings are pad_top, pad_bottom, pad_left, pad_right. // Since pad_top==pad_bottom and pad_left==pad_rifht, we only take pad_top and // pad_left. diff --git a/paddle/cinn/runtime/custom_function.cc b/paddle/cinn/runtime/custom_function.cc index 08fe5c1bd7f35..d424755d56b49 100644 --- a/paddle/cinn/runtime/custom_function.cc +++ b/paddle/cinn/runtime/custom_function.cc @@ -37,8 +37,10 @@ void AssertTrueMsgTool::SetMsg(int key, const std::string& msg) { } const std::string& AssertTrueMsgTool::GetMsg(int key) { - CHECK(global_msg_.find(key) != global_msg_.end()) - << "Cannot find assert_true message key " << key; + PADDLE_ENFORCE_NE( + global_msg_.find(key), + global_msg_.end(), + phi::errors::NotFound("Cannot find assert_true message key (%d).", key)); return global_msg_[key]; } @@ -69,9 +71,12 @@ void AssertTrueMsgTool::InitFlagInfo() { continue; } const auto& flag_arg = cinn::utils::Split(str, "="); - CHECK_EQ(flag_arg.size(), 2UL) - << "The FLAGS_cinn_check_fusion_accuracy_pass must be the format of " - "\"only_warning=false;rtol=1e-5;atol=1e-8;equal_nan=false\""; + PADDLE_ENFORCE_EQ( + flag_arg.size(), + 2UL, + phi::errors::InvalidArgument( + "The FLAGS_cinn_check_fusion_accuracy_pass must be the format of " + "\"only_warning=false;rtol=1e-5;atol=1e-8;equal_nan=false\".")); if (flag_arg[0] == "only_warning" || flag_arg[0] == "equal_nan") { // bool type parameter @@ -80,9 +85,9 @@ void AssertTrueMsgTool::InitFlagInfo() { // string type parameter flag_values_[flag_arg[0]] = std::stof(flag_arg[1]); } else { - LOG(FATAL) - << "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " - "\"only_warning/rtol/atol/equal_nan\" now"; + PADDLE_THROW(phi::errors::InvalidArgument( + "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " + "\"only_warning/rtol/atol/equal_nan\" now")); } } @@ -111,8 +116,8 @@ bool MemcpyToHost(void* dst, cudaStreamSynchronize(cuda_stream); return true; #else - LOG(FATAL) - << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check.")); return false; #endif } @@ -120,9 +125,11 @@ bool MemcpyToHost(void* dst, memcpy(dst, src, bytes); return true; } - LOG(FATAL) << "MemcpyToHost Only support cpu or nvgpu -> cpu, but here the " - "input target is " - << input_target << "! Please check."; + std::stringstream ss; + ss << "MemcpyToHost Only support cpu or nvgpu -> cpu, but here the " + "input target is " + << input_target << "! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return false; } @@ -147,14 +154,17 @@ bool MemcpyToDevice(void* dst, static_cast(stream)); return true; } else { - LOG(FATAL) << "MemcpyToDevice only support cpu or nvgpu -> nvgpu, but here " - "the input target is " - << input_target << "! Please check."; + std::stringstream ss; + ss << "MemcpyToDevice only support cpu or nvgpu -> nvgpu, but here " + "the input target is " + << input_target << "! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return false; } #else - LOG(FATAL) << "MemcpyToDevice only support nvgpu, and NVGPU Target only " - "support when flag CINN_WITH_CUDA ON! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "MemcpyToDevice only support nvgpu, and NVGPU Target only " + "support when flag CINN_WITH_CUDA ON! Please check.")); return false; #endif } @@ -187,7 +197,7 @@ void CheckAssertTrue(const bool* x, if (only_warning) { LOG(WARNING) << error_info; } else { - LOG(FATAL) << error_info; + PADDLE_THROW(phi::errors::InvalidArgument(error_info)); } } else { VLOG(1) << "[AssertTrue] Check succeed!\n" diff --git a/paddle/cinn/runtime/custom_function.h b/paddle/cinn/runtime/custom_function.h index 103da8b5eba89..7fa669a8037ec 100644 --- a/paddle/cinn/runtime/custom_function.h +++ b/paddle/cinn/runtime/custom_function.h @@ -22,6 +22,7 @@ #include "paddle/cinn/hlir/framework/tensor.h" #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/utils/type_defs.h" +#include "paddle/common/enforce.h" namespace cinn { namespace runtime { @@ -42,11 +43,16 @@ class AssertTrueMsgTool { template const T& GetFlagValue(const std::string& param) { InitFlagInfo(); - CHECK(flag_values_.count(param)) - << "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " - "\"only_warning/rtol/atol/equal_nan\" now"; - CHECK(absl::holds_alternative(flag_values_.at(param))) - << "Try get value from a error type!"; + PADDLE_ENFORCE_GT( + flag_values_.count(param), + 0, + phi::errors::InvalidArgument( + "The FLAGS_cinn_check_fusion_accuracy_pass only support parameter " + "\"only_warning/rtol/atol/equal_nan\" now.")); + PADDLE_ENFORCE_GT( + absl::holds_alternative(flag_values_.at(param)), + 0, + phi::errors::InvalidArgument("Try get value from a error type!")); return absl::get(flag_values_.at(param)); } diff --git a/paddle/cinn/runtime/custom_function_test.cc b/paddle/cinn/runtime/custom_function_test.cc index b2dc09b1862f0..2ec40f110966f 100644 --- a/paddle/cinn/runtime/custom_function_test.cc +++ b/paddle/cinn/runtime/custom_function_test.cc @@ -46,9 +46,12 @@ class CinnBufferAllocHelper { template T* mutable_data(const Target& target) { if (target_ != cinn::common::UnkTarget()) { - CHECK_EQ(target, target_) - << "Cannot alloc twice, the memory had alloced at " << target_ - << "! Please check."; + PADDLE_ENFORCE_EQ( + target, + target_, + phi::errors::AlreadyExists( + "Cannot alloc twice, the memory had alloced at %d! Please check.", + target_)); return reinterpret_cast(buffer_->memory); } @@ -59,12 +62,15 @@ class CinnBufferAllocHelper { #ifdef CINN_WITH_CUDA cudaMalloc(&buffer_->memory, buffer_->num_elements() * sizeof(T)); #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! " - "Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! " + "Please check.")); #endif } else { - LOG(FATAL) << "Only support nvgpu and cpu, but here " << target - << "! Please check."; + std::stringstream ss; + ss << "Only support nvgpu and cpu, but here " << target + << "! Please check."; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } return reinterpret_cast(buffer_->memory); @@ -73,7 +79,7 @@ class CinnBufferAllocHelper { template const T* data() { if (target_ == cinn::common::UnkTarget()) { - LOG(FATAL) << "No memory had alloced! Please check."; + PADDLE_THROW(phi::errors::Fatal("No memory had alloced! Please check.")); } return reinterpret_cast(buffer_->memory); } @@ -88,12 +94,15 @@ class CinnBufferAllocHelper { #ifdef CINN_WITH_CUDA cudaFree(buffer_->memory); #else - LOG(FATAL) << "NVGPU Target only support on flag CINN_WITH_CUDA ON! " - "Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! " + "Please check.")); #endif } else { - LOG(FATAL) << "Only support nvgpu and cpu, but here " << target_ - << "! Please check."; + std::stringstream ss; + ss << "Only support nvgpu and cpu, but here " << target_ + << "! Please check."; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } delete buffer_; } @@ -121,8 +130,8 @@ void SetInputValue(T* input, #ifdef CINN_WITH_CUDA cudaMemcpy(input, input_h, num * sizeof(T), cudaMemcpyHostToDevice); #else - LOG(FATAL) - << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check.")); #endif } } @@ -233,8 +242,8 @@ TEST(CustomCallGaussianRandom, test_target_nvgpu) { VLOG(6) << output_data[i]; } #else - LOG(FATAL) - << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check.")); #endif } } @@ -269,8 +278,8 @@ TEST(CustomCallUniformRandom, test_target_nvgpu) { VLOG(6) << output_data[i]; } #else - LOG(FATAL) - << "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check."; + PADDLE_THROW(phi::errors::Fatal( + "NVGPU Target only support on flag CINN_WITH_CUDA ON! Please check.")); #endif } } diff --git a/paddle/cinn/runtime/flags.cc b/paddle/cinn/runtime/flags.cc index 89512913e8fa9..e4fd6e31f665a 100644 --- a/paddle/cinn/runtime/flags.cc +++ b/paddle/cinn/runtime/flags.cc @@ -22,6 +22,7 @@ #include #include "paddle/cinn/common/target.h" +#include "paddle/common/enforce.h" #include "paddle/common/flags.h" #ifdef CINN_WITH_CUDNN @@ -69,6 +70,19 @@ PD_DEFINE_bool(cinn_bucket_compile, BoolFromEnv("FLAGS_cinn_bucket_compile", false), "Whether to enable bucket compile for dynamic shape."); +PD_DEFINE_bool(group_schedule_tiling_first, + BoolFromEnv("FLAGS_group_schedule_tiling_first", false), + "Whether to enable new group scheduler tiling first strategy."); + +PD_DEFINE_bool(cinn_new_cluster_op_method, + BoolFromEnv("FLAGS_cinn_new_cluster_op_method", true), + "Whether to enable newly developed clustering method of group " + "op for cinn."); + +PD_DEFINE_bool(support_reduce_stride_read, + BoolFromEnv("FLAGS_support_reduce_stride_read", false), + "Whether to enable new group scheduler tiling first strategy."); + PD_DEFINE_bool(cinn_use_common_subexpression_elimination, BoolFromEnv("FLAGS_cinn_use_common_subexpression_elimination", false), @@ -128,7 +142,7 @@ PD_DEFINE_bool(cinn_use_dense_merge_pass, PD_DEFINE_bool( nvrtc_compile_to_cubin, - BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", false), + BoolFromEnv("FLAGS_nvrtc_compile_to_cubin", true), "Whether nvrtc compile cuda source into cubin instead of ptx (only " "works after cuda-11.1)."); @@ -286,7 +300,8 @@ bool GetCinnCudnnDeterministic() { #ifdef CINN_WITH_CUDNN return FLAGS_cinn_cudnn_deterministic; #else - LOG(FATAL) << "CINN is compiled without cuDNN, this api is invalid!"; + PADDLE_THROW(phi::errors::Fatal( + "CINN is compiled without cuDNN, this api is invalid!")); return false; #endif } @@ -333,8 +348,9 @@ cinn::common::Target CurrentTarget::target_ = cinn::common::DefaultTarget(); void CurrentTarget::SetCurrentTarget(const cinn::common::Target& target) { if (!IsCompiledWithCUDA() && target.arch == cinn::common::Target::Arch::NVGPU) { - LOG(FATAL) << "Current CINN version does not support NVGPU, please try to " - "recompile with -DWITH_CUDA."; + PADDLE_THROW(phi::errors::Fatal( + "Current CINN version does not support NVGPU, please try to " + "recompile with -DWITH_CUDA.")); } else { target_ = target; } diff --git a/paddle/cinn/runtime/intrinsic.cc b/paddle/cinn/runtime/intrinsic.cc index eb68cb5637cf3..6bf5ac17c506e 100644 --- a/paddle/cinn/runtime/intrinsic.cc +++ b/paddle/cinn/runtime/intrinsic.cc @@ -51,7 +51,9 @@ cinn_type_t ToRuntimeType(Type type) { SET_TYPE_CASE_ITEM(Float16().PointerOf, cinn_type_of); SET_TYPE_CASE_ITEM(BFloat16().PointerOf, cinn_type_of); - LOG(FATAL) << "Not supported type " << type; + std::stringstream ss; + ss << "Not supported type " << type; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); return cinn_unk_t(); #undef SET_TYPE_CASE_ITEM } diff --git a/paddle/cinn/runtime/intrinsic_types.h b/paddle/cinn/runtime/intrinsic_types.h index 6a6c460e6323c..2e547ca1e3875 100644 --- a/paddle/cinn/runtime/intrinsic_types.h +++ b/paddle/cinn/runtime/intrinsic_types.h @@ -18,6 +18,7 @@ */ #include "paddle/cinn/common/common.h" +#include "paddle/common/enforce.h" namespace cinn { namespace runtime { @@ -35,8 +36,10 @@ struct BufferType { private: explicit BufferType(const Type& primitive_type) : primitive_type(primitive_type) { - CHECK(primitive_type.valid()); - CHECK(primitive_type.is_primitive()); + PADDLE_ENFORCE_EQ(primitive_type.valid() && primitive_type.is_primitive(), + true, + phi::errors::InvalidArgument( + "primitive type should be valid and primitive.")); } //! Determine the primitive of cinn_buffer_t. @@ -45,8 +48,10 @@ struct BufferType { }; static Type make_intrinsic_buffer_type(Type primitive_type) { - CHECK(primitive_type.is_primitive()); - CHECK(primitive_type.valid()); + PADDLE_ENFORCE_EQ(primitive_type.valid() && primitive_type.is_primitive(), + true, + phi::errors::InvalidArgument( + "primitive type should be valid and primitive.")); Type res = BufferType::cinn_type(); return res; } diff --git a/paddle/cinn/utils/CMakeLists.txt b/paddle/cinn/utils/CMakeLists.txt index 39e37b5a3471b..afcad3e82f381 100755 --- a/paddle/cinn/utils/CMakeLists.txt +++ b/paddle/cinn/utils/CMakeLists.txt @@ -14,7 +14,8 @@ gather_srcs( event.cc multi_threading.cc data_util.cc - random_engine.cc) + random_engine.cc + external_func_names.cc) cinn_cc_test(test_string SRCS string_test.cc DEPS cinncore) cinn_cc_test(test_sized_multi_set SRCS sized_multi_set_test.cc DEPS cinncore) diff --git a/paddle/cinn/utils/error.h b/paddle/cinn/utils/error.h index 7b5af324d7081..2b6795571c509 100644 --- a/paddle/cinn/utils/error.h +++ b/paddle/cinn/utils/error.h @@ -113,15 +113,6 @@ struct EnforceNotMet : public std::exception { std::string err_str_; }; -#define CINN_THROW(...) \ - do { \ - try { \ - throw utils::enforce::EnforceNotMet(__VA_ARGS__, __FILE__, __LINE__); \ - } catch (const std::exception& e) { \ - std::cout << e.what() << std::endl; \ - throw; \ - } \ - } while (0) } // namespace enforce /** diff --git a/paddle/cinn/utils/event.cc b/paddle/cinn/utils/event.cc index ca06ae73c6766..7ec7769c99230 100644 --- a/paddle/cinn/utils/event.cc +++ b/paddle/cinn/utils/event.cc @@ -15,9 +15,9 @@ #include "paddle/cinn/utils/event.h" #include // for GLog - #include +#include "paddle/common/enforce.h" namespace cinn { namespace utils { inline std::string EventTypeToString(const EventType &type) { @@ -43,7 +43,7 @@ inline std::string EventTypeToString(const EventType &type) { case EventType::kInstruction: return "Instruction"; default: - LOG(FATAL) << "Unknown event type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown event type")); } } diff --git a/paddle/cinn/utils/external_func_names.cc b/paddle/cinn/utils/external_func_names.cc new file mode 100644 index 0000000000000..ee0ad4e112d9d --- /dev/null +++ b/paddle/cinn/utils/external_func_names.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/cinn/utils/external_func_names.h" + +namespace cinn::utils { + +const std::unordered_set& GetProhibitScheduleExternalFuncNames() { + static const std::unordered_set + prohibit_schedule_external_func_names = { +#define CINN_FUNC2STRING(str) #str +#define CINN_NVGPU_FUNC_TYPE(FUNC, TYPE) \ + CINN_FUNC2STRING(cinn_nvgpu_##FUNC##TYPE), \ + CINN_FUNC2STRING(cinn_host_##FUNC##TYPE) + +#define GEN_FUNC_NAME(_, impl) \ + _(impl, gt_num) \ + _(impl, lt_num) \ + _(impl, index_add) \ + _(impl, next_smallest) + +#define GEN_FUNC_NAME_WITH_TYPE(_, ...) \ + _(__VA_ARGS__, _bool), _(__VA_ARGS__, _fp16), _(__VA_ARGS__, _fp32), \ + _(__VA_ARGS__, _fp64), _(__VA_ARGS__, _uint8), _(__VA_ARGS__, _int8), \ + _(__VA_ARGS__, _int16), _(__VA_ARGS__, _int32), _(__VA_ARGS__, _int64), + + GEN_FUNC_NAME(GEN_FUNC_NAME_WITH_TYPE, CINN_NVGPU_FUNC_TYPE) +#undef GEN_FUNC_NAME +#undef GEN_FUNC_NAME_WITH_TYPE +#undef CINN_NVGPU_FUNC_TYPE +#undef CINN_FUNC2STRING + }; + return prohibit_schedule_external_func_names; +} + +} // namespace cinn::utils diff --git a/paddle/fluid/string/pretty_log.h b/paddle/cinn/utils/external_func_names.h similarity index 72% rename from paddle/fluid/string/pretty_log.h rename to paddle/cinn/utils/external_func_names.h index dc80e59d613e3..47585c218e64c 100644 --- a/paddle/fluid/string/pretty_log.h +++ b/paddle/cinn/utils/external_func_names.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -11,12 +11,14 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + #pragma once -#include -#include #include -#include +#include + +namespace cinn::utils { + +const std::unordered_set& GetProhibitScheduleExternalFuncNames(); -#include "paddle/common/flags.h" -#include "paddle/utils/string/pretty_log.h" +} // namespace cinn::utils diff --git a/paddle/cinn/utils/multi_threading.cc b/paddle/cinn/utils/multi_threading.cc index d4031431d0e34..27aed61186b77 100644 --- a/paddle/cinn/utils/multi_threading.cc +++ b/paddle/cinn/utils/multi_threading.cc @@ -20,16 +20,20 @@ #include #include #include - #include "paddle/cinn/utils/string.h" +#include "paddle/common/enforce.h" namespace cinn { namespace utils { SequenceDispatcher::SequenceDispatcher(int begin, int end, int step) : end_(end), step_(step), index_(begin) { - CHECK_LE(begin, end) << StringFormat("begin[%d] > end[%d]", begin, end); - CHECK_GT(step, 0) << "step is less than 0"; + PADDLE_ENFORCE_LE( + begin, + end, + phi::errors::InvalidArgument("begin[%d] > end[%d]", begin, end)); + PADDLE_ENFORCE_GT( + step, 0, phi::errors::InvalidArgument("step is less than 0.")); } int SequenceDispatcher::Next() const { @@ -47,7 +51,10 @@ void parallel_run(const WorkerFuncType& fn, if (num_threads == -1 || num_threads > std::thread::hardware_concurrency()) { num_threads = std::thread::hardware_concurrency(); } - CHECK_GT(num_threads, 0) << "num_threads should be greater than 0"; + PADDLE_ENFORCE_GT( + num_threads, + 0, + phi::errors::PreconditionNotMet("num_threads should be greater than 0")); // worker function of a thread auto worker = [&fn, &dispatcher](int tid) -> int { @@ -86,7 +93,9 @@ void parallel_run(const WorkerFuncType& fn, VLOG(4) << "Thread-" << tid << " process " << counter << " tasks."; } } catch (const std::exception& e) { - LOG(FATAL) << "parallel_run incurs error: " << e.what(); + std::stringstream ss; + ss << "parallel_run incurs error: " << e.what(); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } // join threads diff --git a/paddle/cinn/utils/multi_threading_test.cc b/paddle/cinn/utils/multi_threading_test.cc index bd081fea2b56c..2abf7111c3488 100644 --- a/paddle/cinn/utils/multi_threading_test.cc +++ b/paddle/cinn/utils/multi_threading_test.cc @@ -20,6 +20,8 @@ #include #include +#include "paddle/common/enforce.h" + namespace cinn { namespace utils { @@ -35,7 +37,8 @@ TEST(JobDispatcher, SequenceDispatcher) { TEST(parallel_run, Basic) { std::vector results(100, -1); auto worker_fn = [&results](int index) { - CHECK_LT(index, results.size()) << "index invalid"; + PADDLE_ENFORCE_LT( + index, results.size(), phi::errors::InvalidArgument("invalid index!")); results[index] = index; }; // check process every index in the extent of [0, 100) with step 1 diff --git a/paddle/cinn/utils/random_engine.h b/paddle/cinn/utils/random_engine.h index 49e8e6ecfd2a2..c0afc2dd36941 100644 --- a/paddle/cinn/utils/random_engine.h +++ b/paddle/cinn/utils/random_engine.h @@ -18,6 +18,7 @@ #include #include +#include "paddle/common/enforce.h" namespace cinn { namespace utils { @@ -69,7 +70,10 @@ class LinearRandomEngine { if (state == 0) { state = 1; } - CHECK_GE(state, 0) << "Random seed must be greater than 0"; + PADDLE_ENFORCE_GE( + state, + 0, + phi::errors::PreconditionNotMet("Random seed must be greater than 0")); return state; } @@ -109,7 +113,10 @@ double SampleUniformDouble(double min, template int SampleDiscreteFromDistribution(const std::vector& weights, LinearRandomEngine::StateType* rand_seed) { - CHECK_GT(weights.size(), 0); + PADDLE_ENFORCE_GT( + weights.size(), + 0, + phi::errors::PreconditionNotMet("Size of target weights is empty.")); LinearRandomEngine engine(rand_seed); std::discrete_distribution dist(weights.begin(), weights.end()); return dist(engine); diff --git a/paddle/cinn/utils/sized_multi_set.h b/paddle/cinn/utils/sized_multi_set.h index d36fb7a01920b..96e32ab32f58c 100644 --- a/paddle/cinn/utils/sized_multi_set.h +++ b/paddle/cinn/utils/sized_multi_set.h @@ -19,6 +19,7 @@ #include #include #include +#include "paddle/common/enforce.h" namespace cinn { namespace utils { @@ -55,7 +56,10 @@ class SizedMultiSet { } void Pop() { - CHECK_GE(multi_set_.size(), 1UL) << "Call Pop on empty SizedMultiSet"; + PADDLE_ENFORCE_GE( + multi_set_.size(), + 1UL, + phi::errors::PreconditionNotMet("Call Pop on empty SizedMultiSet.")); if (pop_max_when_full_) { multi_set_.erase(--multi_set_.end()); } else { diff --git a/paddle/cinn/utils/string.cc b/paddle/cinn/utils/string.cc index 5e6560551c068..51813f2fcaf48 100644 --- a/paddle/cinn/utils/string.cc +++ b/paddle/cinn/utils/string.cc @@ -20,6 +20,7 @@ #include #include "glog/logging.h" +#include "paddle/common/enforce.h" namespace cinn { namespace utils { @@ -174,7 +175,8 @@ std::string Attribute2String(const utils::Attribute &attr) { } ss << "[" + cinn::utils::Join(attrs, ", ") + "]"; } else { - LOG(FATAL) << "Unkown attribute data type! Please check."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Unkown attribute data type! Please check.")); } return ss.str(); } diff --git a/paddle/common/array.h b/paddle/common/array.h index d389b4d2288ca..0c90f6ae9f985 100644 --- a/paddle/common/array.h +++ b/paddle/common/array.h @@ -109,7 +109,7 @@ class Array { static T obj{}; return obj; #else - COMMON_THROW(common::errors::Unavailable("Array has no element.")); + PADDLE_THROW(common::errors::Unavailable("Array has no element.")); #endif } @@ -120,7 +120,7 @@ class Array { static const T obj{}; return obj; #else - COMMON_THROW(common::errors::Unavailable("Array has no element.")); + PADDLE_THROW(common::errors::Unavailable("Array has no element.")); #endif } diff --git a/paddle/common/enforce.cc b/paddle/common/enforce.cc index c2ef8308e8cd9..6dd4f0372e2b3 100644 --- a/paddle/common/enforce.cc +++ b/paddle/common/enforce.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/common/enforce.h" #include +#include #include #include #include @@ -48,21 +49,31 @@ std::string SimplifyDemangleStr(std::string str) { } return str; } + +std::atomic_bool paddle_fatal_skip{false}; + } // namespace namespace common { namespace enforce { -TEST_API int GetCallStackLevel() { return FLAGS_call_stack_level; } +void SkipPaddleFatal(bool skip) { paddle_fatal_skip.store(skip); } +bool IsPaddleFatalSkip() { return paddle_fatal_skip.load(); } -TEST_API std::string SimplifyErrorTypeFormat(const std::string& str) { +int GetCallStackLevel() { return FLAGS_call_stack_level; } + +std::string SimplifyErrorTypeFormat(const std::string& str) { std::ostringstream sout; size_t type_end_pos = str.find(':', 0); - if (type_end_pos == std::string::npos) { - sout << str; - } else { - // Remove "Error:", add "()"" + if (type_end_pos != str.npos && type_end_pos >= 5 && + str.substr(type_end_pos - 5, 6) == "Error:") { + // Remove "Error:", add "()" + // Examples: + // InvalidArgumentError: xxx -> (InvalidArgument) xxx sout << "(" << str.substr(0, type_end_pos - 5) << ")" << str.substr(type_end_pos + 1); + } else { + // type_end_pos == std::string::npos + sout << str; } return sout.str(); } diff --git a/paddle/common/enforce.h b/paddle/common/enforce.h index 856cf28d0221a..6076e9089df83 100644 --- a/paddle/common/enforce.h +++ b/paddle/common/enforce.h @@ -55,18 +55,25 @@ inline std::string demangle(std::string name) { inline std::string demangle(std::string name) { return name; } #endif -class CommonNotMetException : public std::exception { - public: - explicit CommonNotMetException(const std::string& str) : err_str_(str) {} +namespace enforce { - const char* what() const noexcept override { return err_str_.c_str(); } +TEST_API void SkipPaddleFatal(bool skip = true); +TEST_API bool IsPaddleFatalSkip(); + +namespace details { + +class PaddleFatalGuard { + public: + PaddleFatalGuard() : skip_paddle_fatal_(IsPaddleFatalSkip()) { + if (!skip_paddle_fatal_) SkipPaddleFatal(true); + } + ~PaddleFatalGuard() { + if (!skip_paddle_fatal_) SkipPaddleFatal(false); + } private: - std::string err_str_; + bool skip_paddle_fatal_; }; - -namespace enforce { -namespace details { template struct CanToString { private: @@ -204,6 +211,8 @@ struct EnforceNotMet : public std::exception { // Simple error message used when no C++ stack and python compile stack // e.g. (InvalidArgument) *** std::string simple_err_str_; + + details::PaddleFatalGuard paddle_fatal_guard_; }; /** HELPER MACROS AND FUNCTIONS **/ #ifndef PADDLE_MAY_THROW @@ -255,17 +264,22 @@ template using CommonType2 = typename std::add_lvalue_reference< typename std::add_const::Type2>::type>::type; -#define COMMON_THROW(...) \ - do { \ - HANDLE_THE_ERROR \ - throw common::CommonNotMetException( \ - paddle::string::Sprintf("Error occurred at: %s:%d :\n%s", \ - __FILE__, \ - __LINE__, \ - paddle::string::Sprintf(__VA_ARGS__))); \ - END_HANDLE_THE_ERROR \ +#define PADDLE_THROW(...) \ + do { \ + HANDLE_THE_ERROR \ + throw ::common::enforce::EnforceNotMet( \ + ::common::ErrorSummary(__VA_ARGS__), __FILE__, __LINE__); \ + END_HANDLE_THE_ERROR \ } while (0) +#define PADDLE_FATAL(...) \ + if (!::common::enforce::IsPaddleFatalSkip()) { \ + auto info = ::common::enforce::EnforceNotMet( \ + paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \ + std::cerr << info.what() << std::endl; \ + std::abort(); \ + } + #define __PADDLE_BINARY_COMPARE(__VAL1, __VAL2, __CMP, __INV_CMP, ...) \ do { \ auto __val1 = (__VAL1); \ @@ -357,6 +371,7 @@ class IrNotMetException : public std::exception { private: std::string err_str_; + ::common::enforce::details::PaddleFatalGuard paddle_fatal_guard_; }; #define IR_THROW(...) \ diff --git a/paddle/common/errors.cc b/paddle/common/errors.cc index c0541edb7a0c3..05f5c4e9d3703 100644 --- a/paddle/common/errors.cc +++ b/paddle/common/errors.cc @@ -21,49 +21,34 @@ std::string error_name(ErrorCode code) { switch (code) { case ErrorCode::LEGACY: return "Error"; - break; case ErrorCode::INVALID_ARGUMENT: return "InvalidArgumentError"; - break; case ErrorCode::NOT_FOUND: return "NotFoundError"; - break; case ErrorCode::OUT_OF_RANGE: return "OutOfRangeError"; - break; case ErrorCode::ALREADY_EXISTS: return "AlreadyExistsError"; - break; case ErrorCode::RESOURCE_EXHAUSTED: return "ResourceExhaustedError"; - break; case ErrorCode::PRECONDITION_NOT_MET: return "PreconditionNotMetError"; - break; case ErrorCode::PERMISSION_DENIED: return "PermissionDeniedError"; - break; case ErrorCode::EXECUTION_TIMEOUT: return "ExecutionTimeoutError"; - break; case ErrorCode::UNIMPLEMENTED: return "UnimplementedError"; - break; case ErrorCode::UNAVAILABLE: return "UnavailableError"; - break; case ErrorCode::FATAL: return "FatalError"; - break; case ErrorCode::EXTERNAL: return "ExternalError"; - break; case ErrorCode::INVALID_TYPE: return "InvalidTypeError"; - break; default: throw std::invalid_argument("The error type is undefined."); - break; } } diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index e09c7c0e8316e..35237b3a2f51f 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -629,6 +629,10 @@ PHI_DEFINE_EXPORTED_uint64( "The real chunk size is max(request_size, " "FLAGS_auto_growth_chunk_size_in_mb)."); +PHI_DEFINE_EXPORTED_bool(custom_device_mem_record, + false, + "Enable mem record event on custom device"); + #endif /** @@ -1345,6 +1349,19 @@ PHI_DEFINE_EXPORTED_bool(use_shm_cache, false, "Use shm cache in mmap_allocator."); +/** + * mmap_allocator related FLAG + * Name: dataloader_use_file_descriptor + * Since Version: 2.6.2 + * Value Range: bool, default=true + * Example: + * Note: . If True, mmap_allocator will use file descripor to open shared memory + * operation. + */ +PHI_DEFINE_EXPORTED_bool(dataloader_use_file_descriptor, + true, + "Use file descriptor in mmap_allocator."); + /** * Tensor operants related FLAG * Name: tensor_operants_mode @@ -1470,6 +1487,14 @@ PHI_DEFINE_EXPORTED_bool(prim_check_ops, "Whether to check the decomposed program, to ensure " "that only the primitive operator is present."); +// PIR and prim related FLAG +// Example: FLAGS_prim_forward_blacklist="pd_op.relu;pd_op.mean" would block +// `relu` and `mean` two ops in decompsition. +PHI_DEFINE_EXPORTED_string( + prim_forward_blacklist, + "", + "It controls the forward blacklist ops not to be decomposed."); + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) /** diff --git a/paddle/common/flags.h b/paddle/common/flags.h index b9ca1a52c4c63..006f2fea5355d 100644 --- a/paddle/common/flags.h +++ b/paddle/common/flags.h @@ -122,19 +122,6 @@ PADDLE_API void ParseCommandLineFlags(int* argc, char*** argv); */ PADDLE_API void AllowUndefinedFlags(); -/** - * @brief Set flags from environment variables. - * - * It recieves a list of flags name, and will find the corresponding environment - * variables named "FLAGS_name", if found, it will set the environment variable - * values to the flags. If error_fatal is true, the program will exit when the - * environment variable is not set or the flag is not defined, that is the same - * effect as using commandline argument "--fromenv=var_name1,var_name2,...". - * Otherwise, the errors above will be ignored, that is the same effect as using - * commandline argument "--tryfromenv=var_name1,var_name2,...". - */ -void SetFlagsFromEnv(const std::vector& flags, bool error_fatal); - /** * @brief Set Single flag value, return true if success. */ diff --git a/paddle/common/flags_native.cc b/paddle/common/flags_native.cc index 8229c6b0f0b1d..706419721d96f 100644 --- a/paddle/common/flags_native.cc +++ b/paddle/common/flags_native.cc @@ -362,6 +362,18 @@ bool GetValueFromEnv(const std::string& name, std::string* value) { return true; } +/** + * @brief Set flags from environment variables. + * + * It recieves a list of flags name, and will find the corresponding environment + * variables named "FLAGS_name", if found, it will set the environment variable + * values to the flags. If error_fatal is true, the program will exit when the + * environment variable is not set or the flag is not defined, that is the same + * effect as using commandline argument "--fromenv=var_name1,var_name2,...". + * Otherwise, the errors above will be ignored, that is the same effect as using + * commandline argument "--tryfromenv=var_name1,var_name2,...". + */ + void SetFlagsFromEnv(const std::vector& flags, bool error_fatal) { bool success = true; for (const std::string& flag_name : flags) { diff --git a/paddle/extension.h b/paddle/extension.h index 3c79adcde5d69..5c309a20b0065 100644 --- a/paddle/extension.h +++ b/paddle/extension.h @@ -14,12 +14,37 @@ limitations under the License. */ #pragma once +#if defined(__clang__) || defined(__GNUC__) +#define CPP_STANDARD __cplusplus +#elif defined(_MSC_VER) +#define CPP_STANDARD _MSVC_LANG +#endif + #ifndef CUSTOM_OP_WITH_SPMD #define CUSTOM_OP_WITH_SPMD #endif // All paddle apis in C++ frontend +// phi headers #include "paddle/phi/api/all.h" +// common headers +#include "paddle/common/ddim.h" +#include "paddle/common/exception.h" +#include "paddle/common/layout.h" + +#if CPP_STANDARD >= 201703L && !defined(__clang__) +// pir&pass headers +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" +#include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/core/type.h" +#include "paddle/pir/include/core/value.h" +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/pir/include/pass/pass_registry.h" +#include "paddle/pir/include/pattern_rewrite/pattern_match.h" +#endif + #if !defined(PADDLE_ON_INFERENCE) && !defined(PADDLE_NO_PYTHON) // Python bindings for the C++ frontend (includes Python.h) #include "paddle/utils/pybind.h" diff --git a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt index d1eae7f599549..0fd2d6e884d1e 100644 --- a/paddle/fluid/distributed/auto_parallel/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/CMakeLists.txt @@ -5,4 +5,4 @@ cc_library( SRCS dist_attr.cc DEPS phi common auto_parallel_proto proto_desc) -cc_library(auto_parallel DEPS op_dist_attr spmd_rules) +cc_library(auto_parallel DEPS op_dist_attr dist_tensor_spec) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt index f16c155890579..38aecc5b39b3b 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/CMakeLists.txt @@ -1,6 +1,6 @@ -file(GLOB spmd_srcs *.cc) +file(GLOB dist_tensor_spec_srcs *.cc) cc_library( - spmd_rules - SRCS ${spmd_srcs} + dist_tensor_spec + SRCS ${dist_tensor_spec_srcs} DEPS phi common) diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc deleted file mode 100644 index d38de8d90e2e4..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.cc +++ /dev/null @@ -1,297 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h" -#include "paddle/phi/core/distributed/auto_parallel/utils.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -using phi::distributed::auto_parallel::str_join; - -std::pair, std::vector> -SPMDRuleBase::InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW( - phi::errors::Unimplemented("InferForward should be called from a " - "derived class of SPMDRuleBase !")); -} - -std::pair, std::vector> -SPMDRuleBase::InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW( - phi::errors::Unimplemented("InferBackward should be called from a " - "derived class of SPMDRuleBase !")); -} - -// deprecated -std::pair, std::vector> -SPMDRuleBase::InferBackward(const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW( - phi::errors::Unimplemented("InferBackward should be called from a " - "derived class of SPMDRuleBase !")); -} - -std::unordered_map ShardingMergeForTensors( - const std::vector>>& - tensor_axes_to_dim_pairs, - const bool merge_conflicts) { - std::unordered_map axis_to_dim_map; - std::unordered_map dim_to_axis_map; - int64_t merge_dim = 0; - - for (auto& pair : tensor_axes_to_dim_pairs) { - for (size_t i = 0; i < pair.second.size(); ++i) { - auto tensor_axis = pair.first.substr(i, 1); - auto mesh_dim = pair.second[i]; - - if (axis_to_dim_map.count(tensor_axis) == 0) { - merge_dim = mesh_dim; - } else { - merge_dim = ShardingMergeForAxis( - tensor_axis, mesh_dim, axis_to_dim_map[tensor_axis]); - } - axis_to_dim_map[tensor_axis] = merge_dim; - if (merge_dim != -1) { - if (dim_to_axis_map.count(merge_dim) == 0) { - dim_to_axis_map.insert({merge_dim, tensor_axis}); - } else if (dim_to_axis_map[merge_dim].find(tensor_axis) == - std::string::npos) { - dim_to_axis_map[merge_dim] += tensor_axis; - } - } - } - } - - // Resolute "mesh_dim shard by more than one axis" conflict. - // Now we just naive pick the first axis naively. - // (TODO) use local cost model to pick the axis with lowest cost(in concern of - // memory or communication or computation). - for (auto& it : dim_to_axis_map) { - if (it.second.size() > 1) { - if (merge_conflicts) { - VLOG(4) << "Sharding Conflict: Mesh_Dim [" << it.first - << "] are Sharding Multiple Tensor Axis: [" << it.second - << "]. The Axis: [" << it.second[0] << "] is Picked."; - for (size_t i = 1; i < it.second.size(); ++i) { - axis_to_dim_map[it.second.substr(i, 1)] = -1; - } - } else { - PADDLE_THROW(phi::errors::PreconditionNotMet( - "Multiple Tensor Axes [%s] is sharded by same mesh dimension [%d].", - str_join(it.second), - it.first)); - } - } - } - - return axis_to_dim_map; -} - -// Rule1: A replicated dimension could be merged by any sharded dimension. -// Rule2: A tensor axis could at most be sharded by one mesh dimension. -// (TODO trigger heuristics cost model and reshard to handle axis sharded by -// multiple dimension case.) -int64_t ShardingMergeForAxis(const std::string& axis, - const int64_t& mesh_dim1, - const int64_t& mesh_dim2) { - if (mesh_dim1 != mesh_dim2) { - if (mesh_dim1 == -1) { - return mesh_dim2; - } else if (mesh_dim2 == -1) { - return mesh_dim1; - } else { - // (TODO) local cost model here. - PADDLE_THROW( - phi::errors::Unimplemented("Tensor Axis[%s] is Sharded by two " - "different mesh dimension [%d] and [%d].", - axis, - mesh_dim1, - mesh_dim2)); - } - - } else { - return mesh_dim1; - } -} - -TensorDistAttr CopyTensorDistAttrForOutput( - const TensorDistAttr& src_dist_attr) { - TensorDistAttr new_dist_attr = TensorDistAttr(); - new_dist_attr.set_process_mesh(src_dist_attr.process_mesh()); - new_dist_attr.set_batch_dim(src_dist_attr.batch_dim()); - new_dist_attr.set_dynamic_dims(src_dist_attr.dynamic_dims()); - // new_dist_attr.set_annotated(false); TODO unset field is false by default. - return new_dist_attr; -} - -std::vector ResoluteOutputPartialDimension( - const std::unordered_map& axis_to_dim_map, - const std::string& tensor_axes) { - std::vector partial_on_dims; - - for (auto& it : axis_to_dim_map) { - if (tensor_axes.find(it.first) == std::string::npos) { - if (it.second > -1) { - partial_on_dims.push_back(it.second); - } - } - } - return partial_on_dims; -} - -std::string GetBroadcastAxes(const int64_t& tensor_ndim, - const int64_t& broadcast_ndim, - const std::string& alphabet) { - PADDLE_ENFORCE_GE( - alphabet.size(), - broadcast_ndim, - phi::errors::InvalidArgument( - "size of alphabet [%d] is less than broadcast ndim [%d]", - alphabet.size(), - broadcast_ndim)); - PADDLE_ENFORCE_GE(broadcast_ndim, - tensor_ndim, - phi::errors::InvalidArgument( - "broadcast ndim [%d] is less than tensor ndim [%d]", - broadcast_ndim, - tensor_ndim)); - if (tensor_ndim <= 0) { - return std::string(); - } - return alphabet.substr(broadcast_ndim - tensor_ndim, tensor_ndim); -} - -TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr) { - TensorDistAttr replicated_dist_attr = src_dist_attr; - replicated_dist_attr.clear_annotated(); - size_t tensor_ndim = replicated_dist_attr.dims_mapping().size(); - replicated_dist_attr.set_dims_mapping(std::vector(tensor_ndim, -1)); - return replicated_dist_attr; -} - -void VerifySpecs(const std::vector& specs, - const std::string& op_name) { - for (size_t i = 0, n = specs.size(); i < n; ++i) { - const std::vector& shape = specs[i].shape(); - const std::vector& dims_mapping = specs[i].dims_mapping(); - PADDLE_ENFORCE_EQ(shape.size(), - dims_mapping.size(), - phi::errors::InvalidArgument( - "Mismatch in %s, spec[%d]'s tensor size: [%d] and " - "spec[%d]'s dims_mapping size [%d].", - op_name, - i, - shape.size(), - i, - dims_mapping.size())); - } -} - -std::vector>> -GetAxesDimsMappingPair(const std::vector& tensor_axes, - const std::vector& specs) { - std::vector>> res; - size_t ntensor = specs.size(); - for (size_t i = 0; i < ntensor; ++i) { - res.emplace_back(tensor_axes[i], specs[i].dims_mapping()); - } - return res; -} - -std::vector GetDimsMappingForAxes( - const std::string& axes, - const std::unordered_map& axis_to_dim_map, - const bool unsharded_miss_axis) { - std::vector dims_mapping; - for (int64_t i = 0, n = static_cast(axes.size()); i < n; i++) { - std::string axis = axes.substr(i, 1); - if (axis == "1") { - dims_mapping.emplace_back(-1); - } else { - auto iter = axis_to_dim_map.find(axis); - if (iter == axis_to_dim_map.end()) { - if (unsharded_miss_axis) { - dims_mapping.emplace_back(-1); - } else { - phi::errors::InvalidArgument( - "Tensor axis [%s] of not in axis_to_dim_map.", axis); - } - } else { - dims_mapping.emplace_back(iter->second); - } - } - } - return dims_mapping; -} - -// SPMDRuleMap -SPMDRuleMap& SPMDRuleMap::Instance() { - static SPMDRuleMap g_spmd_rule_map; - return g_spmd_rule_map; -} - -// To enable default replicated spmd rule for op that are NOT registered -// which all tensors of inputs and outputs will be replicated in all ranks of -// the mesh. -SPMDRuleBase* SPMDRuleMap::Get(const std::string& op_type) const { - auto rule_ptr = GetNullable(op_type); - if (rule_ptr == nullptr) { - std::string str; - for (const auto& item : map_) { - str += item.first + ", "; - } - VLOG(4) << "Size of current map [" << map_.size() << "]"; - VLOG(4) << "Keys are [" << str << "]"; - } - PADDLE_ENFORCE_NOT_NULL( - rule_ptr, - platform::errors::NotFound( - "NO SPMD Rule has been registered for Operator [%s].", op_type)); - return rule_ptr; -} - -SPMDRuleBase* SPMDRuleMap::GetNullable(const std::string& op_type) const { - auto it = map_.find(op_type); - if (it == map_.end()) { - return nullptr; - } else { - return it->second.get(); - } -} - -int SPMDRuleMap::Insert(const std::string& op_type, - std::unique_ptr rule) { - VLOG(4) << "Call SPMDRuleMap::Insert!"; - PADDLE_ENFORCE_NE( - Has(op_type), - true, - platform::errors::AlreadyExists( - "SPMD Rule for Operator [%s] has been registered.", op_type)); - map_.insert({op_type, std::move(rule)}); - - return 1; -} - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h deleted file mode 100644 index 9f6a52750580b..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/common.h +++ /dev/null @@ -1,191 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" -#include "paddle/fluid/framework/attribute.h" -#include "paddle/fluid/framework/type_defs.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" -#include "paddle/utils/flat_hash_map.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -using paddle::framework::Attribute; - -class SPMDRuleBase { - public: - virtual ~SPMDRuleBase() {} - - // Based on the information of Input Tensors and Op Attribute: - // 1. Merge the Sharding (dims_mapping) among Input Tensors. - // 2. Infer the Sharding (dims_mapping) for Output Tensors. - // The Info of input tensors (Shape and DistAttr) are wrapped as - // DistTensorSpec, and op attribute should be given as AttributeMap. The - // Output is a pair consist of two vectors: - // 1. The first vector: the merged DistAttr of input tensors. - // 2. The inferred DistAttr of output tensors. - // The Merged DistAttr might be different from the original Intput DistAttrs, - // which means that the corresponding input tensor need to be reshard. - virtual std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs); - - // Based on the information of Input & Output Tensors and Op Attribute: - // 1. Merge the Sharding (dims_mapping) among Output Tensors. - // 2. Infer the Sharding (dims_mapping) for Input Tensors. - // The Info of output tensors (Shape and DistAttr) are wrapped as - // DistTensorSpec, and op attribute should be given as AttributeMap. The - // Output is a pair consist of two vectors: - // 1. The first vector: the merged DistAttr of output tensors. - // 2. The inferred DistAttr of Input tensors. - virtual std::pair, std::vector> - InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs); - - // deprecated, to be remove in future - virtual std::pair, std::vector> - InferBackward(const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs); - - template - inline const T ExtractAttr( - const std::string& name, - const paddle::framework::AttributeMap& attrs) const { - auto attr = GetAttr(name, attrs); - return *paddle::framework::ExtractAttribute(name)(attr); - } - - Attribute GetAttr(const std::string& name, - const paddle::framework::AttributeMap& attrs) const { - auto iter = attrs.find(name); - PADDLE_ENFORCE_NE(iter, - attrs.end(), - paddle::platform::errors::NotFound( - "(%s) is not found in AttributeMap.", name)); - return iter->second; - } -}; - -// Merge sharding specification (dims mapping) of given tensors. -// The same axes of different tensors will be merged. -std::unordered_map ShardingMergeForTensors( - const std::vector>>& - tensor_axes_to_dim_pairs, - const bool merge_conflicts = true); - -// Merge the sharding specification (dims mapping) for one tensor Axis. -// Rule1: A replicated dimension could be merged by any sharded dimension. -// Rule2: A tensor axis could at most be sharded by one mesh dimension. -// (TODO trigger heuristics cost model and reshard to handle axis sharded by -// multiple dimension case.) -int64_t ShardingMergeForAxis(const std::string& axis, - const int64_t& mesh_dim1, - const int64_t& mesh_dim2); - -// Intend to use for generating the TensorDistAttr of output based on the input -// activation TensorDistAttr. The process_mesh, batch_dim, dynamic_dim are -// copied with annotated is forced to False, and dims_mapping is leave to be -// null. -TensorDistAttr CopyTensorDistAttrForOutput(const TensorDistAttr& src_dist_attr); - -// Resolute the partial mesh dimension of a output tensor, giving the -// merged sharding specification of input tensors and the axis names of output -// tensor. Input are -std::vector ResoluteOutputPartialDimension( - const std::unordered_map& axis_to_dim_map, - const std::string& tensor_axes); - -// Generate the axis notation of tensor for the einsum notation of a broadcast -// operation(alignment star from the rightmost axis). tensor_ndim: the size of -// the tensor. broadcast_ndim: the maximum size of tensors in this broadcast -// operation. alphabet: the characters used to represent the axes of tensor. -// length of alphabet should >= broadcast_ndim. -std::string GetBroadcastAxes(const int64_t& tensor_ndim, - const int64_t& broadcast_ndim, - const std::string& alphabet); - -// Return a NEW TensorDistAttr whose dims mapping is consist of "-1" -// (unsharded). -TensorDistAttr ReplicatedOnMesh(const TensorDistAttr& src_dist_attr); - -// Check whether the given DistTensorSpec objects are valid. For each -// DistTensorSpec, the rank of its dims mapping must be equal to the rank of its -// corresponding tensor shape. the parameter op_name is used for logging error -// message. -void VerifySpecs(const std::vector& specs, - const std::string& op_name); - -// Get dims mapping for the given tensors. Return the pair of each -// tensor's einsum notation and the corresponding dims mapping. -std::vector>> -GetAxesDimsMappingPair(const std::vector& tensor_axes, - const std::vector& specs); - -// Get dims mapping for the given axes according to sharding information of -// the annotated axes after inferring forward or backward. The parameter axis -// stores the axes of the tensor. "1" is a special axis, for the axis "1", set -// its dims mapping to -1. -// if unsharded_miss_axis, "-1" is assigned to axes that has no key in -// axis_to_dim_map. -std::vector GetDimsMappingForAxes( - const std::string& axes, - const std::unordered_map& axis_to_dim_map, - const bool unsharded_miss_axis = false); - -// The static map that stores and initializes all the registered SPMD rules. -class SPMDRuleMap { - public: - ~SPMDRuleMap() = default; - - // A singleton - static SPMDRuleMap& Instance(); - - // Returns the spmd rule for the given op_type - SPMDRuleBase* Get(const std::string& op_type) const; - - // Returns the spmd by name or nullptr if not registered - SPMDRuleBase* GetNullable(const std::string& op_type) const; - - // Register a spmd for an op_type. - int Insert(const std::string& op_type, std::unique_ptr rule); - - bool Has(const std::string& op_type) const { - return map_.find(op_type) != map_.end(); - } - - private: - SPMDRuleMap() = default; - paddle::flat_hash_map> map_; - DISABLE_COPY_AND_ASSIGN(SPMDRuleMap); -}; - -#define REGISTER_SPMD_RULE(op_type, rule_class, ...) \ - UNUSED static int __spmd_rule_holder_##op_type = \ - ::paddle::distributed::auto_parallel::SPMDRuleMap::Instance().Insert( \ - #op_type, std::make_unique(__VA_ARGS__)) - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h deleted file mode 100644 index 70d603e509c43..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/matmul_spmd_rule.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -TensorDistAttr GetInferedDistAttr( - const TensorDistAttr& origin_dist_attr, - const std::vector& shape, - const std::string& tensor_axes, - const std::unordered_map& axis_to_dim_map, - const bool trans_axis); - -void FillMatmulOperandNotation(const int x_ndim, - const int y_ndim, - std::string* x_axes, - std::string* y_axes, - std::string* out_axes); - -class MatmulSPMDRule : public SPMDRuleBase { - public: - std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) override; - - std::pair, std::vector> - InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) override; -}; -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.cc deleted file mode 100644 index 5227a82a4b8b5..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.cc +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -std::pair, std::vector> -ReplicatedSPMDRule::InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - std::vector intput_dist_attrs; - std::vector output_dist_attrs; - intput_dist_attrs.reserve(input_specs.size()); - - for (auto& input_spec : input_specs) { - intput_dist_attrs.push_back(ReplicatedOnMesh(input_spec.dist_attr())); - } - - // TODO(ljz): we need to know num of output and size of each output before - // generate the exact replicated dist tensor attr for the current op. - // here we just assume that only one output tensor and has the same size as - // the first input tensor. - return {intput_dist_attrs, {ReplicatedOnMesh(input_specs[0].dist_attr())}}; -} - -std::pair, std::vector> -ReplicatedSPMDRule::InferBackward( - const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - PADDLE_THROW(phi::errors::Unimplemented( - "InferBackward of ReplicatedSPMDRule is NOT implemented yet.")); -} - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h deleted file mode 100644 index bcca646d351d5..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -// A Bottom Line Rule that enforces input(s) and output(s) of the Op to be -// replicated among the given mesh. -class ReplicatedSPMDRule : public SPMDRuleBase { - public: - // The dims_mapping of ALL TensorDistAttrs would be repeat of "-1" - // (unsharded). - std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) override; - - // The dims_mapping of ALL TensorDistAttrs would be repeat of "-1" - // (unsharded). - std::pair, std::vector> - InferBackward(const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) override; -}; -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt b/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt deleted file mode 100644 index 449ee65ccc751..0000000000000 --- a/paddle/fluid/distributed/auto_parallel/test/CMakeLists.txt +++ /dev/null @@ -1,9 +0,0 @@ -paddle_test(device_mesh_test SRCS device_mesh_test.cc) - -paddle_test(process_mesh_test SRCS process_mesh_test.cc) - -paddle_test(dist_attr_test SRCS dist_attr_test.cc) - -paddle_test(dist_mapper_test SRCS dist_mapper_test.cc) - -paddle_test(spmd_rule_test SRCS spmd_rule_test.cc) diff --git a/paddle/fluid/distributed/collective/mpi_tools.h b/paddle/fluid/distributed/collective/mpi_tools.h index 7f86409c036eb..be2838ffffa83 100644 --- a/paddle/fluid/distributed/collective/mpi_tools.h +++ b/paddle/fluid/distributed/collective/mpi_tools.h @@ -32,14 +32,16 @@ namespace paddle { namespace distributed { namespace mpi { -#define MPI_CHECK(cmd) \ - do { \ - int r = cmd; \ - if (r != MPI_SUCCESS) { \ - LOG(FATAL) << "Failed, MPI error in" << __FILE__ << ":" << __LINE__ \ - << "with error code: " << std::to_string(r) << std::endl; \ - exit(EXIT_FAILURE); \ - } \ +#define MPI_CHECK(cmd) \ + do { \ + int r = cmd; \ + if (r != MPI_SUCCESS) { \ + std::stringstream ss; \ + ss << "Failed, MPI error in" << __FILE__ << ":" << __LINE__ \ + << "with error code: " << std::to_string(r) << std::endl; \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + exit(EXIT_FAILURE); \ + } \ } while (0) MPI_Op ToMPIType(ReduceOp reduction); diff --git a/paddle/fluid/distributed/collective/process_group_custom.cc b/paddle/fluid/distributed/collective/process_group_custom.cc index 33b2728bdc288..715d4d692ea5a 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.cc +++ b/paddle/fluid/distributed/collective/process_group_custom.cc @@ -161,6 +161,32 @@ phi::ccl::CCLComm ProcessGroupCustom::XCCLComm(const Place& place) const { return iter->second->xccl_comm(); } +std::string ProcessGroupCustom::GetCommName(int rank) { + PADDLE_ENFORCE_GE(rank, + 0, + phi::errors::PreconditionNotMet( + "The rank must greater or equal than 0!")); + auto num_devices = phi::DeviceManager::GetDeviceCount(device_type_); + PADDLE_ENFORCE_GT( + num_devices, + 0, + phi::errors::InvalidArgument("The num_devices must greater than 0!")); + + auto place_id = rank % num_devices; + platform::CustomPlace place(device_type_, place_id); + const auto& key = GetKeyFromPlace(place); + phi::DeviceGuard guard(place); + if (place_to_comm_ctx_.find(key) == place_to_comm_ctx_.end()) { + CreateXCCLEnvCache(place, key); + } + + char comm_name[128]; + phi::DeviceManager::CCLCommName( + device_type_, this->GetCommContext()->GetXcclComm(), comm_name); + std::string name_str(comm_name); + return name_str; +} + std::shared_ptr ProcessGroupCustom::AllGather( phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, @@ -236,7 +262,7 @@ std::shared_ptr ProcessGroupCustom::AllToAll( std::vector send_buf, recv_buf; std::vector send_count, recv_count; - std::vector send_dtype, recv_dtype; + std::vector send_dtype, recv_dtype; for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; input_partial = GetPartialTensor(tensor_tmp, in_offset, in_numel); @@ -248,8 +274,8 @@ std::shared_ptr ProcessGroupCustom::AllToAll( recv_buf.push_back(output_partial.data()); send_count.push_back(in_numel); recv_count.push_back(out_numel); - send_dtype.push_back(phi::ccl::ToCCLDataType(input_partial.dtype())); - recv_dtype.push_back(phi::ccl::ToCCLDataType(output_partial.dtype())); + send_dtype.push_back(input_partial.dtype()); + recv_dtype.push_back(output_partial.dtype()); } phi::DeviceManager::CCLAllToAll( @@ -992,9 +1018,8 @@ std::shared_ptr ProcessGroupCustom::AllToAll( std::vector send_buf, recv_buf; std::vector send_count(size_, input.numel() / size_), recv_count(size_, input.numel() / size_); - std::vector send_dtype( - size_, phi::ccl::ToCCLDataType(input.dtype())), - recv_dtype(size_, phi::ccl::ToCCLDataType(input.dtype())); + std::vector send_dtype(size_, input.dtype()), + recv_dtype(size_, input.dtype()); for (auto i = 0; i < size_; i++) { send_buf.push_back( GetPointerByOffset(input.data(), offset, input.dtype())); diff --git a/paddle/fluid/distributed/collective/process_group_custom.h b/paddle/fluid/distributed/collective/process_group_custom.h index a3fb060376597..0bb1c402a181e 100644 --- a/paddle/fluid/distributed/collective/process_group_custom.h +++ b/paddle/fluid/distributed/collective/process_group_custom.h @@ -82,6 +82,8 @@ class ProcessGroupCustom final : public ProcessGroupWithStream { std::string GetBackendName() const override { return "XCCL"; } + std::string GetCommName(int rank); + phi::DeviceContext* GetDeviceContext(const Place& place) const override; phi::DeviceContext* GetDeviceContext(const Place& place, diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 82e95204590bd..d2e75768b95cb 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -123,11 +123,15 @@ ProcessGroupNCCL::ProcessGroupNCCL( int rank, int size, int gid, - int64_t timeout) + int64_t timeout, + int nccl_comm_init_option) : ProcessGroupWithStream(rank, size, gid), store_(store), - pg_timeout_(timeout) { + pg_timeout_(timeout), + nccl_comm_init_option_(nccl_comm_init_option) { LOG(INFO) << "ProcessGroupNCCL pg_timeout_ " << pg_timeout_; + LOG(INFO) << "ProcessGroupNCCL nccl_comm_init_option_ " + << nccl_comm_init_option_; } ProcessGroupNCCL::~ProcessGroupNCCL() { LOG(INFO) << "ProcessGroupNCCL destruct "; @@ -528,7 +532,9 @@ std::shared_ptr ProcessGroupNCCL::Gather( size_t offset = 0; size_t numel = out_tensor->numel() / size_; for (auto i = 0; i < size_; i++) { - partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel)); + partial_tensors.push_back(GetPartialTensor(*out_tensor, + static_cast(offset), + static_cast(numel))); offset += numel; } } @@ -718,7 +724,7 @@ void ProcessGroupNCCL::CreateNCCLEnvCache(const Place& place, phi::distributed::P2POption p2p_opts({is_p2p_op, p2p_rank, num_ranks, rank}); phi::distributed::CommContextManager::CreateNCCLCommContext( - store_, store_key, rank_, size_, "", &p2p_opts); + store_, store_key, rank_, size_, "", &p2p_opts, nccl_comm_init_option_); NCCL_CHECK(phi::dynload::ncclGroupEnd()); @@ -1009,9 +1015,10 @@ std::shared_ptr ProcessGroupNCCL::CreateProcessGroupNCCL( int rank, int size, int gid, - int64_t timeout) { - auto process_group = - std::make_shared(store, rank, size, gid, timeout); + int64_t timeout, + int nccl_comm_init_option) { + auto process_group = std::make_shared( + store, rank, size, gid, timeout, nccl_comm_init_option); ProcessGroupIdMap::GetInstance().emplace(gid, process_group); return process_group; } diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index 22d90370f16af..a57337f1d47fa 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -76,13 +76,15 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { int rank, int size, int gid, - int64_t timeout); + int64_t timeout, + int nccl_comm_init_option); ProcessGroupNCCL(const std::shared_ptr& store, int rank, int size, int gid, - int64_t timeout = 30 * 60 * 1000); + int64_t timeout = 30 * 60 * 1000, + int nccl_comm_init_option = 0); ~ProcessGroupNCCL(); std::string GetBackendName() const override { return "NCCL"; } @@ -177,6 +179,8 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { ncclComm_t NCCLComm(const Place& place) const; + const bool GetNCCLCommInitOption() { return nccl_comm_init_option_; } + private: std::shared_ptr CreateTask(const Place& place, int rank, @@ -247,6 +251,7 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream { static uint64_t s_group_call_counter; // default 30 minutes int64_t pg_timeout_; + int nccl_comm_init_option_; // optimize memory for process_group std::vector, gpuStream_t>> diff --git a/paddle/fluid/distributed/collective/reducer.cc b/paddle/fluid/distributed/collective/reducer.cc index 68ccd8f52fa10..a49dc15199d8b 100644 --- a/paddle/fluid/distributed/collective/reducer.cc +++ b/paddle/fluid/distributed/collective/reducer.cc @@ -334,6 +334,11 @@ void ConcatTensorsWithType( platform::float16>()( context, dense_tensors_, p_dense_contents); break; + case phi::DataType::BFLOAT16: + ConcatTensorsForAllReduce()( + context, dense_tensors_, p_dense_contents); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it concats tensors for " @@ -358,6 +363,11 @@ void SplitTensorsWithType( SplitTensorsForAllReduce()( context, p_dense_contents, p_dense_tensors); break; + case phi::DataType::BFLOAT16: + SplitTensorsForAllReduce()( + context, p_dense_contents, p_dense_tensors); + break; default: PADDLE_THROW(platform::errors::Unimplemented( "Data type (%s) is not supported when it splits tensors for " @@ -831,23 +841,33 @@ void EagerReducer::MarkVarReady(const size_t var_index, auto &group_tensor = group.dense_tensors_[inside_group_index]; const auto length = group.length_[inside_group_index]; if (is_used_var) { - auto *autograd_meta = tensors_[var_index].get_autograd_meta(); - paddle::Tensor grad_tensor = - static_cast(autograd_meta)->Grad(); - if (grad_tensor.is_dense_tensor()) { - const auto &tensor_impl = grad_tensor.impl(); - auto dense_tensor = - std::dynamic_pointer_cast(tensor_impl); - if (!dense_tensor->meta().is_contiguous()) { - grad_tensor.set_impl(std::make_shared(std::move( - paddle::experimental::Trans2Contiguous(*dense_tensor)))); + if (HasGrad(var_index)) { + auto *autograd_meta = tensors_[var_index].get_autograd_meta(); + paddle::Tensor grad_tensor = + static_cast(autograd_meta)->Grad(); + if (grad_tensor.is_dense_tensor()) { + const auto &tensor_impl = grad_tensor.impl(); + auto dense_tensor = + std::dynamic_pointer_cast(tensor_impl); + if (!dense_tensor->meta().is_contiguous()) { + grad_tensor.set_impl(std::make_shared( + paddle::experimental::Trans2Contiguous(*dense_tensor))); + } } - } - group_tensor - .ShareDataWith(*( - std::dynamic_pointer_cast(grad_tensor.impl()))) - .Resize({grad_tensor.numel()}); + group_tensor + .ShareDataWith(*(std::dynamic_pointer_cast( + grad_tensor.impl()))) + .Resize({grad_tensor.numel()}); + } else { + VLOG(3) << "Tensor[" << tensors_[var_index].name() + << "] doesn't have grad"; + auto *dev_ctx = + platform::DeviceContextPool::Instance().Get(inner_place_); + group_tensor.Resize({static_cast(length)}); + dev_ctx->Alloc(&group_tensor, group.dtype_); + phi::funcs::set_constant(*dev_ctx, &group_tensor, 0.0f); + } } else { // TODO(shenliang03): maybe save the memory by avoiding tensor // construction @@ -864,8 +884,8 @@ void EagerReducer::MarkVarReady(const size_t var_index, auto dense_tensor = std::dynamic_pointer_cast(tensor_impl); if (!dense_tensor->meta().is_contiguous()) { - grad_tensor->set_impl(std::make_shared(std::move( - paddle::experimental::Trans2Contiguous(*dense_tensor)))); + grad_tensor->set_impl(std::make_shared( + paddle::experimental::Trans2Contiguous(*dense_tensor))); } } @@ -894,7 +914,7 @@ void EagerReducer::MarkVarReady(const size_t var_index, "The sparse parameter[%d][%s] should have gradient. " "Currently, DataParallel does not support sparse " "parameters without generating gradients during training. " - "For example, if is_sparese=True is used in Embedding, " + "For example, if is_sparse=True is used in Embedding, " "the current step of this parameter cannot generate gradient " "because of stop_gradient/detach, where error will occur.", var_index, diff --git a/paddle/fluid/distributed/common/afs_warpper.h b/paddle/fluid/distributed/common/afs_warpper.h index 30f4f164ba5a1..03b80ef105f73 100644 --- a/paddle/fluid/distributed/common/afs_warpper.h +++ b/paddle/fluid/distributed/common/afs_warpper.h @@ -22,7 +22,7 @@ #include "paddle/common/macros.h" #include "paddle/fluid/distributed/the_one_ps.pb.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { struct FsDataConverter { diff --git a/paddle/fluid/distributed/common/registerer.h b/paddle/fluid/distributed/common/registerer.h index 5b2d4291d826c..730fae0499060 100644 --- a/paddle/fluid/distributed/common/registerer.h +++ b/paddle/fluid/distributed/common/registerer.h @@ -78,15 +78,17 @@ typedef std::map PsCoreClassMap; extern "C" { #endif -inline PsCoreClassMap &global_factory_map() { +inline PsCoreClassMap *global_factory_map() { static PsCoreClassMap *base_class = new PsCoreClassMap(); - return *base_class; + return base_class; } #ifdef __cplusplus } #endif -inline PsCoreClassMap &global_factory_map_cpp() { return global_factory_map(); } +inline PsCoreClassMap &global_factory_map_cpp() { + return *global_factory_map(); +} // typedef pa::Any Any; // typedef ::FactoryMap FactoryMap; diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 8da1ef87814de..5e2be03108294 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -176,7 +176,7 @@ bool ComputeInterceptor::IsInputReady() { flag = flag && (ready_size_map.at(i) != 0); } if (flag) { - if (scope_id_to_finish_flag.empty()) { + if (scope_id_to_finish_flag.empty()) { // NOLINT cur_scope_id_ = i; return true; } else if (scope_id_to_finish_flag.find(i) != @@ -303,7 +303,7 @@ void ComputeInterceptor::RunOps() { cur_scope_id_)); } - if (!cores_.empty()) { + if (!cores_.empty()) { // NOLINT cores_[cur_scope_id_]->Run(/*feed_names=*/{}, /*need_fetch=*/false); } else { for (auto op : node_->ops()) { diff --git a/paddle/fluid/distributed/fleet_executor/dist_model.cc b/paddle/fluid/distributed/fleet_executor/dist_model.cc index a1fd38295319e..4c19069b33705 100644 --- a/paddle/fluid/distributed/fleet_executor/dist_model.cc +++ b/paddle/fluid/distributed/fleet_executor/dist_model.cc @@ -215,7 +215,7 @@ bool DistModel::Init() { } bool DistModel::PreparePlace() { - if (config_.place == "GPU") { + if (config_.place == "GPU") { // NOLINT place_ = paddle::platform::CUDAPlace(config_.device_id); } else if (config_.place == "CPU") { place_ = paddle::platform::CPUPlace(); diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc index 89150deff544a..2f0bba29ba28b 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_client.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/distributed/ps/service/coordinator_client.h" #include "paddle/fluid/framework/archive.h" -#include "paddle/fluid/string/split.h" +#include "paddle/utils/string/split.h" static const int max_port = 65535; @@ -402,7 +402,7 @@ int DownpourBrpcClosure::check_response(size_t request_idx, int cmd_id) { int DownpourBrpcClosure::check_save_response(size_t request_idx, int cmd_id) { int32_t feasign_size = 0; if (_cntls[request_idx]->Failed()) { - LOG(ERROR) << "resquest cmd_id:" << cmd_id + LOG(ERROR) << "request cmd_id:" << cmd_id << " failed, " "err:" << _cntls[request_idx]->ErrorText(); @@ -426,7 +426,7 @@ std::string DownpourBrpcClosure::get_response(size_t request_idx, int cmd_id) { int FlClientBrpcClosure::check_response(size_t request_idx, int cmd_id) { if (_cntls[request_idx]->Failed()) { - LOG(ERROR) << "resquest cmd_id:" << cmd_id + LOG(ERROR) << "request cmd_id:" << cmd_id << " failed, " "err:" << _cntls[request_idx]->ErrorText(); @@ -1634,8 +1634,7 @@ void BrpcPsClient::PushSparseTaskConsume() { task_list.reserve(cur_merge_size + 1); - task_list.push_back( - std::move(std::shared_ptr(async_task))); + task_list.push_back(std::shared_ptr(async_task)); while (!task_queue->Empty() && merge_count < cur_merge_size) { ++merge_count; @@ -1667,8 +1666,7 @@ void BrpcPsClient::PushSparseTaskConsume() { for_each(task_list.begin() + 1, task_list.end(), - [&request_kv_num, request_call_num, closure]( - std::shared_ptr &task) { + [closure](std::shared_ptr &task) { closure->add_timer(task->timer()); closure->add_promise(task->promise()); }); @@ -1712,7 +1710,7 @@ void BrpcPsClient::PushSparseTaskConsume() { merge_status[shard_idx].wait(); } - // meger到task_list[0] + // merge到task_list[0] auto async_task = new SparseAsyncTask(*(task_list[0].get())); task_queue->Put(std::move(async_task)); @@ -1978,8 +1976,7 @@ void BrpcPsClient::PushDenseTaskConsume() { closure->add_timer(async_task->timer()); closure->add_promise(async_task->promise()); merge_status[merge_count] = - async_merge_dense_threads.enqueue([closure, - accessor, + async_merge_dense_threads.enqueue([accessor, &total_send_data, total_send_data_size, async_task]() -> int { diff --git a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc index 8d73a563d79f1..d3623c83fa25e 100644 --- a/paddle/fluid/distributed/ps/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/ps/service/brpc_ps_server.cc @@ -140,8 +140,10 @@ std::future BrpcPsServer::SendPServer2PServerMsg( auto promise = std::make_shared>(); std::future fut = promise->get_future(); if (static_cast(to_pserver_id) >= _pserver_channels.size()) { - LOG(FATAL) << "to_pserver_id is out of range pservers, which size is " - << _pserver_channels.size(); + std::stringstream ss; + ss << "to_pserver_id is out of range pservers, which size is " + << _pserver_channels.size(); + PADDLE_THROW(phi::errors::Fatal(ss.str())); promise->set_value(-1); return fut; } @@ -262,7 +264,7 @@ void BrpcPsService::service(google::protobuf::RpcController *cntl_base, brpc::ClosureGuard done_guard(done); std::string log_label("ReceiveCmd-"); if (!request->has_table_id()) { - set_response_code(*response, -1, "PsRequestMessage.tabel_id is required"); + set_response_code(*response, -1, "PsRequestMessage.table_id is required"); return; } @@ -307,7 +309,7 @@ int32_t BrpcPsService::PullDense(Table *table, set_response_code( response, -1, - "PsRequestMessage.datas is requeired at least 1 for num of dense"); + "PsRequestMessage.datas is required at least 1 for num of dense"); return 0; } CostTimer timer("pserver_server_pull_dense"); @@ -409,7 +411,7 @@ int32_t BrpcPsService::Barrier(Table *table, if (request.params_size() < 1) { set_response_code(response, -1, - "PsRequestMessage.params is requeired at " + "PsRequestMessage.params is required at " "least 1 for num of sparse_key"); return 0; } @@ -436,7 +438,7 @@ int32_t BrpcPsService::PushSparseParam(Table *table, if (request.params_size() < 1) { set_response_code(response, -1, - "PsRequestMessage.params is requeired at " + "PsRequestMessage.params is required at " "least 1 for num of sparse_key"); return 0; } @@ -515,7 +517,7 @@ int32_t BrpcPsService::PullSparse(Table *table, if (request.params_size() < 1) { set_response_code(response, -1, - "PsRequestMessage.params is requeired at " + "PsRequestMessage.params is required at " "least 1 for num of sparse_key"); return 0; } @@ -565,7 +567,7 @@ int32_t BrpcPsService::PushSparse(Table *table, if (request.params_size() < 1) { set_response_code(response, -1, - "PsRequestMessage.params is requeired at " + "PsRequestMessage.params is required at " "least 1 for num of sparse_key"); return 0; } @@ -616,7 +618,7 @@ int32_t BrpcPsService::LoadOneTable(Table *table, set_response_code( response, -1, - "PsRequestMessage.datas is requeired at least 2 for path & load_param"); + "PsRequestMessage.datas is required at least 2 for path & load_param"); return -1; } if (table->Load(request.params(0), request.params(1)) != 0) { @@ -649,7 +651,7 @@ int32_t BrpcPsService::SaveOneTable(Table *table, set_response_code( response, -1, - "PsRequestMessage.datas is requeired at least 2, path&mode"); + "PsRequestMessage.datas is required at least 2, path&mode"); return -1; } table->Flush(); @@ -691,7 +693,7 @@ int32_t BrpcPsService::SaveCacheTable(Table *table, set_response_code( response, -1, - "PsRequestMessage.datas is requeired at least 3, path&mode"); + "PsRequestMessage.datas is required at least 3, path&mode"); return -1; } table->Flush(); @@ -717,7 +719,7 @@ int32_t BrpcPsService::CacheShuffle(Table *table, if (request.params_size() < 3) { set_response_code(response, -1, - "PsRequestMessage.datas is requeired at least 3, " + "PsRequestMessage.datas is required at least 3, " "path&mode&cache_threshold"); return -1; } @@ -805,7 +807,7 @@ int32_t BrpcPsService::ShrinkTable(Table *table, set_response_code( response, -1, - "PsRequestMessage.datas is requeired at least 1, threshold"); + "PsRequestMessage.datas is required at least 1, threshold"); return -1; } table->Flush(); diff --git a/paddle/fluid/distributed/ps/service/brpc_utils.h b/paddle/fluid/distributed/ps/service/brpc_utils.h index cea33219e4bcd..6206f1a6d8415 100644 --- a/paddle/fluid/distributed/ps/service/brpc_utils.h +++ b/paddle/fluid/distributed/ps/service/brpc_utils.h @@ -28,7 +28,7 @@ limitations under the License. */ #include "paddle/fluid/framework/selected_rows_utils.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/var_type.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace butil { class IOBuf; diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.cc b/paddle/fluid/distributed/ps/service/communicator/communicator.cc index 7d8ad7ebad5e8..987dfa443eea2 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.cc +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" #include "paddle/fluid/distributed/ps/wrapper/fleet.h" #include "paddle/fluid/platform/profiler.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@" #define STEP_COUNTER "@PS_STEP_COUNTER@" @@ -254,8 +254,8 @@ void Communicator::RpcSendSparseParam(const std::string &varname, push_g_vec.push_back(tensor->data() + i * dim); } - DownpourBrpcClosure *closure = new DownpourBrpcClosure( - request_call_num, [this, request_call_num](void *done) { + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; // NOLINT for (size_t i = 0; i < request_call_num; ++i) { @@ -422,8 +422,8 @@ void Communicator::SendGlobalStep(const CommContext &ctx, auto *data = out_t->mutable_data({1}, platform::CPUPlace()); data[0] = static_cast(batches); VLOG(3) << "Communicator::SendGlobalStep send: " << batches; - DownpourBrpcClosure *closure = new DownpourBrpcClosure( - request_call_num, [this, request_call_num](void *done) { + DownpourBrpcClosure *closure = + new DownpourBrpcClosure(request_call_num, [request_call_num](void *done) { int ret = 0; auto *closure = (DownpourBrpcClosure *)done; // NOLINT for (size_t i = 0; i < request_call_num; ++i) { diff --git a/paddle/fluid/distributed/ps/service/communicator/communicator.h b/paddle/fluid/distributed/ps/service/communicator/communicator.h index 3af382779c66b..c12f5034968d6 100644 --- a/paddle/fluid/distributed/ps/service/communicator/communicator.h +++ b/paddle/fluid/distributed/ps/service/communicator/communicator.h @@ -40,10 +40,10 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/selected_rows_functor.h" +#include "paddle/utils/string/split.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/service/coordinator_client.cc b/paddle/fluid/distributed/ps/service/coordinator_client.cc index 691b427d2bfde..bf8233ec975fd 100644 --- a/paddle/fluid/distributed/ps/service/coordinator_client.cc +++ b/paddle/fluid/distributed/ps/service/coordinator_client.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" #include "paddle/fluid/framework/archive.h" -#include "paddle/fluid/string/split.h" +#include "paddle/utils/string/split.h" static const int MIN_PORT = 8500; static const int MAX_PORT = 65535; diff --git a/paddle/fluid/distributed/ps/service/coordinator_client.h b/paddle/fluid/distributed/ps/service/coordinator_client.h index 8db08c3fc7999..f0d1116fca268 100644 --- a/paddle/fluid/distributed/ps/service/coordinator_client.h +++ b/paddle/fluid/distributed/ps/service/coordinator_client.h @@ -81,7 +81,7 @@ class CoordinatorServiceHandle { lck.unlock(); VLOG(0) << "last_round_total_fl_clients_num: " << last_round_total_fl_clients_num - << ", has recved fl client num: " << _fl_clients_count.load(); + << ", has received fl client num: " << _fl_clients_count.load(); return; } @@ -102,7 +102,7 @@ class CoordinatorServiceHandle { timeline.Pause(); query_wait_time += timeline.ElapsedSec(); } - // LOG(WARNNING) << "fl-ps > query_wait_time exceed!"; + // LOG(WARNING) << "fl-ps > query_wait_time exceed!"; return true; }; diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_client.cc b/paddle/fluid/distributed/ps/service/graph_brpc_client.cc index e5a7cc38c5987..3725295ac7a26 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_client.cc +++ b/paddle/fluid/distributed/ps/service/graph_brpc_client.cc @@ -25,7 +25,7 @@ #include "paddle/fluid/distributed/ps/service/brpc_ps_client.h" #include "paddle/fluid/distributed/ps/table/table.h" #include "paddle/fluid/framework/archive.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/service/graph_brpc_server.cc b/paddle/fluid/distributed/ps/service/graph_brpc_server.cc index 0a8867bb66e11..29e21e7b9ed50 100644 --- a/paddle/fluid/distributed/ps/service/graph_brpc_server.cc +++ b/paddle/fluid/distributed/ps/service/graph_brpc_server.cc @@ -247,7 +247,7 @@ void GraphBrpcService::service(google::protobuf::RpcController *cntl_base, brpc::ClosureGuard done_guard(done); std::string log_label("ReceiveCmd-"); if (!request->has_table_id()) { - set_response_code(*response, -1, "PsRequestMessage.tabel_id is required"); + set_response_code(*response, -1, "PsRequestMessage.table_id is required"); return; } @@ -558,10 +558,8 @@ int32_t GraphBrpcService::sample_neighbors_across_multi_servers( auto local_promise = std::make_shared>(); std::future local_fut = local_promise->get_future(); std::vector failed(server_size, false); - std::function func = [&, - node_id_buckets, - query_idx_buckets, - request_call_num](void *done) { + std::function func = [&, node_id_buckets, query_idx_buckets]( + void *done) { local_fut.get(); std::vector actual_size; auto *closure = reinterpret_cast(done); diff --git a/paddle/fluid/distributed/ps/service/heter_client.h b/paddle/fluid/distributed/ps/service/heter_client.h index e6c231338ac52..36fd97d95da49 100755 --- a/paddle/fluid/distributed/ps/service/heter_client.h +++ b/paddle/fluid/distributed/ps/service/heter_client.h @@ -32,7 +32,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN -#include "paddle/fluid/string/split.h" +#include "paddle/utils/string/split.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/distributed/ps/service/heter_server.cc b/paddle/fluid/distributed/ps/service/heter_server.cc index 26dd4e6052c9b..0ea3ff3943f7f 100644 --- a/paddle/fluid/distributed/ps/service/heter_server.cc +++ b/paddle/fluid/distributed/ps/service/heter_server.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/distributed/ps/service/heter_server.h" -#include "paddle/fluid/string/split.h" +#include "paddle/utils/string/split.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h index 44836e7661b5f..58203c4816d44 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h +++ b/paddle/fluid/distributed/ps/service/ps_service/graph_py_service.h @@ -39,8 +39,8 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/string/printf.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/service/ps_service/service.cc b/paddle/fluid/distributed/ps/service/ps_service/service.cc index e66475e88d875..b3cc588076036 100644 --- a/paddle/fluid/distributed/ps/service/ps_service/service.cc +++ b/paddle/fluid/distributed/ps/service/ps_service/service.cc @@ -21,7 +21,7 @@ #include #include "paddle/fluid/distributed/ps/service/communicator/communicator.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" using namespace std; // NOLINT diff --git a/paddle/fluid/distributed/ps/service/server.h b/paddle/fluid/distributed/ps/service/server.h index bae9ab652ff74..57b697f30919b 100644 --- a/paddle/fluid/distributed/ps/service/server.h +++ b/paddle/fluid/distributed/ps/service/server.h @@ -100,7 +100,8 @@ class PSServer { int msg_type UNUSED, int to_pserver_id UNUSED, const std::string &msg UNUSED) { - LOG(FATAL) << "NotImplementError: PSServer::send_pserver2pserver_msg"; + PADDLE_THROW(phi::errors::Unimplemented( + "NotImplementError: PSServer::send_pserver2pserver_msg")); std::promise promise; std::future fut = promise.get_future(); promise.set_value(-1); @@ -130,7 +131,8 @@ class PSServer { virtual int32_t ReceiveFromPServer(int msg_type UNUSED, int pserver_id UNUSED, const std::string &msg UNUSED) { - LOG(FATAL) << "NotImplementError::PSServer::ReceiveFromPServer"; + PADDLE_THROW(phi::errors::Unimplemented( + "NotImplementError::PSServer::ReceiveFromPServer")); return -1; } diff --git a/paddle/fluid/distributed/ps/service/simple_rpc/baidu_rpc_server.cc b/paddle/fluid/distributed/ps/service/simple_rpc/baidu_rpc_server.cc index f3e501dd00ce1..9eafbc6e3733e 100644 --- a/paddle/fluid/distributed/ps/service/simple_rpc/baidu_rpc_server.cc +++ b/paddle/fluid/distributed/ps/service/simple_rpc/baidu_rpc_server.cc @@ -114,7 +114,7 @@ class BRpcServiceImpl : public SimpleRpcService { phi::errors::PreconditionNotMet("Service should not be nullptr.")); head.service->decrease_request(); } else { - LOG(FATAL) << "Unknown message type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown message type")); } baidu_rpc_response->set_archive_size(0); done->Run(); @@ -188,7 +188,7 @@ void BaiduRpcServer::initialize() { cep.ip = butil::int2ip(_ips[i]); cep.port = ports[i]; if (channel_ptr->Init(cep, &option) != 0) { - LOG(FATAL) << "Failed to initialize channel"; + PADDLE_THROW(phi::errors::Fatal("Failed to initialize channel")); } LOG(INFO) << "connected to " << butil::endpoint2str(cep).c_str(); return channel_ptr; @@ -242,7 +242,7 @@ static void handle_baidu_rpc_response(brpc::Controller *cntl, phi::errors::PreconditionNotMet("Service should not be nullptr.")); head.service->decrease_request(); } else { - LOG(FATAL) << "Unknown message type"; + PADDLE_THROW(phi::errors::InvalidArgument("Unknown message type")); } } delete baidu_rpc_response; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.cc b/paddle/fluid/distributed/ps/table/common_graph_table.cc index 7b0f513358d46..f8347e027e417 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.cc +++ b/paddle/fluid/distributed/ps/table/common_graph_table.cc @@ -30,9 +30,9 @@ #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/io/fs.h" #include "paddle/fluid/platform/timer.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/generator.h" +#include "paddle/utils/string/printf.h" +#include "paddle/utils/string/string_helper.h" COMMON_DECLARE_bool(graph_load_in_parallel); COMMON_DECLARE_bool(graph_get_neighbor_id); @@ -1621,11 +1621,10 @@ void GraphTable::clear_edge_shard() { std::vector> tasks; for (auto &type_shards : edge_shards) { for (auto &shard : type_shards) { - tasks.push_back( - load_node_edge_task_pool->enqueue([&shard, this]() -> int { - delete shard; - return 0; - })); + tasks.push_back(load_node_edge_task_pool->enqueue([&shard]() -> int { + delete shard; + return 0; + })); } } for (auto &task : tasks) task.get(); @@ -1643,11 +1642,10 @@ void GraphTable::clear_feature_shard() { std::vector> tasks; for (auto &type_shards : feature_shards) { for (auto &shard : type_shards) { - tasks.push_back( - load_node_edge_task_pool->enqueue([&shard, this]() -> int { - delete shard; - return 0; - })); + tasks.push_back(load_node_edge_task_pool->enqueue([&shard]() -> int { + delete shard; + return 0; + })); } } for (auto &task : tasks) task.get(); @@ -1665,11 +1663,10 @@ void GraphTable::clear_node_shard() { std::vector> tasks; for (auto &type_shards : node_shards) { for (auto &shard : type_shards) { - tasks.push_back( - load_node_edge_task_pool->enqueue([&shard, this]() -> int { - delete shard; - return 0; - })); + tasks.push_back(load_node_edge_task_pool->enqueue([&shard]() -> int { + delete shard; + return 0; + })); } } for (size_t i = 0; i < tasks.size(); i++) tasks[i].get(); @@ -2898,7 +2895,7 @@ int32_t GraphTable::get_nodes_ids_by_ranges( first -= total_size; second -= total_size; tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( - [&shards, this, first, second, i, &res, &mutex]() -> size_t { + [&shards, first, second, i, &res, &mutex]() -> size_t { std::vector keys; shards[i]->get_ids_by_range(first, second, &keys); @@ -3322,8 +3319,7 @@ int32_t GraphTable::pull_graph_list(GraphTableType table_type, int count = std::min(1 + (size + cur_size - start - 1) / step, total_size); int end = start + (count - 1) * step + 1; tasks.push_back(_shards_task_pool[i % task_pool_size_]->enqueue( - [&search_shards, this, i, start, end, step, size]() - -> std::vector { + [&search_shards, i, start, end, step, size]() -> std::vector { return search_shards[i]->get_batch(start - size, end - size, step); })); start += count * step; diff --git a/paddle/fluid/distributed/ps/table/common_graph_table.h b/paddle/fluid/distributed/ps/table/common_graph_table.h index 3077f0d6fb867..510562948ffeb 100644 --- a/paddle/fluid/distributed/ps/table/common_graph_table.h +++ b/paddle/fluid/distributed/ps/table/common_graph_table.h @@ -43,8 +43,8 @@ #include "paddle/fluid/distributed/ps/table/graph/class_macro.h" #include "paddle/fluid/distributed/ps/table/graph/graph_node.h" #include "paddle/fluid/distributed/ps/thirdparty/round_robin.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/utils/rw_lock.h" +#include "paddle/utils/string/string_helper.h" #ifdef PADDLE_WITH_HETERPS #include "paddle/fluid/distributed/ps/table/depends/rocksdb_warpper.h" diff --git a/paddle/fluid/distributed/ps/table/ctr_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_accessor.cc index 07175e1069527..70954f0b7ad96 100644 --- a/paddle/fluid/distributed/ps/table/ctr_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_accessor.cc @@ -16,7 +16,7 @@ #include "glog/logging.h" #include "paddle/common/flags.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc index 375014cfa37f8..2b3a27e9c47bc 100644 --- a/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_double_accessor.cc @@ -16,7 +16,7 @@ #include "glog/logging.h" #include "paddle/common/flags.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc index 746fc02487aa5..d3864be773c21 100644 --- a/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc +++ b/paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc @@ -16,7 +16,7 @@ #include "glog/logging.h" #include "paddle/common/flags.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/graph/graph_node.h b/paddle/fluid/distributed/ps/table/graph/graph_node.h index e5978dfbcbfb2..9cc88d2845762 100644 --- a/paddle/fluid/distributed/ps/table/graph/graph_node.h +++ b/paddle/fluid/distributed/ps/table/graph/graph_node.h @@ -26,7 +26,7 @@ #include "glog/logging.h" #include "paddle/fluid/distributed/ps/table/graph/graph_weighted_sampler.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/memory_dense_table.cc b/paddle/fluid/distributed/ps/table/memory_dense_table.cc index 84087605a42fb..641f4e4f73ceb 100644 --- a/paddle/fluid/distributed/ps/table/memory_dense_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_dense_table.cc @@ -356,7 +356,7 @@ int32_t MemoryDenseTable::Save(const std::string &path, os << " "; os << values_[param_col_ids_[x]][y]; } - result_buffer_param.emplace_back(std::move(os.str())); + result_buffer_param.emplace_back(os.str()); } } else { std::ostringstream os; @@ -368,7 +368,7 @@ int32_t MemoryDenseTable::Save(const std::string &path, os << " "; os << values_[param_col_ids_[x]][y]; } - result_buffer_param.emplace_back(std::move(os.str())); + result_buffer_param.emplace_back(os.str()); } } diff --git a/paddle/fluid/distributed/ps/table/memory_dense_table.h b/paddle/fluid/distributed/ps/table/memory_dense_table.h index 9b007cca0196a..ff9af25dddea2 100644 --- a/paddle/fluid/distributed/ps/table/memory_dense_table.h +++ b/paddle/fluid/distributed/ps/table/memory_dense_table.h @@ -25,7 +25,7 @@ #include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/depends/dense.h" #include "paddle/fluid/distributed/ps/table/depends/initializers.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h index 4328615406895..8fc32f2d4859d 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_geo_table.h @@ -28,7 +28,7 @@ #include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/depends/feature_value.h" #include "paddle/fluid/distributed/ps/table/depends/geo_recorder.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc index 262f774005e27..a2f8ff346ffca 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.cc @@ -1213,18 +1213,10 @@ int32_t MemorySparseTable::PushSparse(const uint64_t *keys, size_t value_col = _value_accessor->GetAccessorInfo().size / sizeof(float); size_t mf_value_col = _value_accessor->GetAccessorInfo().mf_size / sizeof(float); - size_t update_value_col = - _value_accessor->GetAccessorInfo().update_size / sizeof(float); for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { tasks[shard_id] = _shards_task_pool[shard_id % _task_pool_size]->enqueue( - [this, - shard_id, - value_col, - mf_value_col, - update_value_col, - values, - &task_keys]() -> int { + [this, shard_id, value_col, mf_value_col, values, &task_keys]() -> int { auto &keys = task_keys[shard_id]; auto &local_shard = _local_shards[shard_id]; float data_buffer[value_col]; // NOLINT diff --git a/paddle/fluid/distributed/ps/table/memory_sparse_table.h b/paddle/fluid/distributed/ps/table/memory_sparse_table.h index 5b5a6d41c7b77..6fb2259e443a8 100644 --- a/paddle/fluid/distributed/ps/table/memory_sparse_table.h +++ b/paddle/fluid/distributed/ps/table/memory_sparse_table.h @@ -29,7 +29,7 @@ #include "paddle/fluid/distributed/ps/table/accessor.h" #include "paddle/fluid/distributed/ps/table/common_table.h" #include "paddle/fluid/distributed/ps/table/depends/feature_value.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #define PSERVER_SAVE_SUFFIX ".shard" diff --git a/paddle/fluid/distributed/ps/table/sparse_accessor.cc b/paddle/fluid/distributed/ps/table/sparse_accessor.cc index 835292c29d3ee..5689ccfe7a594 100644 --- a/paddle/fluid/distributed/ps/table/sparse_accessor.cc +++ b/paddle/fluid/distributed/ps/table/sparse_accessor.cc @@ -16,7 +16,7 @@ #include "glog/logging.h" #include "paddle/common/flags.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc index d72b4ee1c3d3f..6e4309a663b4d 100644 --- a/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc +++ b/paddle/fluid/distributed/ps/table/ssd_sparse_table.cc @@ -102,7 +102,6 @@ int32_t SSDSparseTable::PullSparse(float* pull_values, mf_value_size, select_value_size, pull_values, - keys, &missed_keys]() -> int { auto& keys = task_keys[shard_id]; auto& local_shard = _local_shards[shard_id]; @@ -432,8 +431,8 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, size_t value_col = _value_accessor->GetAccessorInfo().size / sizeof(float); size_t mf_value_col = _value_accessor->GetAccessorInfo().mf_size / sizeof(float); - size_t update_value_col = - _value_accessor->GetAccessorInfo().update_size / sizeof(float); + // size_t update_value_col = + // _value_accessor->GetAccessorInfo().update_size / sizeof(float); { std::vector> tasks(_real_local_shard_num); std::vector>> task_keys( @@ -445,13 +444,8 @@ int32_t SSDSparseTable::PushSparse(const uint64_t* keys, for (int shard_id = 0; shard_id < _real_local_shard_num; ++shard_id) { tasks[shard_id] = _shards_task_pool[shard_id % _shards_task_pool.size()]->enqueue( - [this, - shard_id, - value_col, - mf_value_col, - update_value_col, - values, - &task_keys]() -> int { + [this, shard_id, value_col, mf_value_col, values, &task_keys]() + -> int { auto& keys = task_keys[shard_id]; auto& local_shard = _local_shards[shard_id]; float data_buffer[value_col]; // NOLINT @@ -706,8 +700,10 @@ int32_t SSDSparseTable::SaveWithString(const std::string& path, out_str.second.data(), out_str.second.size()); if (0 != write_channel->write_line(::paddle::string::format_string( "%lu %s", out_str.first, format_value.c_str()))) { - LOG(FATAL) << "SSDSparseTable save failed, retry it! path:" - << channel_config.path; + std::stringstream ss; + ss << "SSDSparseTable save failed, retry it! path:" + << channel_config.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } } write_channel->close(); @@ -1647,8 +1643,10 @@ int32_t SSDSparseTable::SaveWithBinary(const std::string& path, last_file_idx = region->_file_idx; } if (0 != write_channel->write(region->_buf, region->_cur)) { - LOG(FATAL) << "DownpourSparseSSDTable save failed, retry it! path:" - << channel_config.path; + std::stringstream ss; + ss << "DownpourSparseSSDTable save failed, retry it! path:" + << channel_config.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); CHECK(false); } region->reset(); @@ -1688,8 +1686,10 @@ int32_t SSDSparseTable::SaveWithBinary(const std::string& path, std::string format_value = _value_accessor->ParseToString(value, dim); if (0 != write_channel->write_line(paddle::string::format_string( "%lu %s", k, format_value.c_str()))) { - LOG(FATAL) << "SSDSparseTable save failed, retry it! path:" - << channel_config.path; + std::stringstream ss; + ss << "SSDSparseTable save failed, retry it! path:" + << channel_config.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } remain -= len; cursor += len; @@ -1971,8 +1971,10 @@ int32_t SSDSparseTable::SaveWithBinary_v2(const std::string& path, last_file_idx = region->_file_idx; } if (0 != write_channel->write(region->_buf, region->_cur)) { - LOG(FATAL) << "DownpourSparseSSDTable save failed, retry it! path:" - << channel_config.path; + std::stringstream ss; + ss << "DownpourSparseSSDTable save failed, retry it! path:" + << channel_config.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); CHECK(false); } region->reset(); @@ -2001,9 +2003,10 @@ int32_t SSDSparseTable::SaveWithBinary_v2(const std::string& path, if (0 != write_channel_for_slot_feature->write( region_for_slot_feature->_buf, region_for_slot_feature->_cur)) { - LOG(FATAL) - << "DownpourSparseSSDTable save feature failed, retry it! path:" - << channel_config_for_slot_feature.path; + std::stringstream ss; + ss << "DownpourSparseSSDTable save feature failed, retry it! path:" + << channel_config_for_slot_feature.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); CHECK(false); } region_for_slot_feature->reset(); @@ -2044,8 +2047,10 @@ int32_t SSDSparseTable::SaveWithBinary_v2(const std::string& path, std::string format_value = _value_accessor->ParseToString(value, dim); if (0 != write_channel->write_line(paddle::string::format_string( "%lu %s", k, format_value.c_str()))) { - LOG(FATAL) << "SSDSparseTable save failed, retry it! path:" - << channel_config.path; + std::stringstream ss; + ss << "SSDSparseTable save failed, retry it! path:" + << channel_config.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } remain -= len; cursor += len; @@ -2094,8 +2099,10 @@ int32_t SSDSparseTable::SaveWithBinary_v2(const std::string& path, if (0 != write_channel_for_slot_feature->write_line( paddle::string::format_string( "%lu %s", k, format_value.c_str()))) { - LOG(FATAL) << "SSDSparseTable save feature failed, retry it! path:" - << channel_config_for_slot_feature.path; + std::stringstream ss; + ss << "SSDSparseTable save feature failed, retry it! path:" + << channel_config_for_slot_feature.path; + PADDLE_THROW(phi::errors::Fatal(ss.str())); } remain -= len; cursor += len; diff --git a/paddle/fluid/distributed/ps/table/table.h b/paddle/fluid/distributed/ps/table/table.h index 779d6c6c32295..b3c80673aa793 100644 --- a/paddle/fluid/distributed/ps/table/table.h +++ b/paddle/fluid/distributed/ps/table/table.h @@ -32,7 +32,7 @@ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/test/ctr_accessor_test.cc b/paddle/fluid/distributed/test/ctr_accessor_test.cc index 9b71e4524625c..0288a93d71a96 100644 --- a/paddle/fluid/distributed/test/ctr_accessor_test.cc +++ b/paddle/fluid/distributed/test/ctr_accessor_test.cc @@ -79,7 +79,7 @@ TEST(downpour_feature_value_accessor_test, test_shrink) { float* value = new float[acc->GetAccessorInfo().dim]; for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) { - value[i] = i * 1.0; + value[i] = static_cast(i) * 1.0; } ASSERT_TRUE(!acc->Shrink(value)); @@ -98,7 +98,7 @@ TEST(downpour_feature_value_accessor_test, test_save) { float* value = new float[acc->GetAccessorInfo().dim]; for (auto i = 0u; i < acc->GetAccessorInfo().dim; ++i) { - value[i] = i * 1.0; + value[i] = static_cast(i) * 1.0; } // save all feature @@ -166,7 +166,7 @@ TEST(downpour_feature_value_accessor_test, test_update) { for (auto i = 0u; i < item_size; ++i) { float* p = new float[acc->GetAccessorInfo().update_dim]; for (auto j = 0u; j < acc->GetAccessorInfo().update_dim; ++j) { - p[j] = i + 1; + p[j] = static_cast(i) + 1.0; } grad[i] = p; } @@ -288,7 +288,7 @@ TEST(downpour_feature_value_accessor_test, test_string_related) { const int field_size = 15; float* value = new float[field_size]; for (auto i = 0u; i < field_size; ++i) { - value[i] = i; + value[i] = static_cast(i); } auto str = acc->ParseToString(value, 0); diff --git a/paddle/fluid/distributed/test/graph_node_split_test.cc b/paddle/fluid/distributed/test/graph_node_split_test.cc index cb47f3103883f..bc2fcea6bb75f 100644 --- a/paddle/fluid/distributed/test/graph_node_split_test.cc +++ b/paddle/fluid/distributed/test/graph_node_split_test.cc @@ -38,8 +38,8 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/string/printf.h" namespace framework = paddle::framework; @@ -55,7 +55,7 @@ std::vector edges = {std::string("37\t45\t0.34"), std::string("97\t48\t0.34"), std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; -char edge_file_name[] = "edges.txt"; +char edge_file_name[] = "edges.txt"; // NOLINT std::vector nodes = { std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"), @@ -74,12 +74,12 @@ std::vector nodes = { std::string("item\t49\ta 0.21"), std::string("item\t248\ta 0.21"), std::string("item\t113\ta 0.21")}; -char node_file_name[] = "nodes.txt"; +char node_file_name[] = "nodes.txt"; // NOLINT std::vector graph_split = {std::string("0\t97")}; -char graph_split_file_name[] = "graph_split.txt"; +char graph_split_file_name[] = "graph_split.txt"; // NOLINT -void prepare_file(char file_name[], std::vector data) { +void prepare_file(char file_name[], std::vector data) { // NOLINT std::ofstream ofile; ofile.open(file_name); for (auto x : data) { diff --git a/paddle/fluid/distributed/test/graph_node_test.cc b/paddle/fluid/distributed/test/graph_node_test.cc index 8c29c2bf1df3f..55255f2b75347 100644 --- a/paddle/fluid/distributed/test/graph_node_test.cc +++ b/paddle/fluid/distributed/test/graph_node_test.cc @@ -39,8 +39,8 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/string/printf.h" namespace framework = paddle::framework; namespace distributed = paddle::distributed; @@ -236,8 +236,8 @@ const char* edges[] = {"37\t45\t0.34", "59\t122\t0.21", "97\t48\t0.34", "97\t247\t0.31", - "97\t111\t0.21"}; -char edge_file_name[] = "edges.txt"; + "97\t111\t0.21"}; // NOLINT +char edge_file_name[] = "edges.txt"; // NOLINT const char* nodes[] = {"user\t37\ta 0.34\tb 13 14\tc hello\td abc", "user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd", @@ -254,10 +254,10 @@ const char* nodes[] = {"user\t37\ta 0.34\tb 13 14\tc hello\td abc", "item\t122\ta 0.21", "item\t49\ta 0.21", "item\t248\ta 0.21", - "item\t113\ta 0.21"}; -char node_file_name[] = "nodes.txt"; + "item\t113\ta 0.21"}; // NOLINT +char node_file_name[] = "nodes.txt"; // NOLINT -void prepare_file(char file_name[], bool load_edge) { +void prepare_file(char file_name[], bool load_edge) { // NOLINT std::ofstream ofile; ofile.open(file_name); if (load_edge) { diff --git a/paddle/fluid/distributed/test/graph_table_sample_test.cc b/paddle/fluid/distributed/test/graph_table_sample_test.cc index 5489129a070dd..286b19b7070ac 100644 --- a/paddle/fluid/distributed/test/graph_table_sample_test.cc +++ b/paddle/fluid/distributed/test/graph_table_sample_test.cc @@ -43,7 +43,7 @@ std::vector edges = {std::string("37\t45\t0.34"), std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; // odd id:96 48 122 112 -char edge_file_name[] = "edges.txt"; +char edge_file_name[] = "edges.txt"; // NOLINT std::vector nodes = { std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"), @@ -62,9 +62,9 @@ std::vector nodes = { std::string("item\t49\ta 0.21"), std::string("item\t248\ta 0.21"), std::string("item\t113\ta 0.21")}; -char node_file_name[] = "nodes.txt"; +char node_file_name[] = "nodes.txt"; // NOLINT -void prepare_file(char file_name[], std::vector data) { +void prepare_file(char file_name[], std::vector data) { // NOLINT std::ofstream ofile; ofile.open(file_name); for (auto x : data) { diff --git a/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc b/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc index 120d8de56f793..a7029d1e8b127 100644 --- a/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc +++ b/paddle/fluid/distributed/test/sparse_sgd_rule_test.cc @@ -37,8 +37,8 @@ TEST(sparse_value_naive_sgd_test, init_and_update) { // check init_value for zero const int kItemSize = 10; - float w[kItemSize]; - float grad[kItemSize]; + float w[kItemSize]; // NOLINT + float grad[kItemSize]; // NOLINT rule.InitValue(w, w + 9, true); for (float item : w) { @@ -58,16 +58,16 @@ TEST(sparse_value_naive_sgd_test, init_and_update) { for (auto i = 0u; i < kItemSize; ++i) { grad[i] = static_cast(i + 1) * 1.0; } - float label[] = {-0.100000, - -0.200000, - -0.300000, - -0.400000, - -0.500000, - -0.600000, - -0.700000, - -0.800000, - -0.900000, - -1.000000}; + std::array label = {-0.100000, + -0.200000, + -0.300000, + -0.400000, + -0.500000, + -0.600000, + -0.700000, + -0.800000, + -0.900000, + -1.000000}; const float* ptr_grad = grad; rule.UpdateValue(w, w + 9, ptr_grad); @@ -93,7 +93,7 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { // check init_value for zero const int kValueSize = 11; int kEmbSize = 10; - float w[kValueSize]; + float w[kValueSize]; // NOLINT rule.InitValue(w, w + 10, true); @@ -114,24 +114,24 @@ TEST(downpour_sparse_adagrad_test, test_init_and_update) { w[i] = 0; } w[kEmbSize] = 0; - float grad[kEmbSize]; + float grad[kEmbSize]; // NOLINT for (int i = 0; i < kEmbSize; ++i) { grad[i] = static_cast(i + 1) * 1.0; } const float* ptr_grad = grad; rule.UpdateValue(w, w + 10, ptr_grad); - float label[] = {-0.100000, - -0.200000, - -0.300000, - -0.400000, - -0.500000, - -0.600000, - -0.700000, - -0.800000, - -0.900000, - -1.000000, - 38.500000}; + std::array label = {-0.100000, + -0.200000, + -0.300000, + -0.400000, + -0.500000, + -0.600000, + -0.700000, + -0.800000, + -0.900000, + -1.000000, + 38.500000}; for (auto i = 0u; i < kValueSize; ++i) { ASSERT_FLOAT_EQ(w[i], label[i]); } @@ -190,14 +190,14 @@ TEST(downpour_sparse_adam_test, test_init_and_update) { grad[i] = static_cast(i + 1) * 1.0; } - float label[] = {-0.0999999642, -0.099999994, -0.099999994, -0.099999994, - -0.099999994, -0.099999994, -0.099999994, -0.100000001, - -0.100000009, -0.100000001, 0.100000024, 0.200000048, - 0.300000072, 0.400000095, 0.500000119, 0.600000143, - 0.700000167, 0.800000191, 0.900000215, 1.00000024, - 0.000999987125, 0.0039999485, 0.00899988413, 0.015999794, - 0.0249996781, 0.0359995365, 0.0489993691, 0.063999176, - 0.0809989572, 0.0999987125, 0.809999943, 0.998001039}; + std::array label = { + -0.0999999642, -0.099999994, -0.099999994, -0.099999994, -0.099999994, + -0.099999994, -0.099999994, -0.100000001, -0.100000009, -0.100000001, + 0.100000024, 0.200000048, 0.300000072, 0.400000095, 0.500000119, + 0.600000143, 0.700000167, 0.800000191, 0.900000215, 1.00000024, + 0.000999987125, 0.0039999485, 0.00899988413, 0.015999794, 0.0249996781, + 0.0359995365, 0.0489993691, 0.063999176, 0.0809989572, 0.0999987125, + 0.809999943, 0.998001039}; rule.UpdateValue(value, value + embed_dim, grad); diff --git a/paddle/fluid/eager/amp_auto_cast.h b/paddle/fluid/eager/amp_auto_cast.h index 09a2b73e2e693..ac1f1d5d16972 100644 --- a/paddle/fluid/eager/amp_auto_cast.h +++ b/paddle/fluid/eager/amp_auto_cast.h @@ -53,8 +53,7 @@ inline std::vector AmpAutoCasts( paddle::framework::AttributeMap cast_attrs = { {"in_dtype", paddle::framework::TransToProtoVarType(input.dtype())}, {"out_dtype", paddle::framework::TransToProtoVarType(dst_dtype)}}; - inputs_casted.emplace_back( - std::move(cast_dygraph_function(input, cast_attrs))); + inputs_casted.emplace_back(cast_dygraph_function(input, cast_attrs)); } else { inputs_casted.emplace_back(input); } diff --git a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc index 9d1451c74e65f..aa18f8cd4acb8 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/forwards/multiply_fwd_func.cc @@ -27,6 +27,15 @@ COMMON_DECLARE_bool(check_nan_inf); +bool check_if_support_elementwise_mul_mem_opt(const std::string& device_type) { + // TODO(@gexiao): replace this function with api implemented at custom repo + if (device_type == "npu") { + return true; + } else { + return false; + } +} + paddle::Tensor multiply_ad_func(const paddle::Tensor& x, const paddle::Tensor& y) { FLAGS_tensor_operants_mode = "eager"; @@ -160,7 +169,11 @@ paddle::Tensor multiply_ad_func(const paddle::Tensor& x, } // SetAttributes if needed grad_node->SetAttribute_axis(-1); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (check_if_support_elementwise_mul_mem_opt(x.place().GetDeviceType())) { +#else if (paddle::platform::is_gpu_place(x.place())) { +#endif if (x_autograd_meta != nullptr && x_autograd_meta->StopGradient() && y_autograd_meta != nullptr && !y_autograd_meta->StopGradient()) { grad_node->SetTensorWrapper_x(x); diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc index 84162355e2f88..5d2912d4beb6a 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/add_n_node.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/api/all.h" #include "paddle/phi/api/lib/api_custom_impl.h" @@ -34,6 +35,19 @@ AddNGradNodeFinal::operator()( bool is_new_grad) { // Fill Zero For GradIn Tensors + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_AddNGradNodeFinal", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Apply Gradient Hooks auto hooked_grads = ApplyGradientHooks(grads); diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc index 437cce80c919b..888d96b50fa3c 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc @@ -38,6 +38,19 @@ Conv2dGradNodeFinal::operator()( bool is_new_grad) { // Fill Zero For GradIn Tensors VLOG(3) << " Running Conv2dGradNodeFinal: " << this; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_Conv2dGradNodeFinal", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Apply Gradient Hooks auto hooked_grads = ApplyGradientHooks(grads); @@ -208,6 +221,19 @@ Conv2dDoubleGradNodeFinal::operator()( egr::kSlotSmallVectorSize>& grads, bool create_graph, bool is_new_grad) { + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_Conv2dDoubleGradNodeFinal", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Fill Zero For GradIn Tensors const auto& input_metas = this->InputMeta(); egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[0][0], diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc index 56c1f1e61a7fc..b1f25601d066b 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc @@ -41,6 +41,19 @@ MultiplyGradNode::operator()( bool is_new_grad) { VLOG(3) << "Running AD API GRAD: " << "multiply_grad"; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_MultiplyGradNode", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Fill Zero For GradIn Tensors const auto& input_metas = this->InputMeta(); egr::EagerUtils::FillZeroForEmptyGradInput(&grads[0][0], input_metas[0][0]); @@ -110,7 +123,11 @@ MultiplyGradNode::operator()( // Call grad_api function - if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) { + std::string grad_op_name = "multiply_grad"; + auto need_skip = + paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps( + grad_op_name); + if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled() && !need_skip) { bool original_global_grad = egr::Controller::Instance().HasGrad(); if (!create_graph) { egr::Controller::Instance().SetHasGrad(create_graph); @@ -156,7 +173,7 @@ MultiplyGradNode::operator()( // Create Grad Node - if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) { + if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled() || need_skip) { if (trace_backward) { paddle::platform::RecordEvent node_creation_record_event( "multiply_grad node_creation", @@ -196,6 +213,7 @@ MultiplyGradNode::operator()( } VLOG(4) << "Finish AD API GRAD: multiply_grad"; + VLOG(6) << "gradnode_ptr = " << this; // LOG IF DEBUG if (VLOG_IS_ON(4)) { @@ -240,6 +258,19 @@ MultiplyDoubleGradNode::operator()( bool is_new_grad) { VLOG(3) << "Running AD API GRAD: " << "multiply_double_grad"; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_MultiplyDoubleGradNode", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Fill Zero For GradIn Tensors const auto& input_metas = this->InputMeta(); egr::EagerUtils::FillZeroForEmptyOptionalGradInput(&grads[0][0], @@ -356,22 +387,39 @@ MultiplyDoubleGradNode::operator()( // Call grad_api function - bool original_global_grad = egr::Controller::Instance().HasGrad(); - if (!create_graph) { - egr::Controller::Instance().SetHasGrad(create_graph); - } - paddle::prim::multiply_double_grad(x, - y, - fwd_grad_out, - fwd_grad_grad_x_optional, - fwd_grad_grad_y_optional, - axis, - api_output_0, - api_output_1, - api_output_2); - VLOG(4) << "Composite api multiply_double_grad is called "; - if (!create_graph) { - egr::Controller::Instance().SetHasGrad(original_global_grad); + std::string grad_op_name = "multiply_double_grad"; + auto need_skip = + paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps( + grad_op_name); + if (!need_skip) { + bool original_global_grad = egr::Controller::Instance().HasGrad(); + if (!create_graph) { + egr::Controller::Instance().SetHasGrad(create_graph); + } + paddle::prim::multiply_double_grad(x, + y, + fwd_grad_out, + fwd_grad_grad_x_optional, + fwd_grad_grad_y_optional, + axis, + api_output_0, + api_output_1, + api_output_2); + VLOG(4) << "Composite api multiply_double_grad is called "; + if (!create_graph) { + egr::Controller::Instance().SetHasGrad(original_global_grad); + } + } else { + paddle::experimental::multiply_double_grad(x, + y, + fwd_grad_out, + fwd_grad_grad_x_optional, + fwd_grad_grad_y_optional, + axis, + api_output_0, + api_output_1, + api_output_2); + VLOG(4) << "Fused api multiply_double_grad is called"; } // Check NaN and Inf id needed @@ -411,7 +459,16 @@ MultiplyDoubleGradNode::operator()( // Create Grad Node + if (need_skip) { + if (trace_backward) { + PADDLE_THROW(phi::errors::Unavailable( + "The Op multiply_double_grad doesn't have any grad" + "op. If you don't intend calculating higher order" + "derivatives, please set `create_graph`to False.")); + } + } VLOG(4) << "Finish AD API GRAD: multiply_double_grad"; + VLOG(6) << "gradnode_ptr = " << this; // LOG IF DEBUG if (VLOG_IS_ON(4)) { @@ -474,6 +531,19 @@ MultiplyGradNode::operator()( bool is_new_grad) { VLOG(3) << "Running AD API GRAD: " << "multiply_grad"; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_MultiplyGradNode", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Fill Zero For GradIn Tensors const auto& input_metas = this->InputMeta(); egr::EagerUtils::FillZeroForEmptyGradInput(&grads[0][0], input_metas[0][0]); @@ -573,6 +643,7 @@ MultiplyGradNode::operator()( "derivatives, please set `create_graph`to False.")); } VLOG(4) << "Finish AD API GRAD: multiply_grad"; + VLOG(6) << "gradnode_ptr = " << this; // LOG IF DEBUG if (VLOG_IS_ON(4)) { diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc index 15fd00ed5bbaa..0049c67b4870e 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/reshard_node.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" paddle::small_vector, egr::kSlotSmallVectorSize> // NOLINT @@ -29,6 +30,18 @@ ReshardGradNode::operator()( #ifdef PADDLE_WITH_DISTRIBUTE VLOG(3) << "Running AD API GRAD: " << "reshard_grad"; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_ReshardGradNode", + paddle::platform::TracerEventType::OperatorInner, + 1); // Apply Gradient Hooks auto hooked_grad = ApplyGradientHooks(grads); diff --git a/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc b/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc index 04bfac8ebd5c6..4e327d23e6da9 100644 --- a/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc +++ b/paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/imperative/tracer.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/api/all.h" #include "paddle/phi/api/backward/backward_api.h" #include "paddle/phi/api/backward/sparse_bw_api.h" @@ -37,6 +38,19 @@ SyncBatchNormGradNode::operator()( bool is_new_grad) { VLOG(3) << "Running AD API GRAD: " << "sync_batch_norm_grad"; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_SyncBatchNormGradNode", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Fill Zero For GradIn Tensors // Apply Gradient Hooks @@ -256,6 +270,19 @@ SyncBatchNormGradNode::operator()( bool is_new_grad) { VLOG(3) << "Running AD API GRAD: " << "sync_batch_norm_grad"; + // This 'Local_XXXGradNode' record event is different with + // 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, + // but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra accumulation overhead than + // 'Local_XXXGradNode'. + paddle::platform::RecordEvent node_execution_inner( + "Local_SyncBatchNormGradNode", + paddle::platform::TracerEventType::OperatorInner, + 1); + // Fill Zero For GradIn Tensors // Apply Gradient Hooks diff --git a/paddle/fluid/eager/api/utils/hook_utils.cc b/paddle/fluid/eager/api/utils/hook_utils.cc index bc6706edb2dab..4230c5e0702d8 100644 --- a/paddle/fluid/eager/api/utils/hook_utils.cc +++ b/paddle/fluid/eager/api/utils/hook_utils.cc @@ -30,9 +30,7 @@ int64_t RegisterGradientHookForTensor( auto rank_info = EagerUtils::unsafe_autograd_meta(tensor)->OutRankInfo(); return grad_node->RegisterGradientHook( - rank_info.first, - rank_info.second, - std::move(std::make_shared(hook))); + rank_info.first, rank_info.second, std::make_shared(hook)); } void RegisterReduceHookForTensor(const paddle::Tensor& tensor, @@ -48,7 +46,7 @@ void RegisterReduceHookForTensor(const paddle::Tensor& tensor, auto accumulation_grad_node = std::dynamic_pointer_cast(grad_node); accumulation_grad_node->RegisterReduceHook( - std::move(std::make_shared(hook))); + std::make_shared(hook)); } else { PADDLE_THROW(paddle::platform::errors::Fatal( "Only can register reduce hook for leaf Tensor.")); diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 33d6da07f81a7..52c2f9b9ef123 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -26,7 +26,7 @@ #include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/pybind/eager_generator.h" #include "paddle/fluid/pybind/pybind.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" // phi #include "paddle/phi/kernels/declarations.h" diff --git a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py index c13fb1cb4848c..47bed1595a465 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/generator/codegen_utils.py @@ -38,9 +38,7 @@ "tanh_grad", "tanh_double_grad", "tanh_triple_grad", - "sin_double_grad", "sin_triple_grad", - "cos_double_grad", "cos_triple_grad", "subtract_double_grad", "divide_double_grad", @@ -59,6 +57,7 @@ "conv3d_double_grad", "depthwise_conv2d_grad_grad", "concat_double_grad", + "stack_double_grad", "expand_grad", "argsort_grad", "eigh_grad", diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 74fc6b9a7dbc6..32b36ecf2eea6 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -73,6 +73,8 @@ "add_triple_grad", "silu_double_grad", "tanh_triple_grad", + "minimum_double_grad", + "maximum_double_grad", ] # white ops list whose kernel can automaically do type promotion. @@ -209,6 +211,12 @@ class {} : public egr::GradNodeBase {{ paddle::small_vector, egr::kSlotSmallVectorSize> {}::operator()(paddle::small_vector, egr::kSlotSmallVectorSize>& grads, bool create_graph, bool is_new_grad) {{ VLOG(3) << \"Running AD API GRAD: \" << \"{}\"; + // This 'Local_XXXGradNode' record event is different with 'Global_XXXGradNode' event. + // * 'Local_XXXGradNode' will only cover execution time of this function. + // * 'Global_XXXGradNode' will not only cover execution time of this function, but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared by other OP(s), which may have extra accumulation overhead than 'Local_XXXGradNode'. + paddle::platform::RecordEvent grad_node_record_event_inner(\"Local_{}\", paddle::platform::TracerEventType::OperatorInner, 1); + // Fill Zero For GradIn Tensors {} // Apply Gradient Hooks @@ -242,7 +250,7 @@ class {} : public egr::GradNodeBase {{ VLOG(6) << "gradnode_ptr = " << this; // LOG IF DEBUG - {} +{} // Return {} }} @@ -296,25 +304,25 @@ class {} : public egr::GradNodeBase {{ VLOG(4) << \"Finish AD API: {}"; // LOG IF DEBUG - {} +{} // Returns return {}; }} """ AFTER_LOG_PRINT_TEMPLATE = """ - if(VLOG_IS_ON(4)){{ - const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], \\n Output: [%s] }} \"; - {} - VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str); + if (VLOG_IS_ON(4)) {{ + const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s], \\n Output: [%s] }} \"; +{} + VLOG(4) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str, output_str); }} """ BEFORE_LOG_PRINT_TEMPLATE = """ - if(VLOG_IS_ON(3)){{ - const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s]}} \"; - {} - VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str); + if (VLOG_IS_ON(3)) {{ + const char* INPUT_PRINT_TEMPLATE = \"{{ Input: [%s]}} \"; +{} + VLOG(3) << paddle::string::Sprintf(INPUT_PRINT_TEMPLATE, input_str); }} """ @@ -346,13 +354,13 @@ class {} : public egr::GradNodeBase {{ // Check Inplace if needed {}{} // LOG IF DEBUG - {} +{} // Returns return {}; }} """ -FORWARD_BODY_BEFORE_API_CALL_TEMPLATE = """ if(require_any_grad) {{ +FORWARD_BODY_BEFORE_API_CALL_TEMPLATE = """ if (require_any_grad) {{ {} // Node Construction {} @@ -367,7 +375,7 @@ class {} : public egr::GradNodeBase {{ }} """ -FORWARD_BODY_AFTER_API_CALL_TEMPLATE = """ if(require_any_grad) {{ +FORWARD_BODY_AFTER_API_CALL_TEMPLATE = """ if (require_any_grad) {{ egr::EagerUtils::PassStopGradient({}); @@ -382,7 +390,7 @@ class {} : public egr::GradNodeBase {{ }} """ -HIGHER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """ if(trace_backward) {{ +HIGHER_ORDER_DERIVATIVE_VALUE_TEMPLATE = """ if (trace_backward) {{ {} // Node Construction {} @@ -562,12 +570,12 @@ class {} : public egr::GradNodeBase {{ CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = """ paddle::optional {}_optional; - if( {}.impl() ) {}_optional = paddle::make_optional({}); + if ({}.impl()) {}_optional = paddle::make_optional({}); """ CREATE_RECOVER_OPTIONAL_VECTOR_TENSOR_TEMPLATE = """ paddle::optional> {}_optional; - if( !{}.empty() ) {}_optional = paddle::make_optional>({}); + if (!{}.empty()) {}_optional = paddle::make_optional>({}); """ SET_GRAD_OUT_DIST_ATTR_TEMPLATE = """ @@ -593,20 +601,20 @@ class {} : public egr::GradNodeBase {{ CHECK_NAN_AND_INF_TEMPLATE_FORWARD = """ if (FLAGS_check_nan_inf) {{ - egr::CheckTensorHasNanOrInf("{}", {}); + egr::CheckTensorHasNanOrInf("{}", {}); }} """ CHECK_NAN_AND_INF_TEMPLATE_BACKWARD = """ if (FLAGS_check_nan_inf) {{ - try{{ - egr::CheckTensorHasNanOrInf("{}", {}); - }} catch(...) {{ - LOG(WARNING) << "There are nan/inf in ({})"; - auto forward_trace = GetForwardTrace(); - std::cout<SetTensorWrapper_{name}(*{name}_clone);}""".format_map( {"indent": indent, "name": name} @@ -1102,13 +1106,13 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): or (name in self.optional_inputs) ): if for_backward is False: - set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper_{name}(*{name});" + set_tensor_wrappers = f"{indent}if ({name}) grad_node->SetTensorWrapper_{name}(*{name});" else: - set_tensor_wrappers = f"{indent}if({name}_optional) grad_node->SetTensorWrapper_{name}(*{name}_optional);" + set_tensor_wrappers = f"{indent}if ({name}_optional) grad_node->SetTensorWrapper_{name}(*{name}_optional);" else: need_pre_contiguous_set.add(name) - set_tensor_wrappers = f"{indent}if({name}) grad_node->SetTensorWrapper_{name}(*{name}_tmp);" + set_tensor_wrappers = f"{indent}if ({name}) grad_node->SetTensorWrapper_{name}(*{name}_tmp);" else: if is_inplace_input: set_tensor_wrappers = f"{indent}auto {name}_clone = paddle::experimental::assign({name});\n{indent}grad_node->SetTensorWrapper_{name}({name}_clone);" @@ -1127,9 +1131,9 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): else: # Forwad's output as backward's input if num_fwd_outputs > 1: # Aligned with forward output position - assert ( - name in forward_outputs_position_map.keys() - ), AssertMessage(name, forward_outputs_position_map.keys()) + assert name in forward_outputs_position_map, AssertMessage( + name, forward_outputs_position_map.keys() + ) set_tensor_wrappers = ( f"{indent}grad_node->SetTensorWrapper_{name}({name});" @@ -1151,7 +1155,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): for name, (ttype, pos) in forward_inputs_position_map.items(): if name in need_pre_contiguous_set: pre_contiguous_list.append( - f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared(std::move(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast({name}.impl()))))), {name}.mutable_autograd_meta()) : {name};" + f"{indent}const auto& {name}_tmp = (require_any_grad && {name}.is_dense_tensor() && !std::dynamic_pointer_cast({name}.impl())->meta().is_contiguous()) ? paddle::Tensor(std::make_shared(paddle::experimental::Trans2Contiguous(*(std::dynamic_pointer_cast({name}.impl())))), {name}.mutable_autograd_meta(), {name}.name()) : {name};" ) self.inputs_call_list_tmp[pos] = ( self.inputs_call_list_tmp[pos] + '_tmp' @@ -1185,9 +1189,9 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): if is_optional: if for_backward is False: - set_grad_out_meta = f"{indent}if({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" + set_grad_out_meta = f"{indent}if ({name}.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}.get_ptr()), {pos});" else: - set_grad_out_meta = f"{indent}if({name}_optional.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}_optional.get_ptr()), {pos});" + set_grad_out_meta = f"{indent}if ({name}_optional.get_ptr() != nullptr) grad_node->SetGradOutMeta(*({name}_optional.get_ptr()), {pos});" else: if ( is_special_forward_api @@ -1209,7 +1213,7 @@ def GenerateNodeCreationCodes(self, for_backward=False, is_inplaced=False): set_out_rank_list = [] set_history_list = [] set_grad_in_meta_list = [] - num_outputs = len(forward_outputs_position_map.keys()) + num_outputs = len(forward_outputs_position_map) for name, (_, pos) in forward_outputs_position_map.items(): output_autograd_meta_name = GetAutoGradMetaName(name) set_out_rank = f"""{indent}if ({output_autograd_meta_name}) {{ @@ -1358,7 +1362,7 @@ def GenerateForwardLayoutAutotune( intermediate_outputs = self.intermediate_outputs forward_attrs_list = self.forward_attrs_list forward_outputs_position_map = self.forward_outputs_position_map - num_outputs = len(forward_outputs_position_map.keys()) - len( + num_outputs = len(forward_outputs_position_map) - len( intermediate_outputs ) # for layout autotune attr @@ -1481,9 +1485,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): indent = GetIndent(1) # Get Function Args - num_inputs = len(forward_attrs_list) + len( - forward_inputs_position_map.keys() - ) + num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map) inputs_args_definition_list = ["" for i in range(num_inputs)] inputs_args_declaration_list = ["" for i in range(num_inputs)] inputs_call_list = ["" for i in range(num_inputs)] @@ -1512,7 +1514,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): self.is_forward_only and is_inplaced and forward_inplace_map - and name in forward_inplace_map.keys() + and name in forward_inplace_map ): arg_str = f"paddle::optional& {name}" else: @@ -1535,7 +1537,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): if ( is_inplaced and forward_inplace_map - and name in forward_inplace_map.keys() + and name in forward_inplace_map ): arg_str = f"paddle::Tensor& {name}" amp_tensors_vector_list.append(f"{{{name}}}") @@ -1558,7 +1560,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): self.is_forward_only and is_inplaced and forward_inplace_map - and name in forward_inplace_map.keys() + and name in forward_inplace_map ): arg_str = f"paddle::optional>& {name}" else: @@ -1576,7 +1578,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): if ( is_inplaced and forward_inplace_map - and name in forward_inplace_map.keys() + and name in forward_inplace_map ): arg_str = f"std::vector& {name}" else: @@ -1623,7 +1625,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): if is_inplaced and len(forward_outputs_position_map) == 1: api_out_type = "auto&" forward_call_str = f"{indent}{api_out_type} api_result = paddle::experimental::{namespace}{function_name}({inputs_call_args_str});" - num_outputs = len(forward_outputs_position_map.keys()) - len( + num_outputs = len(forward_outputs_position_map) - len( intermediate_outputs ) @@ -1710,7 +1712,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): self.forward_api_name[-1] != '_' or self.forward_api_name == 'assign_out_' ): - for inplace_name in forward_inplace_map.keys(): + for inplace_name in forward_inplace_map: if ( not self.is_forward_only and forward_api_name not in inplace_check_blacklist @@ -1765,7 +1767,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): # 2. Get Output AutoGradMeta outputs_autograd_meta_list = [] - num_fwd_outputs = len(forward_outputs_position_map.keys()) + num_fwd_outputs = len(forward_outputs_position_map) for name, (rtype, pos) in forward_outputs_position_map.items(): output_autograd_meta_name = GetAutoGradMetaName(name) @@ -1828,9 +1830,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): f"return {forward_ad_function_name}({amp_inputs_call_args_str});" ) if is_inplaced or (forward_api_name == "cast"): - amp_logic_str = "\n VLOG(5) << \" No AMP for {} because it is a inplace or cast api. \"; ".format( - forward_ad_function_name - ) + amp_logic_str = f"\n VLOG(5) << \" No AMP for {forward_ad_function_name} because it is a inplace or cast api. \"; " else: amp_logic_str = AMP_LOGIC_TEMPLATE.format( kernel_trans2_op_name_str, @@ -1857,11 +1857,7 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): return_value=type_promote_call_list, ) else: - type_promotion_logic_str = ( - "\n VLOG(5) << \" No Type Promotion for {} api. \"; ".format( - forward_ad_function_name - ) - ) + type_promotion_logic_str = f"\n VLOG(5) << \" No Type Promotion for {forward_ad_function_name} api. \"; " # Forward layout autotune layout_autotune_list_str = " ".join( layout_autotune_list @@ -1882,22 +1878,20 @@ def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): for name, (ttype, pos) in forward_inputs_position_map.items(): var_str += f"\n{indent} const char* TENSOR_{name.upper()}_TEMPLATE = \" \\n( {name} , [%s]), \";" var_str += f"\n{indent} std::string input_{name}_str = paddle::string::Sprintf(TENSOR_{name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({name}));" - var_str += f"\n{indent} input_str += input_{name}_str; " + var_str += f"\n{indent} input_str += input_{name}_str;" before_log_str = BEFORE_LOG_PRINT_TEMPLATE.format(var_str) for name, (ttype, pos) in forward_outputs_position_map.items(): var_str += f"\n{indent} const char* TENSOR_{name.upper()}_TEMPLATE = \" \\n( {name} , [%s]), \";" var_str += f"\n{indent} std::string output_{name}_str = paddle::string::Sprintf(TENSOR_{name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({name}));" - var_str += f"\n{indent} output_str += output_{name}_str; " + var_str += f"\n{indent} output_str += output_{name}_str;" log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str) # Generate forward_definition_str and forward_declaration_str if self.is_forward_only: if len(amp_tensors_vector_list) == 0: - amp_logic_str = "\n VLOG(7) << \" No AMP for {} because it has no input. \"; ".format( - forward_ad_function_name - ) + amp_logic_str = f"\n VLOG(7) << \" No AMP for {forward_ad_function_name} because it has no input. \"; " self.forward_definition_str += ( FORWARD_ONLY_FUNCTION_TEMPLATE.format( returns_type_str, @@ -1958,10 +1952,7 @@ def GenerateInplacedForwardDygraphFunctions(self): forward_api_name = self.forward_api_name forward_api_contents = self.forward_api_contents - if ( - forward_api_name != "sum" - and "inplace" in forward_api_contents.keys() - ): + if forward_api_name != "sum" and "inplace" in forward_api_contents: # Function Definition and Declaration Generation self.GenerateForwardDefinitionAndDeclaration(is_inplaced=True) self.UpdateCoreOpsInformation(is_inplaced=True) @@ -1976,10 +1967,8 @@ def UpdateCoreOpsInformation(self, is_inplaced): forward_outputs_position_map = self.forward_outputs_position_map forward_attrs_list = self.forward_attrs_list - num_args = len(forward_inputs_position_map.keys()) + len( - forward_attrs_list - ) - num_returns = len(forward_outputs_position_map.keys()) + num_args = len(forward_inputs_position_map) + len(forward_attrs_list) + num_returns = len(forward_outputs_position_map) fwd_api_name = "" + forward_api_name core_ops_returns_info[fwd_api_name] = ["" for i in range(num_returns)] @@ -2042,7 +2031,7 @@ def __init__( def TransformToNextGradName(self, string): name_mapping = self.to_next_grad_name_mapping - if string in name_mapping.keys(): + if string in name_mapping: return name_mapping[string] return string @@ -2072,6 +2061,7 @@ def RecordGrad2NextGradNameMapping(self, next_node_generator): self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name def GenerateHigherOrderNodeCreationCode(self): + indent = GetIndent(1) has_higher_order_node = False namespace = self.namespace grad_api_contents = self.grad_api_contents @@ -2081,6 +2071,7 @@ def GenerateHigherOrderNodeCreationCode(self): next_grad_node_creation_str = "" next_grad_node_out_list = [] next_node_generator = None + if next_grad_api_contents: # Fake forward_api_contents and backward_api_contents forward_api_contents = grad_api_contents @@ -2107,30 +2098,46 @@ def GenerateHigherOrderNodeCreationCode(self): is_composite_grad_api = ( False if self.composite_func_info == {} else True ) - if is_composite_grad_api: if next_grad_node_creation_str != '': - next_grad_node_creation_str = f""" - if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{ - {next_grad_node_creation_str} - }} - """ + next_grad_node_creation_str = [ + line if len(line) else line + for line in next_grad_node_creation_str.split("\n") + ] + next_grad_node_creation_str = [ + (indent + line if i >= 1 and len(line) else line) + for line in next_grad_node_creation_str + ] + next_grad_node_creation_str = [ + (indent + line if len(line) else line) + for line in next_grad_node_creation_str + ] + next_grad_node_creation_str = "\n".join( + next_grad_node_creation_str + ) + if self.backward_api_name in prim_white_list: + next_grad_node_creation_str = "" + else: + next_grad_node_creation_str = f""" + if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled() || need_skip) {{ +{next_grad_node_creation_str} + }} +""" else: if not ( self.grad_api_contents["backward_op"] in prim_white_list or is_invoke_forward_api ): next_grad_node_creation_str = f""" - if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled()) {{ - if(trace_backward) {{ - PADDLE_THROW(phi::errors::Unavailable( - \"The Op {self.backward_api_name} doesn't have any grad\" - \"op. If you don't intend calculating higher order\" - \"derivatives, please set `create_graph`to False.\")); + if (!paddle::prim::PrimCommonUtils::IsEagerPrimEnabled() || need_skip) {{ + if (trace_backward) {{ + PADDLE_THROW(phi::errors::Unavailable( + \"The Op {self.backward_api_name} doesn't have any grad\" + \"op. If you don't intend calculating higher order\" + \"derivatives, please set `create_graph`to False.\")); + }} }} - }} - """ - +""" if next_node_generator is not None: has_higher_order_node = True return ( @@ -2143,7 +2150,7 @@ def GenerateHigherOrderNodeCreationCode(self): ) # TODO(Ruting):Integrate invoke and composite as composite so the rest branch canbe covered elif not is_invoke_forward_api and not is_composite_grad_api: - next_grad_node_creation_str = f""" if(trace_backward) {{ + next_grad_node_creation_str = f""" if (trace_backward) {{ PADDLE_THROW(phi::errors::Unavailable( \"The Op {self.backward_api_name} doesn't have any grad\" \"op. If you don't intend calculating higher order\" @@ -2273,8 +2280,8 @@ def GenerateNodeDefinition( # Construct grad_api function args # Order: TensorWrappers, GradTensors, Attributes grad_api_args_len = ( - len(backward_forward_inputs_map.keys()) - + len(backward_grad_inputs_map.keys()) + len(backward_forward_inputs_map) + + len(backward_grad_inputs_map) + len(backward_attrs_list) ) grad_api_args = ["" for i in range(grad_api_args_len)] @@ -2325,7 +2332,7 @@ def GenerateNodeDefinition( is_optional = name in self.optional_inputs tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name});" - if backward_inplace_map and name in backward_inplace_map.keys(): + if backward_inplace_map and name in backward_inplace_map: if has_higher_order_node: if ( transformed_tensor_name @@ -2401,7 +2408,7 @@ def GenerateNodeDefinition( get_tensor_str = f"{indent}auto& {transformed_tensor_name} = hooked_grads[{fwd_position}][0];" # Inplace in backward op - if backward_inplace_map and name in backward_inplace_map.keys(): + if backward_inplace_map and name in backward_inplace_map: if has_higher_order_node: if ( transformed_tensor_name @@ -2464,7 +2471,7 @@ def GenerateNodeDefinition( get_grad_in_args_str = "\n".join(get_grad_in_args_list) # Grad Function Call String - slot_num_bwd_outputs = len(self.forward_inputs_position_map.keys()) + slot_num_bwd_outputs = len(self.forward_inputs_position_map) grad_api_namespace = f"paddle::experimental::{namespace}" composite_grad_api_namespace = f"paddle::prim::{namespace}" grad_function_prepare_str = f""" @@ -2508,7 +2515,7 @@ def GenerateNodeDefinition( backward_inplace_map and name in backward_inplace_map.values() ): - inplace_str = f""" if (api_output_{out_index} != nullptr && can_be_inplaced) {{ + inplace_str = f"""if (api_output_{out_index} != nullptr && can_be_inplaced) {{ egr::EagerUtils::HandleViewBetweenInputAndOutput({inplace_grad_input_str}, api_output_{out_index}); }}""" if has_higher_order_node: @@ -2520,7 +2527,7 @@ def GenerateNodeDefinition( }}""" need_gen_trace_backward_for_inplace = True else: - inplace_for_grad_outs_str += inplace_str + inplace_for_grad_outs_str += " " + inplace_str grad_function_prepare_str += f""" auto* api_output_{out_index} = (out_metas[{fwd_position}].empty() || out_metas[{fwd_position}][0].IsStopGradient()) ? nullptr : &returns[{fwd_position}][0];""" @@ -2570,43 +2577,106 @@ def GenerateNodeDefinition( grad_function_call_str = f""" if (trace_backward) {{ {indent}{autograd_api_out} api_output = {autograd_api}; - {out_assign_str}}} else {{ + {out_assign_str}{indent}}} else {{ {indent}{autograd_api_out} api_output = paddle::experimental::{self.namespace}{self.grad_api_contents['invoke']}; {out_assign_str}{indent}}} - """ - # TODO(Ruting):using composite only when we don't have backward kernel in the future. +""" elif is_composite_grad_api: - if composite_grad_api_name in prim_white_list: - grad_function_call_str = f""" + has_kernel_impl = "kernel" in self.grad_api_contents + + def _gen_api_call_code_block( + in_prim_white_list: bool, + has_kernel_impl: bool, + indention: int, + ): + """This function will generate code block for calling composite or + kernel grad api as shown below. + + // Call grad_api function + + XXX <-- Generated code by this function + XXX <-- Generated code by this function + ... <-- Generated code by this function + ... <-- Generated code by this function + + // Check NaN and Inf id needed + + Args: + in_prim_white_list (bool): Whether current op in `prim_white_list`. + has_kernel_impl (bool): Whether current op has kernel implementation. + indention (int): Number of single space for whole code block indention. + """ + if in_prim_white_list: + code = f""" +bool original_global_grad = egr::Controller::Instance().HasGrad(); +if (!create_graph) {{ +{indent}egr::Controller::Instance().SetHasGrad(create_graph); +}} +{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); +VLOG(4) << "Composite api {composite_grad_api_name} is called"; +if (!create_graph) {{ +{indent}egr::Controller::Instance().SetHasGrad(original_global_grad); +}} +""" + else: + code = f""" +std::string grad_op_name = "{composite_grad_api_name}"; +auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(grad_op_name); +if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled() && !need_skip) {{ {indent}bool original_global_grad = egr::Controller::Instance().HasGrad(); -{indent}if(!create_graph){{ +{indent}if (!create_graph) {{ {indent}{indent}egr::Controller::Instance().SetHasGrad(create_graph); - }} - {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); - VLOG(4) << "Composite api {composite_grad_api_name} is called "; -{indent}if(!create_graph){{ +{indent}}} +{indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); +{indent}VLOG(4) << "Composite api {composite_grad_api_name} is called"; +{indent}if (!create_graph) {{ {indent}{indent}egr::Controller::Instance().SetHasGrad(original_global_grad); - }} - """ +{indent}}}""" + if has_kernel_impl: + code = ( + code + + f""" +}} else {{ +{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str}); +{indent}VLOG(4) << "Fused api {backward_api_name} is called"; +}} +""" + ) + else: + code = ( + code + + f""" +}} else {{ + PADDLE_THROW(phi::errors::Unavailable( + \"The grad op of {self.backward_api_name} doesn't implemented yet.\")); +}} +""" + ) + # make indention for all line(s) in code + code = "\n".join( + [ + (f"{' ' * indention}{line}" if len(line) else line) + for line in code.split("\n") + ] + ) + + return code + + if ( + self.backward_api_name not in prim_white_list + and not has_kernel_impl + ): + grad_function_call_str = _gen_api_call_code_block( + self.backward_api_name in prim_white_list, + has_kernel_impl, + 0, + ) else: - grad_function_call_str = f""" - std::string grad_op_name = "{composite_grad_api_name}"; - auto need_skip = paddle::prim::StaticCompositeContext::Instance().CheckSkipCompOps(grad_op_name); - if (paddle::prim::PrimCommonUtils::IsEagerPrimEnabled() && !need_skip) {{ -{indent}bool original_global_grad = egr::Controller::Instance().HasGrad(); -{indent}if(!create_graph){{ -{indent}{indent}egr::Controller::Instance().SetHasGrad(create_graph); - }} - {indent}{composite_grad_api_namespace}{composite_grad_api_name}{composite_template_name}({composite_grad_api_args_str}); - {indent}VLOG(4) << "Composite api {composite_grad_api_name} is called "; -{indent}if(!create_graph){{ -{indent}{indent}egr::Controller::Instance().SetHasGrad(original_global_grad); - }} - }}else{{ - {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str}); - {indent}VLOG(4) << "Fused api {backward_api_name} is called "; - }} - """ + grad_function_call_str = _gen_api_call_code_block( + self.backward_api_name in prim_white_list, + has_kernel_impl, + 2, + ) else: grad_function_call_str = f""" {indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});""" @@ -2630,7 +2700,7 @@ def GenerateNodeDefinition( outputs_autograd_meta_list = [] # TODO(jiabin): Optimize this with SetStopGradient instead of Pass Stop gradient - num_fwd_outputs = len(backward_grad_outputs_map.keys()) + num_fwd_outputs = len(backward_grad_outputs_map) for name, ( rtype, pos, @@ -2649,7 +2719,7 @@ def GenerateNodeDefinition( auto& {transformed_tensor_name} = returns[{pos}][0]; egr::AutogradMeta* {output_autograd_meta_name} = returns[{pos}][0].initialized() ? egr::EagerUtils::autograd_meta(&{transformed_tensor_name}) : nullptr; if ({output_autograd_meta_name}) {output_autograd_meta_name}->SetStopGradient(false); - """ +""" else: assert IsVectorTensorType(rtype) @@ -2658,7 +2728,7 @@ def GenerateNodeDefinition( auto& {transformed_tensor_name} = returns[{pos}]; std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name}); std::vector* {output_autograd_meta_name} = &{output_autograd_meta_vec_name}; - for(auto* meta : {output_autograd_meta_vec_name}){{ + for(auto* meta : {output_autograd_meta_vec_name}) {{ meta->SetStopGradient(false); }} """ @@ -2666,7 +2736,7 @@ def GenerateNodeDefinition( output_autograd_meta = f""" auto& {transformed_tensor_name} = returns[{pos}]; std::vector {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{transformed_tensor_name}); - for(auto* meta : {output_autograd_meta_vec_name}){{ + for(auto* meta : {output_autograd_meta_vec_name}) {{ meta->SetStopGradient(false); }} """ @@ -2674,7 +2744,7 @@ def GenerateNodeDefinition( outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) - returns_str = f"{indent}if(NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" + returns_str = f"{indent}if (NeedComplexToRealConversion()) HandleComplexGradToRealGrad(&returns);\n" returns_str += f"{indent}return returns;\n" grad_node_name = GetGradNodeName(self.backward_api_name) @@ -2689,7 +2759,7 @@ def GenerateNodeDefinition( new_name = self.TransformToNextGradName(name) var_str += f"\n{indent} const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n( {new_name} , [%s]), \";" var_str += f"\n{indent} std::string input_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));" - var_str += f"\n{indent} input_str += input_{new_name}_str; " + var_str += f"\n{indent} input_str += input_{new_name}_str;" for ( name, @@ -2698,7 +2768,7 @@ def GenerateNodeDefinition( new_name = self.TransformToNextGradName(name) var_str += f"\n{indent} const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n( {new_name} , [%s]), \";" var_str += f"\n{indent} std::string input_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));" - var_str += f"\n{indent} input_str += input_{new_name}_str; " + var_str += f"\n{indent} input_str += input_{new_name}_str;" before_log_str = BEFORE_LOG_PRINT_TEMPLATE.format(var_str) @@ -2710,13 +2780,14 @@ def GenerateNodeDefinition( new_name = self.TransformToNextGradName(name) var_str += f"\n{indent} const char* TENSOR_{new_name.upper()}_TEMPLATE = \" \\n ( {new_name} , [%s]), \";" var_str += f"\n{indent} std::string output_{new_name}_str = paddle::string::Sprintf(TENSOR_{new_name.upper()}_TEMPLATE, egr::EagerUtils::TensorStr({new_name}));" - var_str += f"\n{indent} output_str += output_{new_name}_str; " + var_str += f"\n{indent} output_str += output_{new_name}_str;" log_str = AFTER_LOG_PRINT_TEMPLATE.format(var_str) self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format( grad_node_name, self.backward_api_name, + grad_node_name, fill_zero_str, get_grad_in_args_str, grad_function_prepare_str, @@ -2787,7 +2858,7 @@ def __init__( def CollectIsForwardOnly(self, forward_api_contents): self.is_forward_only = ( - False if 'backward' in forward_api_contents.keys() else True + False if 'backward' in forward_api_contents else True ) def ParseYamlContents(self): @@ -2802,11 +2873,11 @@ def ParseYamlContents(self): def GetBackwardAPIContents(self, forward_api_contents): grad_api_dict = self.grad_api_dict - if 'backward' not in forward_api_contents.keys(): + if 'backward' not in forward_api_contents: return None backward_api_name = forward_api_contents['backward'] - assert backward_api_name in grad_api_dict.keys(), AssertMessage( + assert backward_api_name in grad_api_dict, AssertMessage( backward_api_name, grad_api_dict.keys() ) backward_api_contents = grad_api_dict[backward_api_name] diff --git a/paddle/fluid/eager/backward.cc b/paddle/fluid/eager/backward.cc index 027ebba18be96..1fa69a37302e4 100644 --- a/paddle/fluid/eager/backward.cc +++ b/paddle/fluid/eager/backward.cc @@ -253,10 +253,6 @@ std::vector RunBackward( while (!queue.empty()) { GradNodeBase* node = queue.front(); VLOG(3) << "Preparing GradNode:" << node->name() << " addr:" << node; - paddle::platform::RecordEvent node_record_event( - std::string((*node).name()), - paddle::platform::TracerEventType::Operator, - 1); if (queue.size() > 1 && node_in_degree_map[node] != 0) { queue.pop_front(); @@ -280,14 +276,29 @@ std::vector RunBackward( EnforceGradNodeHasInput(node); VLOG(7) << "Run Backward Kernel with GradTensorHolder."; + + // This 'Global_XXXGradNode' record event is different with + // 'Local_XXXGradNode' event. + // * 'Global_XXXGradNode' will not only cover execution time of this + // function, but also include gradient + // accumulation when the output(s) of corresponding forward OP are shared + // by other OP(s), which may have extra overhead of accumulation than + // 'Local_XXXGradNode'. + // * 'Local_XXXGradNode' will only cover execution time of GradNode + // function. + paddle::platform::RecordEvent grad_node_record_event( + "Global_" + std::string((*node).name()), + paddle::platform::TracerEventType::Operator, + 1); + // Run Pre Backward Node and get outputs paddle::small_vector, kSlotSmallVectorSize> grad_output_tensors = (*node)( node_input_buffer->Buffers(), create_graph, is_general_grad); if (!inputs.empty() && is_general_grad) { - GeneralGrad::Instance().SetResultForEnddingNodes(grad_output_tensors, - node); + GeneralGrad::Instance().SetResultForEndingNodes(grad_output_tensors, + node); } // retain_grad or not @@ -382,8 +393,7 @@ std::vector RunBackward( "Node's in-degree cannot be negative.", next_node->name())); - auto add_next_node_func = [&node_in_degree_map, - &queue](GradNodeBase* next_node) { + auto add_next_node_func = [&queue](GradNodeBase* next_node) { if (dynamic_cast(next_node)) { queue.push_front(next_node); } else { diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index 9b6318c7a43ed..e252868ebcaff 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -209,8 +209,8 @@ RunCustomOpNode::operator()(paddle::small_vector, ->meta() .is_contiguous()) { tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous(*( - std::dynamic_pointer_cast(tensor.impl())))))); + paddle::experimental::Trans2Contiguous(*( + std::dynamic_pointer_cast(tensor.impl()))))); } } @@ -436,7 +436,7 @@ RunCustomOpDoubleGradNode::operator()( << " to tmp_outputs: " << grad_output_idx; for (size_t j = 0; j < OutputMeta()[grad_output_idx].size(); j++) { outs[grad_output_idx] - .emplace_back(/* init it incase of copy nullptr of shared_ptr */ + .emplace_back(/* init it in case of copy nullptr of shared_ptr */ std::make_shared( phi::DataType::UNDEFINED), egr::Controller::Instance().GenerateUniqueName( diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index b843e081c29be..d3debf77df14f 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -163,12 +163,11 @@ static std::vector> RunInferShapeFunc( for (size_t i = 0; i < ctx.InputRange().size(); ++i) { const auto& input_pair = ctx.InputRangeAt(i); if (input_pair.first == input_pair.second - 1) { - input_shapes.emplace_back( - std::move(ctx.InputAt(input_pair.first).shape())); + input_shapes.emplace_back(ctx.InputAt(input_pair.first).shape()); } else { std::vector> shapes; for (size_t j = input_pair.first; j < input_pair.second; j++) { - shapes.push_back(std::move(ctx.InputAt(j).shape())); + shapes.push_back(ctx.InputAt(j).shape()); } vec_input_shapes.emplace_back(std::move(shapes)); } @@ -558,7 +557,7 @@ std::vector> RunInferShapeFn( out_dims = RunInferShapeFunc(ctx, infer_shape_func, inputs, outputs, inplace_map); } else { - if (is_forward) { + if (is_forward) { // NOLINT out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); } else { out_dims = @@ -592,7 +591,7 @@ std::vector> RunInferDtypeFn( out_dtypes = RunInferDtypeFunc(ctx, infer_dtype_func, inputs, outputs, inplace_map); } else { - if (is_forward) { + if (is_forward) { // NOLINT out_dtypes = RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); } else { out_dtypes = @@ -800,8 +799,8 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ->meta() .is_contiguous()) { tensor.set_impl(std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast(tensor.impl())))))); + paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(tensor.impl()))))); } } diff --git a/paddle/fluid/eager/general_grad.h b/paddle/fluid/eager/general_grad.h index 443455619cae6..5ced385700f4f 100644 --- a/paddle/fluid/eager/general_grad.h +++ b/paddle/fluid/eager/general_grad.h @@ -124,15 +124,15 @@ class GeneralGrad { } visited.insert(target_node); if (!(depending_nodes_)[target_node].empty()) { - auto precedding_nodes = (depending_nodes_)[target_node]; - for (auto pre_nodes : precedding_nodes) { + auto preceding_nodes = (depending_nodes_)[target_node]; + for (auto pre_nodes : preceding_nodes) { queue.push_back(pre_nodes); needed_nodes_.emplace(pre_nodes); if (IsInputTargetNodes(pre_nodes)) { input_target_nodes_on_path.emplace(pre_nodes); } } - } else { // startup_ops have no precedding nodes + } else { // startup_ops have no preceding nodes VLOG(6) << "Emplace startup_ops"; startup_ops.emplace(target_node); needed_nodes_.emplace(target_node); @@ -143,7 +143,7 @@ class GeneralGrad { input_target_nodes_inputmeta_map_) { if (!input_target_nodes_on_path.count( target_nodes_inputmeta_pair.first)) { - endding_nodes_.emplace(target_nodes_inputmeta_pair.first); + ending_nodes_.emplace(target_nodes_inputmeta_pair.first); } } @@ -236,12 +236,12 @@ class GeneralGrad { } // TODO(jiabin): Some check here. } - void SetResultForEnddingNodes( + void SetResultForEndingNodes( paddle::small_vector, kSlotSmallVectorSize> grad_output, GradNodeBase* node) { - if (IsEnddingNodes(node)) { - VLOG(6) << "Set result for endding_nodes_ with grad_output_tensors"; + if (IsEndingNodes(node)) { + VLOG(6) << "Set result for ending_nodes_ with grad_output_tensors"; results_map_[node] = std::make_shared(grad_output[0][0]); } } @@ -270,14 +270,14 @@ class GeneralGrad { target_node->RegisterGradientHook( rank_info.first, rank_info.second, - std::move(std::make_shared(hook))); + std::make_shared(hook)); return tmp; } // Register Hook to fetch input's gradients, when input's grad node is not an - // endding node in backward graph. If input's grad node is an endding node in + // ending node in backward graph. If input's grad node is an ending node in // backward graph, use grad node's output as inputs' gradients and no need to - // register Hook. Please note that endding node must be GradNodeAccumulation + // register Hook. Please note that ending node must be GradNodeAccumulation // after ModifyBackwardGraph function. void RegisterFetchGradHook(const std::vector& inputs) { VLOG(6) << "Running in RegisterFetchGradHook."; @@ -296,8 +296,8 @@ class GeneralGrad { if (orig_to_copied_node_map_.count(target_node)) { target_node = orig_to_copied_node_map_[target_node].get(); - if (copied_node_to_endding_node_map_.count(target_node)) { - VLOG(6) << "No need to call FetchGradForTensor for endding_nodes"; + if (copied_node_to_ending_node_map_.count(target_node)) { + VLOG(6) << "No need to call FetchGradForTensor for ending_nodes"; continue; } } @@ -309,7 +309,7 @@ class GeneralGrad { "stop_gradient=True.", i)); - if (!IsEnddingNodes(target_node)) { + if (!IsEndingNodes(target_node)) { // Fetch grad for tensor in target_node on path. auto fetched_grad = FetchGradForTensor(inputs[i], target_node); results_map_[target_node] = fetched_grad; @@ -321,9 +321,9 @@ class GeneralGrad { void SetNodeToAccumulationNode(GradNodeBase* node) { if (dynamic_cast(node)) return; if (!(depending_nodes_)[node].empty()) { - // Find precedding_nodes of current node. - auto precedding_nodes = (depending_nodes_)[node]; - for (auto pre_nodes : precedding_nodes) { + // Find preceding_nodes of current node. + auto preceding_nodes = (depending_nodes_)[node]; + for (auto pre_nodes : preceding_nodes) { paddle::small_vector, kSlotSmallVectorSize>& pre_nodes_edges = pre_nodes->MutableOutputMeta(); for (size_t i = 0; i < pre_nodes_edges.size(); i++) { @@ -332,21 +332,21 @@ class GeneralGrad { if (edge_.GetGradNode() == node) { Edge& pre_node_edge = pre_nodes_edges[i][j].GetMutableEdge(); - if (copied_node_to_endding_node_map_.count(node)) { + if (copied_node_to_ending_node_map_.count(node)) { pre_node_edge.SetGradNode( - copied_node_to_endding_node_map_[node]); + copied_node_to_ending_node_map_[node]); } else { auto autograd_meta = egr::AutogradMeta(edge_); std::shared_ptr shared_grad_node_accumulation = std::make_shared(&autograd_meta); pre_node_edge.SetGradNode(shared_grad_node_accumulation); - copied_node_to_endding_node_map_[node] = + copied_node_to_ending_node_map_[node] = shared_grad_node_accumulation; } auto* grad_node = pre_node_edge.GetGradNode(); needed_nodes_.emplace(grad_node); - endding_nodes_.emplace(grad_node); + ending_nodes_.emplace(grad_node); input_target_nodes_inputmeta_map_[grad_node] = input_target_nodes_inputmeta_map_[node]; @@ -384,7 +384,7 @@ class GeneralGrad { } visited.insert(node); - if (IsInputTargetNodes(node) && IsEnddingNodes(node)) { + if (IsInputTargetNodes(node) && IsEndingNodes(node)) { SetNodeToAccumulationNode(node); continue; } @@ -413,7 +413,7 @@ class GeneralGrad { } if (meta.size() != 1 && IsNeededNodes(node) && - !IsNeededNodes(next_node.get()) && !IsEnddingNodes(node)) { + !IsNeededNodes(next_node.get()) && !IsEndingNodes(node)) { VLOG(3) << "Get stop edge from grad_node: " << node->name() << " : " << node << " to:" << next_node->name() << ", " << next_node.get() << " with output rank info: " << i @@ -448,8 +448,8 @@ class GeneralGrad { auto* target_node = auto_grad_meta->GetMutableGradNode().get(); if (orig_to_copied_node_map_.count(target_node)) { target_node = orig_to_copied_node_map_[target_node].get(); - if (copied_node_to_endding_node_map_.count(target_node)) { - target_node = copied_node_to_endding_node_map_[target_node].get(); + if (copied_node_to_ending_node_map_.count(target_node)) { + target_node = copied_node_to_ending_node_map_[target_node].get(); } } else { VLOG(6) << "Unable to find target node in " @@ -480,7 +480,7 @@ class GeneralGrad { bool IsNeededNodes(GradNodeBase* node) { return needed_nodes_.count(node); } - bool IsEnddingNodes(GradNodeBase* node) { return endding_nodes_.count(node); } + bool IsEndingNodes(GradNodeBase* node) { return ending_nodes_.count(node); } bool IsInputTargetNodes(GradNodeBase* node) { auto iter = input_target_nodes_inputmeta_map_.find(node); @@ -621,9 +621,9 @@ class GeneralGrad { results_map_.clear(); copied_grad_nodes_.clear(); orig_to_copied_node_map_.clear(); - copied_node_to_endding_node_map_.clear(); + copied_node_to_ending_node_map_.clear(); needed_nodes_.clear(); - endding_nodes_.clear(); + ending_nodes_.clear(); } private: @@ -649,8 +649,8 @@ class GeneralGrad { std::unordered_set needed_nodes_; // Record which grad_node has been transformed to AccumulationNode std::unordered_map> - copied_node_to_endding_node_map_; - std::unordered_set endding_nodes_; + copied_node_to_ending_node_map_; + std::unordered_set ending_nodes_; DISABLE_COPY_AND_ASSIGN(GeneralGrad); }; diff --git a/paddle/fluid/eager/grad_node_info.cc b/paddle/fluid/eager/grad_node_info.cc index 2a97f5bf35e90..ce7f7caf1f44c 100644 --- a/paddle/fluid/eager/grad_node_info.cc +++ b/paddle/fluid/eager/grad_node_info.cc @@ -261,6 +261,106 @@ void GradNodeBase::SetGradInMeta(const std::vector& fwd_out, } } +void GradNodeBase::SetGradInMeta(const std::vector& fwd_out, + size_t slot_rank) { + VLOG(7) << "Set GradSlotMeta for Grad Inputs"; + size_t slot_size = fwd_out.size(); + PADDLE_ENFORCE_LE( + slot_rank, + (bwd_in_meta_.size() - 1), + paddle::platform::errors::InvalidArgument( + "Slot Rank should less equal than bwd_in_meta_ size, since " + "bwd_in_meta_ is designed to hold as same num as backward " + "inputs.")); + auto& metas = bwd_in_meta_.at(slot_rank); + // Init stop gradient vector before use to avoid push back + if (metas.size() < slot_size) { + VLOG(7) << "Init bwd_in_meta_ with slot rank: " << slot_rank; + metas.resize(slot_size); + } + for (size_t i = 0; i < slot_size; i++) { + auto& meta = metas[i]; + const auto& fwd_out_tensor = *fwd_out[i]; + auto* fwd_out_meta = + egr::EagerUtils::nullable_autograd_meta(fwd_out_tensor); + PADDLE_ENFORCE_NOT_NULL(fwd_out_meta, + paddle::platform::errors::PreconditionNotMet( + "Bwd_in_meta should only be called while " + "autograd_meta is not null. If you got this " + "error, it indicates bugs in framework.")); + if (fwd_out_meta && fwd_out_meta->StopGradient()) { + // Set Stop Gradient only when its true or non-initialized autograd_meta, + // since all default value is false. + meta.SetStopGradient(fwd_out_meta->StopGradient()); + } + + if (!fwd_out_tensor.initialized()) { + if (fwd_out_tensor.defined() && fwd_out_tensor.is_dist_tensor() && + phi::distributed::NeedComputationClipForPP(fwd_out_tensor.impl())) { + VLOG(3) << "Tensor " << fwd_out_tensor.name() << " is DistTensor," + << " and needs computation clip for pipeline parallel." + << " Still SetGradInMeta for it."; + } else { + VLOG(7) << "Skip Configuring GradSlotMeta for uninitialized GradInput " + "Tensor"; + return; + } + } + + // Record TensorMeta + if (phi::DenseTensor::classof(fwd_out_tensor.impl().get())) { + // Only Copy Meta + phi::DenseTensor* dense_tensor = + static_cast(fwd_out_tensor.impl().get()); + + PADDLE_ENFORCE_NE( + dense_tensor->meta().dtype, + phi::DataType::UNDEFINED, + paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta " + "with phi::DataType::UNDEFINED," + "which is illegal.")); + meta.SetTensorMeta(dense_tensor->meta()); + meta.SetPlace(fwd_out_tensor.place()); + + if (dense_tensor->type() == phi::DataType::COMPLEX64 || + dense_tensor->type() == phi::DataType::COMPLEX128) { + need_complex_to_real_ = true; + } + } else if (phi::distributed::DistTensor::classof( + fwd_out_tensor.impl().get())) { + // Only Copy Meta + meta.SetDistAttr(static_cast( + fwd_out_tensor.impl().get()) + ->dist_attr()); + meta.SetDistTensorGlobalDims(static_cast( + fwd_out_tensor.impl().get()) + ->dims()); + SetIsRunAutoParallel(true); + + auto dense_tensor = static_cast( + fwd_out_tensor.impl().get()) + ->value(); + + PADDLE_ENFORCE_NE( + dense_tensor.meta().dtype, + phi::DataType::UNDEFINED, + paddle::platform::errors::Fatal("Attempting to copy DenseTensorMeta " + "with phi::DataType::UNDEFINED," + "which is illegal.")); + meta.SetTensorMeta(dense_tensor.meta()); + meta.SetPlace(fwd_out_tensor.place()); + + if (dense_tensor.type() == phi::DataType::COMPLEX64 || + dense_tensor.type() == phi::DataType::COMPLEX128) { + need_complex_to_real_ = true; + } + } else { + VLOG(7) << "Unable to initialize the DenseTensorMeta of GradSlotMeta " + "with non-DenseTensor argument."; + } + } +} + void GradNodeBase::SetGradOutMeta(const paddle::Tensor& fwd_in, size_t slot_rank) { auto* fwd_in_meta = egr::EagerUtils::nullable_autograd_meta(fwd_in); diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 7b5e36f4d5cdc..73eedaba9e4f3 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -250,7 +250,8 @@ class GradNodeBase { void SetGradInMeta(const std::vector& fwd_out, size_t slot_rank); void SetGradInMeta(const paddle::Tensor& fwd_out, size_t slot_rank); - + void SetGradInMeta(const std::vector& fwd_out, + size_t slot_rank); void SetGradOutMeta(const std::vector& fwd_in, size_t slot_rank); void SetGradOutMeta(const std::vector& fwd_in, diff --git a/paddle/fluid/eager/grad_tensor_holder.cc b/paddle/fluid/eager/grad_tensor_holder.cc index dac55f8f5462f..47f41b5a4f93b 100644 --- a/paddle/fluid/eager/grad_tensor_holder.cc +++ b/paddle/fluid/eager/grad_tensor_holder.cc @@ -79,7 +79,7 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id, // Create new tensor->impl and fill it with 1.0 if (t.defined()) { // Fill 1.0, use full to support complex, one_like don't support it. - if (t.is_dense_tensor()) { + if (t.is_dense_tensor()) { // NOLINT buffer_[slot_id][rank] = paddle::experimental::full(t.shape(), 1, t.dtype(), t.place()); } else if (t.is_sparse_csr_tensor() || t.is_sparse_coo_tensor()) { diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h index f6b8e21cd8b17..cdb4de66ae189 100644 --- a/paddle/fluid/eager/to_static/run_program_op_func.h +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -20,9 +20,12 @@ #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/eager/to_static/run_program_op_node.h" #include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/tensor_ref_array.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/value.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" // Filter params without grads in global block. In this case, we will // tag its AutogradMeta with stop_gradient = True to avoid fault from @@ -119,9 +122,10 @@ static std::vector Trans2ContiguousTensors( .is_contiguous()) { res.emplace_back( std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( - *(std::dynamic_pointer_cast(t.impl()))))), - t.mutable_autograd_meta()); + paddle::experimental::Trans2Contiguous( + *(std::dynamic_pointer_cast(t.impl())))), + t.mutable_autograd_meta(), + t.name()); } else { res.emplace_back(t); } @@ -244,8 +248,9 @@ inline void pir_run_program_ad_func( trace_backward, &p_autograd_x, &p_autograd_params); // Create Middle Output for GradNode. - auto middle_size = - PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")).size(); + auto middle_values = + PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fm")); + auto middle_size = middle_values.size(); auto output_size = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("fo")).size(); auto middles = std::vector(); @@ -264,8 +269,14 @@ inline void pir_run_program_ad_func( grad_node->GetMiddle().resize(middle_size); grad_node->GetOutputs().resize(output_size); for (size_t i = 0; i < middle_size; ++i) { - grad_node->GetMiddle()[i] = - paddle::Tensor(std::make_shared()); + auto middle_value = middle_values[i]; + if (middle_value.type().isa()) { + grad_node->GetMiddle()[i] = + paddle::Tensor(std::make_shared()); + } else if (middle_value.type().isa()) { + grad_node->GetMiddle()[i] = paddle::Tensor( + std::make_shared()); + } middles.push_back(&grad_node->GetMiddle()[i]); } diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index fdebfbb1e3771..af91fe9e0c08e 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -19,6 +19,7 @@ #include "paddle/fluid/eager/tensor_wrapper.h" #include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" +#include "paddle/fluid/framework/tensor_ref_array.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/operators/run_program_op.h" @@ -84,14 +85,78 @@ static std::vector GetTensorsName( return in_names; } +static bool IsVariableRefArray(const Tensor &tensor) { + return paddle::framework::VariableRefArray::classof(tensor.impl().get()); +} + +static auto GetNameFromValue(const ::pir::Block *block, + const std::vector<::pir::Value> &values, + bool allow_input, + bool allow_output) { + PADDLE_ENFORCE_EQ( + allow_input || allow_output, + true, + paddle::platform::errors::InvalidArgument( + "GetNameFromValue should allow input or output at least one.")); + // we use name here, later value is used directly. + std::unordered_map<::pir::Value, std::string> value2name; + if (allow_input) { + for (auto &kwarg : block->kwargs()) { + value2name[kwarg.second] = kwarg.first; + } + } + for (auto &op : *block) { + std::string name; + if (allow_input && op.name() == "pd_op.data") { + name = + op.attributes().at("name").dyn_cast().AsString(); + value2name[op.results()[0].Value::impl()] = name; + } else if (allow_output && op.name() == "builtin.set_parameter") { + name = op.attributes() + .at("parameter_name") + .dyn_cast() + .AsString(); + value2name[op.operand(0).source()] = name; + } else if (allow_output && op.name() == "builtin.shadow_output") { + name = op.attributes() + .at("output_name") + .dyn_cast() + .AsString(); + value2name[op.operand(0).source()] = name; + } else if (allow_input && op.name() == "builtin.parameter") { + name = op.attributes() + .at("parameter_name") + .dyn_cast() + .AsString(); + value2name[op.result(0).Value::impl()] = name; + } else if (allow_input && op.name() == "builtin.constant") { + if (op.isa()) { + name = op.dyn_cast().tensor_name(); + value2name[op.result(0).Value::impl()] = name; + } + } + } + std::vector names; + std::transform(values.begin(), + values.end(), + std::back_inserter(names), + [&value2name](const ::pir::Value &v) { + if (!value2name.count(v)) + return std::string(paddle::framework::kFakeVarName); + return value2name.at(v); + }); + return names; +} + static void CheckInputVarStatus(const Tensor &tensor) { - PADDLE_ENFORCE_EQ(tensor.defined() && tensor.is_dense_tensor(), - true, - paddle::platform::errors::InvalidArgument( - "The input tensor %s of " - "RunProgram(Grad)Op holds " - "wrong type. Expect type is DenseTensor.", - tensor.name())); + PADDLE_ENFORCE_EQ( + tensor.defined() && + (tensor.is_dense_tensor() || IsVariableRefArray(tensor)), + true, + paddle::platform::errors::InvalidArgument( + "The input tensor %s of RunProgram(Grad)Op holds " + "wrong type. Expect type is DenseTensor or VariableRefArray.", + tensor.name())); } static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, @@ -120,46 +185,32 @@ static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, "RunProgram(Grad)Op's internal scope holds " "wrong type. Expect type is SelectedRows", name)); + } else if (IsVariableRefArray(dst_tensor)) { + auto &src_tensor = src_var.Get(); + PADDLE_ENFORCE_EQ(paddle::framework::VariableRefArray::classof(&src_tensor), + true, + paddle::platform::errors::InvalidArgument( + "The output tensor %s get from " + "RunProgram(Grad)Op's internal scope holds " + "wrong type. Expect type is VariableRefArray", + name)); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "The RunProgram(Grad)Op only support output " - "variable of type LoDTensor or SelectedRows", + "variable of type DenseTensor, SelectedRows or VariableRefArray", name)); } } -static void ShareTensorsIntoScope(const std::vector &tensors, - paddle::framework::Scope *scope) { - for (size_t i = 0; i < tensors.size(); ++i) { - VLOG(4) << "Share Tensor Into Scope: " << i; - auto name = tensors[i].name(); - if (name == paddle::framework::kFakeVarName || - name == paddle::framework::kEmptyVarName) { - continue; - } - auto *var = scope->Var(name); - CheckInputVarStatus(tensors[i]); - // share tensor - auto tensor_base = tensors[i].impl(); - if (phi::DenseTensor::classof(tensor_base.get())) { - auto *dst_tensor = var->GetMutable(); - auto t = std::dynamic_pointer_cast(tensor_base); - *dst_tensor = *t; - } else if (phi::SelectedRows::classof(tensor_base.get())) { - auto *dst_tensor = var->GetMutable(); - auto t = std::dynamic_pointer_cast(tensor_base); - *dst_tensor = *t; - } - } -} - static void ShareTensorsIntoScopeWithName( const std::vector &tensors, const std::vector &tensor_names, paddle::framework::Scope *scope) { for (size_t i = 0; i < tensors.size(); ++i) { auto name = tensor_names[i]; - if (name == paddle::framework::kFakeVarName) { + VLOG(4) << "Share Tensor Into Scope: " << name; + if (name == paddle::framework::kFakeVarName || + name == paddle::framework::kEmptyVarName) { continue; } auto *var = scope->Var(name); @@ -174,102 +225,28 @@ static void ShareTensorsIntoScopeWithName( auto *dst_tensor = var->GetMutable(); auto t = std::dynamic_pointer_cast(tensor_base); *dst_tensor = *t; + } else if (paddle::framework::VariableRefArray::classof( + tensor_base.get())) { + auto *dst_tensor = var->GetMutable(); + auto t = std::dynamic_pointer_cast( + tensor_base); + *dst_tensor = *t; } } } -static auto GetNameFromValue(const ::pir::Block *block, - const std::vector<::pir::Value> &values, - bool is_input) { - // we use name here, later value is used directly. - std::unordered_map<::pir::Value, std::string> value2name; - if (is_input) { - for (auto &kwarg : block->kwargs()) { - value2name[kwarg.second] = kwarg.first; - } - } - for (auto &op : *block) { - std::string name; - if (is_input && op.name() == "pd_op.data") { - name = - op.attributes().at("name").dyn_cast().AsString(); - value2name[op.results()[0].Value::impl()] = name; - } else if (!is_input && op.name() == "builtin.set_parameter") { - name = op.attributes() - .at("parameter_name") - .dyn_cast() - .AsString(); - value2name[op.operand(0).source()] = name; - } else if (!is_input && op.name() == "builtin.shadow_output") { - name = op.attributes() - .at("output_name") - .dyn_cast() - .AsString(); - value2name[op.operand(0).source()] = name; - } else if (is_input && op.name() == "builtin.parameter") { - name = op.attributes() - .at("parameter_name") - .dyn_cast() - .AsString(); - value2name[op.result(0).Value::impl()] = name; - } else if (is_input && op.name() == "builtin.constant") { - if (op.isa()) { - name = op.dyn_cast().tensor_name(); - value2name[op.result(0).Value::impl()] = name; - } - } - } - std::vector names; - std::transform(values.begin(), - values.end(), - std::back_inserter(names), - [&value2name](const ::pir::Value &v) { - if (!value2name.count(v)) - return std::string(paddle::framework::kFakeVarName); - return value2name.at(v); - }); - return names; -} +static void ShareTensorsIntoScope(const std::vector &tensors, + paddle::framework::Scope *scope) { + const std::vector names = + [&](const std::vector &tensors) { + std::vector names; + for (auto &t : tensors) { + names.push_back(t.name()); + } + return names; + }(tensors); -static void ShareTensorsFromScope( - const std::vector &tensors, - const paddle::framework::BlockDesc &global_block, - paddle::framework::Scope *scope) { - for (size_t i = 0; i < tensors.size(); ++i) { - // NOTE: In case of setting out_tmp.stop_gradient = True in model code, all - // parameters before generating out_tmp have no @GRAD, it will raise error - // because we can't find them in scope. So we skip sharing these vars or - // var@GRAD if they don't appear in global block. - auto &name = tensors[i]->name(); - if (name == paddle::framework::kEmptyVarName || - name == paddle::framework::kFakeVarName || !global_block.HasVar(name)) { - VLOG(2) << "find tensor name is " << name << ", skip it!"; - continue; - } - // NOTE: Here skip not found var is dangerous, if a bug is caused here, - // the result is grad calculation error, which will be very hidden! - auto *var = scope->FindVar(name); - PADDLE_ENFORCE_NOT_NULL( - var, - paddle::platform::errors::NotFound("The output tensor %s is not in " - "RunProgram(Grad)Op'" - "s internal scope.", - name)); - CheckOutputVarStatus(*var, *tensors[i]); - // share tensor - if (var->IsType()) { - auto &src_tensor = var->Get(); - auto *dst_tensor = const_cast( - dynamic_cast(tensors[i]->impl().get())); - VLOG(4) << "share " << name << " from scope"; - *dst_tensor = src_tensor; - } else if (var->IsType()) { - auto &src_tensor = var->Get(); - auto *dst_tensor = const_cast( - dynamic_cast(tensors[i]->impl().get())); - *dst_tensor = src_tensor; - } - } + ShareTensorsIntoScopeWithName(tensors, names, scope); } static void ShareTensorsIntoScopeByValue( @@ -277,12 +254,7 @@ static void ShareTensorsIntoScopeByValue( const std::vector &tensors, const std::vector<::pir::Value> &values, paddle::framework::Scope *scope) { - auto names = GetNameFromValue(block, values, true); - if (VLOG_IS_ON(4)) { - for (auto &s : names) { - VLOG(4) << "ShareTensorIntoScopeByValue name: " << s; - } - } + auto names = GetNameFromValue(block, values, true, false); ShareTensorsIntoScopeWithName(tensors, names, scope); } @@ -291,11 +263,16 @@ static void ShareTensorsFromScopeByValue( const std::vector &tensors, const std::vector<::pir::Value> &values, paddle::framework::Scope *scope) { - auto names = GetNameFromValue(block, values, false); + // NOTE(SigureMo): If the program has an inplace chain connecting + // an input value to an output value, the output value will be + // replaced with the input value, so we set the `allow_input` to + // `true` in `GetNameFromValue` + auto names = GetNameFromValue(block, values, true, true); for (size_t i = 0; i < tensors.size(); ++i) { auto &name = names[i]; auto &value = values[i]; - VLOG(2) << "share " << name << " from scope"; + VLOG(4) << "Share Tensor From Scope: " << name; + if (value.impl() == nullptr) { // skip stop_gradient. continue; @@ -320,6 +297,17 @@ static void ShareTensorsFromScopeByValue( auto *dst_tensor = const_cast( dynamic_cast(tensors[i]->impl().get())); *dst_tensor = src_tensor; + } else if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast( + tensors[i]->impl().get())); + *dst_tensor = src_tensor; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The RunProgram(Grad)Op only support output " + "variable of type DenseTensor, SelectedRows or VariableRefArray", + name)); } } } @@ -350,6 +338,17 @@ static void ShareTensorsFromScopeWithPartialBlock( auto *dst_tensor = const_cast( dynamic_cast(tensors[i]->impl().get())); *dst_tensor = src_tensor; + } else if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast( + tensors[i]->impl().get())); + *dst_tensor = src_tensor; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The RunProgram(Grad)Op only support output " + "variable of type DenseTensor, SelectedRows or VariableRefArray", + name)); } } } @@ -489,15 +488,14 @@ inline void PirRunProgramAPI( VLOG(10) << is_test << program_id; - auto &interpretercore_info_cache = - paddle::framework::InterpreterCoreInfoCache::Instance(); + auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false, - /*in_pir_mode=*/true)) { + if (!cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/true)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -532,20 +530,20 @@ inline void PirRunProgramAPI( // *backward_program); // update interpretercore skip_gc_var - auto skip_names = - details::GetNameFromValue(forward_global_block, middle_values, false); + auto skip_names = details::GetNameFromValue( + forward_global_block, middle_values, false, true); auto skip_names_set = std::set(skip_names.begin(), skip_names.end()); auto no_need_buffer_values = PADDLE_GET_CONST(std::vector<::pir::Value>, attrs.at("no_need_buffers")); auto no_need_buffer_names = details::GetNameFromValue( - forward_global_block, no_need_buffer_values, false); + forward_global_block, no_need_buffer_values, false, true); for (auto &name : no_need_buffer_names) { VLOG(4) << "Find no need buffer vars with name:" << name; skip_names_set.erase(name); } - skip_names = - details::GetNameFromValue(forward_global_block, output_values, false); + skip_names = details::GetNameFromValue( + forward_global_block, output_values, false, true); skip_names_set.insert(skip_names.begin(), skip_names.end()); details::print_collection(skip_names_set); interpreter_core->SetSkipGcVars(skip_names_set); @@ -554,7 +552,7 @@ inline void PirRunProgramAPI( // input_vars.insert(input_names.begin(), input_names.end()); // interpreter_core->SetJitInputVars(input_vars); - // interpretercore_info_cache.UpdateSkipEagerDeleteVars( + // cache.UpdateSkipEagerDeleteVars( // program_id, global_inner_scope, false, skip_eager_delete_vars); } else { paddle::platform::RecordEvent record_event( @@ -563,12 +561,11 @@ inline void PirRunProgramAPI( 1); VLOG(2) << "Get interpretercore cache by program:" << program_id; // Step 1. get cache interpretercore - auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false, - /*in_pir_mode=*/true); + auto &cached_value = cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/true); interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScopeByValue( @@ -702,15 +699,14 @@ inline void RunProgramAPI( backward_program = backward_global_block->Program(); } - auto &interpretercore_info_cache = - paddle::framework::InterpreterCoreInfoCache::Instance(); + auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false, - /*in_pir_mode=*/in_pir_pt_mode)) { + if (!cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -776,13 +772,12 @@ inline void RunProgramAPI( VLOG(6) << s.str(); } - interpretercore_info_cache.UpdateSkipEagerDeleteVars( - program_id, - global_inner_scope, - place_hash_key, - false, - in_pir_pt_mode, - skip_eager_delete_vars); + cache.UpdateSkipEagerDeleteVars(program_id, + global_inner_scope, + place_hash_key, + false, + in_pir_pt_mode, + skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); } else { paddle::platform::RecordEvent record_event( @@ -791,12 +786,11 @@ inline void RunProgramAPI( 1); VLOG(2) << "Get interpretercore cache by program:" << program_id; // Step 1. get cache interpretercore - auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false, - /*in_pir_mode=*/in_pir_pt_mode); + auto &cached_value = cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode); interpreter_core = cached_value.core_; // Step 2. update scope for cache interpretercore details::ShareTensorsIntoScopeWithName(x, input_names, global_inner_scope); @@ -881,15 +875,14 @@ inline void RunProgramGradAPI( details::Trans2ContiguousTensorsInplace(out_grad); auto out_grad_names = details::GetTensorsName(out_grad); - auto &interpretercore_info_cache = - paddle::framework::InterpreterCoreInfoCache::Instance(); + auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/true, - /*in_pir_mode=*/in_pir_pt_mode)) { + if (!cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/in_pir_pt_mode)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -929,13 +922,13 @@ inline void RunProgramGradAPI( // share threadpool // NOTE(zhiqiu): this only works interpreter_core is executed strictly // after the related fwd_interpreter_core. - if (interpretercore_info_cache.Has(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false, - /*in_pir_mode=*/in_pir_pt_mode)) { + if (cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/in_pir_pt_mode)) { auto fwd_interpreter_core = - interpretercore_info_cache + cache .GetMutable(program_id, global_inner_scope, place_hash_key, @@ -963,13 +956,12 @@ inline void RunProgramGradAPI( paddle::framework::details::AppendSkipDeletionVars(param_grad_names, &skip_eager_delete_vars); interpreter_core->SetSkipGcVars(skip_eager_delete_vars); - interpretercore_info_cache.UpdateSkipEagerDeleteVars( - program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/true, - in_pir_pt_mode, - skip_eager_delete_vars); + cache.UpdateSkipEagerDeleteVars(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + in_pir_pt_mode, + skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); } else { paddle::platform::RecordEvent record_event( @@ -977,12 +969,11 @@ inline void RunProgramGradAPI( paddle::platform::TracerEventType::UserDefined, 1); VLOG(2) << "Get interpretercore cache by program:" << program_id; - auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/true, - /*in_pir_mode=*/in_pir_pt_mode); + auto &cached_value = cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/in_pir_pt_mode); interpreter_core = cached_value.core_; // update scope @@ -1027,8 +1018,8 @@ inline void PirRunProgramGradAPI( const std::vector &x, const std::vector ¶ms, const std::vector &out_grad, - const std::vector &middles, - const std::vector &out, + std::vector &middles, // NOLINT + std::vector &out, // NOLINT const std::vector &step_scope, // NOLINT const paddle::framework::AttributeMap &attrs, std::vector &x_grad, // NOLINT @@ -1087,15 +1078,18 @@ inline void PirRunProgramGradAPI( details::ShareTensorsIntoScopeByValue( backward_global_block, params, parameter_values, global_inner_scope); - auto &interpretercore_info_cache = - paddle::framework::InterpreterCoreInfoCache::Instance(); + // Clear out and middles to avoid hold memory until backward finish. + out.clear(); + middles.clear(); + + auto &cache = paddle::framework::InterpreterCoreInfoCache::Instance(); std::shared_ptr interpreter_core = nullptr; - if (!interpretercore_info_cache.Has(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/true, - /*in_pir_mode=*/true)) { + if (!cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/true)) { paddle::platform::RecordEvent record_event( "create_new_interpretercore", paddle::platform::TracerEventType::UserDefined, @@ -1120,12 +1114,12 @@ inline void PirRunProgramGradAPI( // share threadpool // NOTE(zhiqiu): this only works interpreter_core is executed strictly // after the related fwd_interpreter_core. - if (interpretercore_info_cache.Has(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/false, - /*in_pir_mode=*/true)) { - auto fwd_interpreter_core = interpretercore_info_cache + if (cache.Has(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/false, + /*in_pir_mode=*/true)) { + auto fwd_interpreter_core = cache .GetMutable(program_id, global_inner_scope, place_hash_key, @@ -1139,20 +1133,19 @@ inline void PirRunProgramGradAPI( // get all eager gc vars std::set skip_eager_delete_vars; - auto skip_names = - details::GetNameFromValue(backward_global_block, x_grad_values, false); + auto skip_names = details::GetNameFromValue( + backward_global_block, x_grad_values, false, true); skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); - skip_names = - details::GetNameFromValue(backward_global_block, p_grad_values, false); + skip_names = details::GetNameFromValue( + backward_global_block, p_grad_values, false, true); skip_eager_delete_vars.insert(skip_names.begin(), skip_names.end()); interpreter_core->SetSkipGcVars(skip_eager_delete_vars); - interpretercore_info_cache.UpdateSkipEagerDeleteVars( - program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/true, - /*in_pir_mode=*/true, - skip_eager_delete_vars); + cache.UpdateSkipEagerDeleteVars(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/true, + skip_eager_delete_vars); VLOG(2) << "Get skip GC vars size is: " << skip_eager_delete_vars.size(); details::print_collection(skip_eager_delete_vars); } else { @@ -1161,12 +1154,11 @@ inline void PirRunProgramGradAPI( paddle::platform::TracerEventType::UserDefined, 1); VLOG(2) << "Get interpretercore cache by program:" << program_id; - auto &cached_value = - interpretercore_info_cache.GetMutable(program_id, - global_inner_scope, - place_hash_key, - /*is_grad=*/true, - /*in_pir_mode=*/true); + auto &cached_value = cache.GetMutable(program_id, + global_inner_scope, + place_hash_key, + /*is_grad=*/true, + /*in_pir_mode=*/true); interpreter_core = cached_value.core_; if (interpreter_core->GetVariableScope()->GetMutableScope() != @@ -1519,12 +1511,19 @@ class PirGradNodeRunProgram : public egr::GradNodeBase { x_grad_values.size())); // TODO(dev): Need an elegant way to determine information of grad_tensor, - // such as: name, tensor type(DenseTensor or SelectedRows). + // such as: name, tensor type (DenseTensor, SelectedRows or + // VariableRefArray). for (size_t i = 0; i < x.size(); i++) { if (x[i].is_dense_tensor()) { x_grad->emplace_back(std::make_shared()); } else if (x[i].is_selected_rows()) { x_grad->emplace_back(std::make_shared()); + } else if (details::IsVariableRefArray(x[i])) { + x_grad->emplace_back( + std::make_shared()); + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The grad tensor type is not supported.")); } } } diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 15486bbb1580a..5f8a768cd65dd 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -34,9 +34,9 @@ limitations under the License. */ namespace paddle { namespace framework { -paddle::any GetAttrValue(const Attribute& attr); +TEST_API paddle::any GetAttrValue(const Attribute& attr); -Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); +TEST_API Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); Attribute GetAttrValue(const proto::VarDesc::Attr& attr_desc); @@ -350,9 +350,10 @@ class AttrReader { }; paddle::experimental::Scalar MakeScalarFromProto(const proto::Scalar& v); -proto::Scalar MakeScalarProto(const paddle::experimental::Scalar& v); -paddle::experimental::Scalar MakeScalarFromAttribute(const Attribute& v); -std::vector MakeScalarsFromAttribute( +TEST_API proto::Scalar MakeScalarProto(const paddle::experimental::Scalar& v); +TEST_API paddle::experimental::Scalar MakeScalarFromAttribute( + const Attribute& v); +TEST_API std::vector MakeScalarsFromAttribute( const Attribute& v); void CanonicalizeScalarAttrs(const proto::OpProto& op_proto, AttributeMap* attrs); diff --git a/paddle/fluid/framework/block_desc.cc b/paddle/fluid/framework/block_desc.cc index 7ba2ebc8fe027..d5533f5ea6e1d 100644 --- a/paddle/fluid/framework/block_desc.cc +++ b/paddle/fluid/framework/block_desc.cc @@ -320,7 +320,7 @@ void BlockDesc::MoveFrom(BlockDesc *block) { std::vector old_block_desc; // NOTE(GhostScreaming): don't use program->proto()->blocks_size(), // previous assignment of new Variable in vars_ use std::move, - // which makes 'var_ptr' which holded by 'block' a nullptr. + // which makes 'var_ptr' which held by 'block' a nullptr. // block->Program()->proto() will calls Flush() at first, // a null var_ptr will cause segmentation fault. int block_size = static_cast(program->Size()); diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 46416f17b3cd0..6c7d9bdb29e64 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -32,11 +32,11 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/api/all.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/any.h" +#include "paddle/utils/string/string_helper.h" #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/phi/backends/device_manager.h" @@ -147,7 +147,7 @@ static void RunKernelFunc( in_name)); VLOG(3) << "Custom Operator: KernelFunc's input " << in_name << " is optional dtype with None input"; - kernel_ctx.EmplaceBackInput(std::move(paddle::Tensor())); + kernel_ctx.EmplaceBackInput(paddle::Tensor()); } } } @@ -215,7 +215,7 @@ static void RunKernelFunc( VLOG(3) << "Custom Operator: InferDtype - inplace optional outputs : " << out_name << " is None."; true_out_ptrs.emplace_back(nullptr); - kernel_ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + kernel_ctx.EmplaceBackOutput(paddle::Tensor()); continue; } // general/inplace vector outputs @@ -252,7 +252,7 @@ static void RunKernelFunc( VLOG(3) << "Custom Operator: InferDtype - inplace optional outputs : " << out_name << " is None."; true_out_ptrs.emplace_back(nullptr); - kernel_ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + kernel_ctx.EmplaceBackOutput(paddle::Tensor()); continue; } // general/inplace Tensor outputs diff --git a/paddle/fluid/framework/custom_operator_utils.h b/paddle/fluid/framework/custom_operator_utils.h index 31b0793c8fb6a..994544357dc64 100644 --- a/paddle/fluid/framework/custom_operator_utils.h +++ b/paddle/fluid/framework/custom_operator_utils.h @@ -17,13 +17,16 @@ limitations under the License. */ #include #include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/core/enforce.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { constexpr char kCustomDialectPrefix[] = "custom_op."; // NOLINT +constexpr char kGradSuffix[] = "_grad"; // NOLINT +constexpr char kDoubleGradSuffix[] = "_grad_grad"; // NOLINT + namespace detail { // dynamic lib load func @@ -93,10 +96,10 @@ inline static const OpMetaInfo* GetGradOpInfoByFwdPirName( } pos = custom_name.length(); - if (custom_name.find("_grad_grad") != custom_name.npos) { - pos = custom_name.find("_grad_grad"); - } else if (custom_name.find("_grad") != custom_name.npos) { - pos = custom_name.find("_grad"); + if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) { + pos = custom_name.find(kDoubleGradSuffix); + } else if (custom_name.find(kGradSuffix) != custom_name.npos) { + pos = custom_name.find(kGradSuffix); } auto custom_name_prefix = custom_name.substr(0, pos); auto map_iter = @@ -106,10 +109,10 @@ inline static const OpMetaInfo* GetGradOpInfoByFwdPirName( } const auto& vec_op_meta = map_iter->second; const OpMetaInfo* ret = nullptr; - if (custom_name.find("_grad_grad") != custom_name.npos) { + if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) { PADDLE_THROW("Custom op : " + custom_name_prefix + " doesn't support triple grad."); - } else if (custom_name.find("_grad") != custom_name.npos) { + } else if (custom_name.find(kGradSuffix) != custom_name.npos) { bool has_double_grad = vec_op_meta.size() >= 3; ret = has_double_grad ? &(vec_op_meta[2]) : nullptr; } else { @@ -130,10 +133,10 @@ inline static const OpMetaInfo& GetOpInfoByPirName( } pos = custom_name.length(); - if (custom_name.find("_grad_grad") != custom_name.npos) { - pos = custom_name.find("_grad_grad"); - } else if (custom_name.find("_grad") != custom_name.npos) { - pos = custom_name.find("_grad"); + if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) { + pos = custom_name.find(kDoubleGradSuffix); + } else if (custom_name.find(kGradSuffix) != custom_name.npos) { + pos = custom_name.find(kGradSuffix); } auto custom_name_prefix = custom_name.substr(0, pos); auto map_iter = @@ -142,9 +145,9 @@ inline static const OpMetaInfo& GetOpInfoByPirName( PADDLE_THROW("The info of custom op : " + custom_name + " is not exists!"); } const auto& vec_op_meta = map_iter->second; - if (custom_name.find("_grad_grad") != custom_name.npos) { + if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) { return vec_op_meta[2]; - } else if (custom_name.find("_grad") != custom_name.npos) { + } else if (custom_name.find(kGradSuffix) != custom_name.npos) { return vec_op_meta[1]; } else { return vec_op_meta[0]; @@ -161,10 +164,10 @@ inline static bool HasGradOp(const std::string& fwd_pir_op_name) { } pos = custom_name.length(); - if (custom_name.find("_grad_grad") != custom_name.npos) { - pos = custom_name.find("_grad_grad"); - } else if (custom_name.find("_grad") != custom_name.npos) { - pos = custom_name.find("_grad"); + if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) { + pos = custom_name.find(kDoubleGradSuffix); + } else if (custom_name.find(kGradSuffix) != custom_name.npos) { + pos = custom_name.find(kGradSuffix); } auto custom_name_prefix = custom_name.substr(0, pos); auto map_iter = @@ -174,10 +177,10 @@ inline static bool HasGradOp(const std::string& fwd_pir_op_name) { " is not exists!"); } const auto& vec_op_meta = map_iter->second; - if (custom_name.find("_grad_grad") != custom_name.npos) { + if (custom_name.find(kDoubleGradSuffix) != custom_name.npos) { // custom op only support double grad, there will not have triple grad op return false; - } else if (custom_name.find("_grad") != custom_name.npos) { + } else if (custom_name.find(kGradSuffix) != custom_name.npos) { // vec_op_meta.size() == 3 means the op has double grad op return vec_op_meta.size() > 2UL; } else { @@ -247,7 +250,8 @@ static std::vector> RunDefaultInferShape( const std::vector>>& vec_input_shapes, const std::unordered_map& vec_input_name2id_map) { std::vector> output_shapes; - auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(custom_op_meta); + auto& inplace_reverse_map = + OpMetaInfoHelper::GetInplaceReverseMap(custom_op_meta); // Op is grad op if (custom_op_meta.IsGradOp() || custom_op_meta.IsDoubleGradOp()) { bool is_double_grad = custom_op_meta.IsDoubleGradOp(); @@ -278,6 +282,10 @@ static std::vector> RunDefaultInferShape( bwd_input_name) != bwd_inputs_name.end()) { int input_index = input_name2id_map.at(bwd_input_name); auto input_shape = input_shapes[input_index]; + if (input_shape.size() == 0) { + // if optional tensor is None, we don't need to infer shape + continue; + } output_shapes.push_back(input_shape); } else { PADDLE_ENFORCE_EQ( @@ -299,7 +307,8 @@ static std::vector> RunDefaultInferShape( } // Op is forward op - if (inplace_map.empty()) { // general case, assure single input and output + if (inplace_reverse_map + .empty()) { // general case, assure single input and output VLOG(3) << "Custom Operator: Default InferShape - share ddim."; if (input_shapes.size() == 1) { output_shapes = input_shapes; @@ -311,15 +320,21 @@ static std::vector> RunDefaultInferShape( "and only one output without setting the InferShapeFn. ")); } } else { // inplace case - for (auto const& pair : inplace_map) { - if (paddle::framework::detail::IsDuplicableVar(pair.second)) { - int input_index = vec_input_name2id_map.at(pair.first); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(custom_op_meta); + for (auto& output : outputs) { + auto input_name = inplace_reverse_map.at(output); + if (paddle::framework::detail::IsDuplicableVar(output)) { + int input_index = vec_input_name2id_map.at(input_name); auto input_shape = vec_input_shapes[input_index]; output_shapes.insert( output_shapes.end(), input_shape.begin(), input_shape.end()); } else { - int input_index = input_name2id_map.at(pair.first); + int input_index = input_name2id_map.at(input_name); auto input_shape = input_shapes[input_index]; + if (input_shape.size() == 0) { + // if optional tensor is None, we don't need to infer shape + continue; + } output_shapes.push_back(input_shape); } } @@ -334,7 +349,8 @@ static std::vector RunDefaultInferDtype( const std::vector>& vec_input_dtypes, const std::unordered_map& vec_input_name2id_map) { std::vector output_dtypes; - auto& inplace_map = OpMetaInfoHelper::GetInplaceMap(custom_op_meta); + auto& inplace_reverse_map = + OpMetaInfoHelper::GetInplaceReverseMap(custom_op_meta); // Op is grad op if (custom_op_meta.IsGradOp() || custom_op_meta.IsDoubleGradOp()) { bool is_double_grad = custom_op_meta.IsDoubleGradOp(); @@ -357,6 +373,10 @@ static std::vector RunDefaultInferDtype( bwd_input_name) != bwd_inputs_name.end()) { int input_index = input_name2id_map.at(bwd_input_name); auto input_dtype = input_dtypes[input_index]; + if (input_dtype == DataType::UNDEFINED) { + // if optional tensor is None, we don't need to infer dtype + continue; + } output_dtypes.push_back(input_dtype); } else { // If there is no corresponding input for the output, set float as @@ -368,7 +388,8 @@ static std::vector RunDefaultInferDtype( return output_dtypes; } - if (inplace_map.empty()) { // general case, assure single input and output + if (inplace_reverse_map + .empty()) { // general case, assure single input and output VLOG(3) << "Custom Operator: Default InferDtype - share ddim."; if (input_dtypes.size() == 1) { output_dtypes = input_dtypes; @@ -380,15 +401,21 @@ static std::vector RunDefaultInferDtype( "and only one output without setting the InferDtypeFn. ")); } } else { // inplace case - for (auto const& pair : inplace_map) { - if (paddle::framework::detail::IsDuplicableVar(pair.second)) { - int input_index = vec_input_name2id_map.at(pair.first); + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(custom_op_meta); + for (auto& output : outputs) { + auto input_name = inplace_reverse_map.at(output); + if (paddle::framework::detail::IsDuplicableVar(output)) { + int input_index = vec_input_name2id_map.at(input_name); auto input_dtype = vec_input_dtypes[input_index]; output_dtypes.insert( output_dtypes.end(), input_dtype.begin(), input_dtype.end()); } else { - int input_index = input_name2id_map.at(pair.first); + int input_index = input_name2id_map.at(input_name); auto input_dtype = input_dtypes[input_index]; + if (input_dtype == DataType::UNDEFINED) { + // if optional tensor is None, we don't need to infer dtype + continue; + } output_dtypes.push_back(input_dtype); } } @@ -405,7 +432,57 @@ static std::vector> RunInferShape( const std::unordered_map& vec_input_name2id_map, const std::vector& custom_attrs) { if (infershape_func) { - return infershape_func(input_shapes, vec_input_shapes, custom_attrs); + std::vector> infershape_result = + infershape_func(input_shapes, vec_input_shapes, custom_attrs); + std::vector> complete_result; + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(custom_op_meta); + const auto& inplace_reverse_map = + paddle::OpMetaInfoHelper::GetInplaceReverseMap(custom_op_meta); + + // The real output shape result is ( infershape func result + inplace output + // result), because the infershape doesn't create output shape that belongs + // to inplace output. + size_t infershape_result_index = 0; + for (auto& out_name : outputs) { + if (paddle::framework::detail::IsDuplicableVar(out_name)) { + PADDLE_ENFORCE( + inplace_reverse_map.find(out_name) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manually.")); + auto in_name = inplace_reverse_map.at(out_name); + if (custom_op_meta.IsGradOp() || custom_op_meta.IsDoubleGradOp()) { + const auto& bwd_op_name = + paddle::OpMetaInfoHelper::GetOpName(custom_op_meta); + bool is_double_grad_op = + (bwd_op_name.find(kDoubleGradSuffix) != bwd_op_name.npos) ? true + : false; + in_name = + paddle::framework::detail::NoGrad(out_name, is_double_grad_op); + } + auto index = vec_input_name2id_map.at(in_name); + const auto& vec_input_shape = vec_input_shapes[index]; + complete_result.insert(complete_result.end(), + vec_input_shape.begin(), + vec_input_shape.end()); + } else { + if (inplace_reverse_map.find(out_name) != inplace_reverse_map.end()) { + auto in_name = inplace_reverse_map.at(out_name); + auto index = input_name2id_map.at(in_name); + if (input_shapes[index].size() == 0) { + // if optional tensor is None, we don't need to infer shape, + continue; + } + complete_result.push_back(input_shapes[index]); + } else { + complete_result.push_back(infershape_result[infershape_result_index]); + infershape_result_index++; + } + } + } + return complete_result; } else { return RunDefaultInferShape(custom_op_meta, input_shapes, @@ -424,7 +501,57 @@ static std::vector RunInferDtype( const std::unordered_map& vec_input_name2id_map, const std::vector& custom_attrs) { if (inferdtype_func) { - return inferdtype_func(input_dtypes, vec_input_dtypes, custom_attrs); + std::vector complete_result; + const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(custom_op_meta); + const auto& inplace_reverse_map = + paddle::OpMetaInfoHelper::GetInplaceReverseMap(custom_op_meta); + std::vector inferdtype_result = + inferdtype_func(input_dtypes, vec_input_dtypes, custom_attrs); + + // The real output dtype result is ( infershape func dtype + inplace output + // dtype), because the inferdtype doesn't create output dtype that belongs + // to inplace output. + size_t inferdtype_result_index = 0; + for (auto& out_name : outputs) { + if (paddle::framework::detail::IsDuplicableVar(out_name)) { + PADDLE_ENFORCE( + inplace_reverse_map.find(out_name) != inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Custom operator only supports `paddle::Vec(...)` inputs and " + "cannot support `paddle::Vec(...)` output without setting " + "InplaceMap. If you have to use `paddle::Vec(...)` output, " + "please indicate it by setting InplaceMap manually.")); + auto in_name = inplace_reverse_map.at(out_name); + if (custom_op_meta.IsGradOp() || custom_op_meta.IsDoubleGradOp()) { + const auto& bwd_op_name = + paddle::OpMetaInfoHelper::GetOpName(custom_op_meta); + bool is_double_grad_op = + (bwd_op_name.find(kDoubleGradSuffix) != bwd_op_name.npos) ? true + : false; + in_name = + paddle::framework::detail::NoGrad(out_name, is_double_grad_op); + } + auto index = vec_input_name2id_map.at(in_name); + const auto& vec_input_dtype = vec_input_dtypes[index]; + complete_result.insert(complete_result.end(), + vec_input_dtype.begin(), + vec_input_dtype.end()); + } else { + if (inplace_reverse_map.find(out_name) != inplace_reverse_map.end()) { + auto in_name = inplace_reverse_map.at(out_name); + auto index = input_name2id_map.at(in_name); + if (input_dtypes[index] == DataType::UNDEFINED) { + // if optional tensor is None, we don't need to infer dtype + continue; + } + complete_result.push_back(input_dtypes[index]); + } else { + complete_result.push_back(inferdtype_result[inferdtype_result_index]); + inferdtype_result_index++; + } + } + } + return complete_result; } else { return RunDefaultInferDtype(custom_op_meta, input_dtypes, diff --git a/paddle/fluid/framework/data_feed.cc b/paddle/fluid/framework/data_feed.cc index cec1f664ce0f1..9489d22e34d21 100644 --- a/paddle/fluid/framework/data_feed.cc +++ b/paddle/fluid/framework/data_feed.cc @@ -1813,7 +1813,7 @@ int PaddleBoxDataFeed::Next() { this->batch_size_ = index; VLOG(3) << "pv_batch_size_=" << this->batch_size_ << ", thread_id=" << thread_id_; - if (this->batch_size_ != 0) { + if (this->batch_size_ != 0) { // NOLINT PutToFeedVec(pv_vec); } else { VLOG(3) << "finish reading, output_pv_channel_ size=" @@ -2113,7 +2113,7 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { finish_init_ = true; input_type_ = data_feed_desc.input_type(); size_t pos = pipe_command_.find(".so"); - if (pos != std::string::npos) { + if (pos != std::string::npos) { // NOLINT pos = pipe_command_.rfind('|'); if (pos == std::string::npos) { so_parser_name_ = pipe_command_; @@ -2129,7 +2129,7 @@ void SlotRecordInMemoryDataFeed::Init(const DataFeedDesc& data_feed_desc) { #if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS) gpu_graph_data_generator_.SetConfig(data_feed_desc); #endif - if (gpu_graph_mode_) { + if (gpu_graph_mode_) { // NOLINT train_mode_ = true; } else { train_mode_ = data_feed_desc.graph_config().gpu_graph_training(); @@ -2780,7 +2780,7 @@ int SlotRecordInMemoryDataFeed::Next() { this->batch_size_ = batch.second; VLOG(3) << "batch_size_=" << this->batch_size_ << ", thread_id=" << thread_id_; - if (this->batch_size_ != 0) { + if (this->batch_size_ != 0) { // NOLINT PutToFeedVec(&records_[batch.first], this->batch_size_); } else { VLOG(3) << "finish reading for heterps, batch size zero, thread_id=" diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 14b2e87b56e7c..9228f2701f584 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -41,7 +41,7 @@ limitations under the License. */ #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/timer.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #if defined(PADDLE_WITH_CUDA) #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_utils.h" #include "paddle/fluid/platform/cuda_device_guard.h" diff --git a/paddle/fluid/framework/data_layout_transform.h b/paddle/fluid/framework/data_layout_transform.h index 1b5639d5be981..b9b4b7a8308b4 100644 --- a/paddle/fluid/framework/data_layout_transform.h +++ b/paddle/fluid/framework/data_layout_transform.h @@ -47,11 +47,11 @@ struct CastDataLayout { std::vector GetAxis(const DataLayout& from, const DataLayout& to); -void TransDataLayout(const phi::KernelKey& kernel_type_for_var, - const phi::KernelKey& expected_kernel_type, - const phi::DenseTensor& in, - phi::DenseTensor* out, - const phi::Place& place); +TEST_API void TransDataLayout(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, + const phi::DenseTensor& in, + phi::DenseTensor* out, + const phi::Place& place); void TransDataLayout(phi::DataLayout from_layout, phi::DataLayout to_layout, diff --git a/paddle/fluid/framework/data_set.cc b/paddle/fluid/framework/data_set.cc index 0c48c6e1a25ad..231428c5a3721 100644 --- a/paddle/fluid/framework/data_set.cc +++ b/paddle/fluid/framework/data_set.cc @@ -966,7 +966,7 @@ void DatasetImpl::DynamicAdjustChannelNum(int channel_num, CHECK(output_channels_data_size == 0); // NOLINT cur_channel = 1; } - if (cur_channel == 0) { + if (cur_channel == 0) { // NOLINT origin_channels = &multi_output_channel_; other_channels = &multi_consume_channel_; origin_pv_channels = &multi_pv_output_; @@ -1111,8 +1111,8 @@ void DatasetImpl::CreateReaders() { if (input_pv_channel_ != nullptr) { readers_[i]->SetInputPvChannel(input_pv_channel_.get()); } - if (cur_channel_ == 0 && - static_cast(channel_idx) < multi_output_channel_.size()) { + if (cur_channel_ == 0 && static_cast(channel_idx) < + multi_output_channel_.size()) { // NOLINT readers_[i]->SetOutputChannel(multi_output_channel_[channel_idx].get()); readers_[i]->SetConsumeChannel(multi_consume_channel_[channel_idx].get()); readers_[i]->SetOutputPvChannel(multi_pv_output_[channel_idx].get()); @@ -1441,40 +1441,39 @@ void MultiSlotDataset::GenerateLocalTablesUnlock(int table_id, } } }; - auto gen_func = - [this, &shard_num, &feadim, &local_map_tables, &consume_func](int i) { - std::vector vec_data; - std::vector> task_keys(shard_num); - std::vector> task_futures; - this->multi_output_channel_[i]->Close(); - this->multi_output_channel_[i]->ReadAll(vec_data); - for (auto& item : vec_data) { - for (auto& feature : item.uint64_feasigns_) { - int shard = - static_cast(feature.sign().uint64_feasign_ % shard_num); - task_keys[shard].push_back(feature.sign().uint64_feasign_); - } - } + auto gen_func = [this, &shard_num, &feadim, &consume_func](int i) { + std::vector vec_data; + std::vector> task_keys(shard_num); + std::vector> task_futures; + this->multi_output_channel_[i]->Close(); + this->multi_output_channel_[i]->ReadAll(vec_data); + for (auto& item : vec_data) { + for (auto& feature : item.uint64_feasigns_) { + int shard = + static_cast(feature.sign().uint64_feasign_ % shard_num); + task_keys[shard].push_back(feature.sign().uint64_feasign_); + } + } - for (int shard_id = 0; shard_id < shard_num; shard_id++) { - task_futures.emplace_back(consume_task_pool_[shard_id]->enqueue( - consume_func, shard_id, feadim, task_keys[shard_id])); - } + for (int shard_id = 0; shard_id < shard_num; shard_id++) { + task_futures.emplace_back(consume_task_pool_[shard_id]->enqueue( + consume_func, shard_id, feadim, task_keys[shard_id])); + } - multi_output_channel_[i]->Open(); - multi_output_channel_[i]->Write(std::move(vec_data)); - vec_data.clear(); - vec_data.shrink_to_fit(); - for (auto& tk : task_keys) { - tk.clear(); - std::vector().swap(tk); - } - task_keys.clear(); - std::vector>().swap(task_keys); - for (auto& tf : task_futures) { - tf.wait(); - } - }; + multi_output_channel_[i]->Open(); + multi_output_channel_[i]->Write(std::move(vec_data)); + vec_data.clear(); + vec_data.shrink_to_fit(); + for (auto& tk : task_keys) { + tk.clear(); + std::vector().swap(tk); + } + task_keys.clear(); + std::vector>().swap(task_keys); + for (auto& tf : task_futures) { + tf.wait(); + } + }; for (size_t i = 0; i < threads.size(); i++) { threads[i] = std::thread(gen_func, i); } @@ -1722,7 +1721,7 @@ void MultiSlotDataset::PreprocessChannel( const std::set& slots_to_replace, std::unordered_set& index_slots) { // NOLINT int out_channel_size = 0; - if (cur_channel_ == 0) { + if (cur_channel_ == 0) { // NOLINT for (auto& item : multi_output_channel_) { out_channel_size += static_cast(item->Size()); } @@ -1757,7 +1756,7 @@ void MultiSlotDataset::PreprocessChannel( input_channel_->ReadAll(slots_shuffle_original_data_); } else { CHECK(out_channel_size > 0); // NOLINT - if (cur_channel_ == 0) { + if (cur_channel_ == 0) { // NOLINT for (auto& item : multi_output_channel_) { std::vector vec_data; item->Close(); @@ -1792,7 +1791,7 @@ void MultiSlotDataset::PreprocessChannel( } else { // if already have original data for slots shuffle, clear channel input_channel_->Clear(); - if (cur_channel_ == 0) { + if (cur_channel_ == 0) { // NOLINT for (auto& item : multi_output_channel_) { if (!item) { continue; @@ -1808,22 +1807,22 @@ void MultiSlotDataset::PreprocessChannel( } } } - int end_size = 0; - if (cur_channel_ == 0) { - for (auto& item : multi_output_channel_) { - if (!item) { - continue; - } - end_size += static_cast(item->Size()); - } - } else { - for (auto& item : multi_consume_channel_) { - if (!item) { - continue; - } - end_size += static_cast(item->Size()); - } - } + // int end_size = 0; + // if (cur_channel_ == 0) { // NOLINT + // for (auto& item : multi_output_channel_) { + // if (!item) { + // continue; + // } + // end_size += static_cast(item->Size()); + // } + // } else { + // for (auto& item : multi_consume_channel_) { + // if (!item) { + // continue; + // } + // end_size += static_cast(item->Size()); + // } + // } CHECK(input_channel_->Size() == 0) << "input channel should be empty before slots shuffle"; } diff --git a/paddle/fluid/framework/data_transform.cc b/paddle/fluid/framework/data_transform.cc index 9bb07bb47ea0f..039ed3ffc2441 100644 --- a/paddle/fluid/framework/data_transform.cc +++ b/paddle/fluid/framework/data_transform.cc @@ -82,7 +82,7 @@ void TransformData(const phi::KernelKey &expected_kernel_type, phi::funcs::make_memory_desc(out, lin); out.set_mem_desc(out_mem_desc); } else { - // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel + // Case2 - transform from ONEDNN OPKernel to Non-ONEDNN OPKernel // Do transform via ONEDNN lib PADDLE_ENFORCE(lin == DataLayout::ONEDNN && lout != DataLayout::ONEDNN, platform::errors::InvalidArgument( @@ -97,12 +97,12 @@ void TransformData(const phi::KernelKey &expected_kernel_type, place); } } else { - // Case3 - transfrom between Non-ONEDNN OPKernels + // Case3 - transform between Non-ONEDNN OPKernels TransDataLayout( kernel_type_for_var, expected_kernel_type, in, &out, place); } #else - // Case3 - transfrom between Non-ONEDNN OPKernels + // Case3 - transform between Non-ONEDNN OPKernels TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out, place); #endif transformed = true; diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index d2344fb68d3e4..b5fa02eeb2bc8 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -29,7 +29,7 @@ namespace paddle { namespace framework { TEST_API std::string DataTypeToString(const proto::VarType::Type type); -extern size_t SizeOfType(proto::VarType::Type type); +TEST_API extern size_t SizeOfType(proto::VarType::Type type); template struct IsComplex : public std::false_type {}; @@ -123,7 +123,7 @@ _ForEachDataType_(DefineDataTypeTrait); #undef DefineDataTypeTrait -extern proto::VarType::Type ToDataType(std::type_index type); +TEST_API extern proto::VarType::Type ToDataType(std::type_index type); extern std::type_index ToTypeIndex(proto::VarType::Type type); template diff --git a/paddle/fluid/framework/data_type_transform.h b/paddle/fluid/framework/data_type_transform.h index 2ec193b675097..aa25fb3653013 100644 --- a/paddle/fluid/framework/data_type_transform.h +++ b/paddle/fluid/framework/data_type_transform.h @@ -28,10 +28,10 @@ class OpKernelType; using KernelTypePair = std::pair; -void TransDataType(const phi::KernelKey& kernel_type_for_var, - const phi::KernelKey& expected_kernel_type, - const phi::DenseTensor& in, - phi::DenseTensor* out); +TEST_API void TransDataType(const phi::KernelKey& kernel_type_for_var, + const phi::KernelKey& expected_kernel_type, + const phi::DenseTensor& in, + phi::DenseTensor* out); void TransDataType(const phi::DenseTensor& in, const paddle::framework::proto::VarType::Type& type, phi::DenseTensor* out); diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 1114fea8a23f7..4c78b12fd4ac4 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -260,7 +260,7 @@ void AllReduceOpHandle::AllReduceFunc( size_t size = numel * SizeOfType(framework::TransToProtoVarType(trg.dtype())); - RunAndRecordEvent(p, [&trg, var, p, size] { + RunAndRecordEvent(p, [&trg, var, size] { auto dst_ptr = var->GetMutable()->data(); platform::CPUPlace cpu_place; memory::Copy(cpu_place, dst_ptr, cpu_place, trg.data(), size); diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.h b/paddle/fluid/framework/details/async_ssa_graph_executor.h index ae7b81e6ada75..bca1f0b460ff4 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.h @@ -32,7 +32,7 @@ struct VarInfo { bool persistable_; }; -class AsyncSSAGraphExecutor : public SSAGraphExecutor { +class AsyncSSAGraphExecutor final : public SSAGraphExecutor { public: AsyncSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, diff --git a/paddle/fluid/framework/details/broadcast_op_handle.h b/paddle/fluid/framework/details/broadcast_op_handle.h index c41ed77f0e274..2b685d62c6d94 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.h +++ b/paddle/fluid/framework/details/broadcast_op_handle.h @@ -35,7 +35,7 @@ class Node; } // namespace framework namespace platform { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -struct NCCLContextMap; +class NCCLContextMap; #endif #if defined(PADDLE_WITH_XPU_BKCL) struct BKCLContextMap; diff --git a/paddle/fluid/framework/details/exception_holder.h b/paddle/fluid/framework/details/exception_holder.h index 1fb802b3f651d..5f5f4f65b8fc9 100644 --- a/paddle/fluid/framework/details/exception_holder.h +++ b/paddle/fluid/framework/details/exception_holder.h @@ -41,7 +41,7 @@ class ExceptionHolder { } catch (std::exception& ex) { Catch(ex); } catch (...) { - LOG(FATAL) << "Unknown exception caught."; + PADDLE_THROW(phi::errors::Fatal("Unknown exception caught.")); } } diff --git a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc index 19cf30d24db40..66c62085faed2 100644 --- a/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.cc @@ -49,8 +49,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor( /*disable_setting_default_stream_for_allocator=*/true, /*stream_priority=*/0); if (ir::IsTopologySortOperationsUnique(*graph_)) { - VLOG(10) - << "Change thread number to 1 because the toposort order is unique"; + VLOG(10) << "Change thread number to 1 because the topology sort order is " + "unique"; strategy_.num_threads_ = 1; traced_ops_.clear(); for (auto *op_node : TopologySortOperations(*graph_)) { diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 27be4b7717635..25108148af349 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -39,7 +39,7 @@ FetchOpHandle::~FetchOpHandle() = default; void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) { PADDLE_THROW(platform::errors::PermissionDenied( - "No nodes need to wait FetchOp. Unexpceted Error.")); + "No nodes need to wait FetchOp. Unexpected Error.")); } static void CheckDims(const framework::DDim &tensor_dims, diff --git a/paddle/fluid/framework/details/fused_broadcast_op_handle.h b/paddle/fluid/framework/details/fused_broadcast_op_handle.h index 18eab1ed688b5..5ff89f71a6557 100644 --- a/paddle/fluid/framework/details/fused_broadcast_op_handle.h +++ b/paddle/fluid/framework/details/fused_broadcast_op_handle.h @@ -32,7 +32,7 @@ class Node; } // namespace ir } // namespace framework namespace platform { -struct NCCLContextMap; +class NCCLContextMap; } // namespace platform } // namespace paddle diff --git a/paddle/fluid/framework/details/graph_test_base.h b/paddle/fluid/framework/details/graph_test_base.h index 2f50556e771ee..09d7dcc863aed 100644 --- a/paddle/fluid/framework/details/graph_test_base.h +++ b/paddle/fluid/framework/details/graph_test_base.h @@ -44,7 +44,7 @@ class DummyOp : public OperatorBase { class SumOpMaker : public OpProtoAndCheckerMaker { public: - void Make() { + void Make() override { AddInput("X", "").AsDuplicable(); AddOutput("Out", ""); AddComment(""); @@ -53,7 +53,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker { class AssignOpMaker : public OpProtoAndCheckerMaker { public: - void Make() { + void Make() override { AddInput("X", "").AsDuplicable(); AddOutput("Out", ""); AddComment(""); @@ -62,7 +62,7 @@ class AssignOpMaker : public OpProtoAndCheckerMaker { class SplitOpMaker : public OpProtoAndCheckerMaker { public: - void Make() { + void Make() override { AddInput("X", ""); AddOutput("Out", "").AsDuplicable(); AddComment(""); diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index 551a10f1ccacd..d18cee16b19a6 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -264,7 +264,7 @@ void CheckOpHasNanOrInf(const framework::OperatorBase& op, if (IsSkipOp(op)) return; - if (op_var_nan_inf_white_list().count(op.Type()) == 0) { + if (op_var_nan_inf_white_list().count(op.Type()) == 0) { // NOLINT // NOTE. vname may destruct in the end of this func. for (auto& vname : op.OutputVars(true)) { auto* var = exec_scope.FindVar(vname); diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index 6da7f9f8c2041..7a137b050bed7 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -64,7 +64,9 @@ class OpHandleBase { virtual bool GetSkipRunning() const { return skip_running_; } - virtual void SetSkipRunning(bool skip_runing) { skip_running_ = skip_runing; } + virtual void SetSkipRunning(bool skip_running) { + skip_running_ = skip_running; + } virtual std::string Name() const = 0; diff --git a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h index 88c8b1cbfb294..3414c7361e040 100644 --- a/paddle/fluid/framework/details/parallel_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/parallel_ssa_graph_executor.h @@ -27,7 +27,7 @@ namespace paddle { namespace framework { namespace details { -class ParallelSSAGraphExecutor : public SSAGraphExecutor { +class ParallelSSAGraphExecutor final : public SSAGraphExecutor { public: enum FeedStatus { kNone = 0, // No feed diff --git a/paddle/fluid/framework/details/reduce_op_handle.h b/paddle/fluid/framework/details/reduce_op_handle.h index 2eb0ad2923211..166bd2c0f2861 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.h +++ b/paddle/fluid/framework/details/reduce_op_handle.h @@ -36,7 +36,7 @@ class Node; } // namespace ir } // namespace framework namespace platform { -struct NCCLContextMap; +class NCCLContextMap; } // namespace platform } // namespace paddle #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index 9351b8c0c31a3..801280108b9b5 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -34,7 +34,7 @@ namespace paddle { namespace framework { namespace details { -struct ScaleLossGradOpHandle : public OpHandleBase { +struct ScaleLossGradOpHandle final : public OpHandleBase { ScaleLossGradOpHandle(ir::Node *node, size_t num_dev, Scope *scope, diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 9d275b0fd4c2e..355b179599ce9 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -70,7 +70,7 @@ static void RunProgramDescs(const ProgramDescs &programs, FetchResultType ScopeBufferedSSAGraphExecutor::Run( const std::vector &fetch_tensors, bool return_merged) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::IsCUDAGraphCapturing()) { strategy_.num_iteration_per_drop_scope_ = std::numeric_limits::max(); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h index 4a94dd917540c..0633bffd5bdfb 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.h +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.h @@ -47,7 +47,7 @@ struct OpDependentData { size_t num_ops_{0}; }; -class ThreadedSSAGraphExecutor : public SSAGraphExecutor { +class ThreadedSSAGraphExecutor final : public SSAGraphExecutor { public: ThreadedSSAGraphExecutor(const ExecutionStrategy &strategy, const std::vector &local_scopes, diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index bf83e965f3887..da794486ae866 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -387,31 +387,31 @@ void DeviceWorker::DumpField(const Scope& scope, VLOG(3) << dims.size() << " " << dims[0] << " * " << dims[1]; continue; } - size_t acutal_thread_num = + size_t actual_thread_num = std::min(static_cast(batch_size), tensor_iterator_thread_num); - for (size_t i = 0; i < acutal_thread_num; i++) { - size_t average_size = batch_size / acutal_thread_num; + for (size_t i = 0; i < actual_thread_num; i++) { + size_t average_size = batch_size / actual_thread_num; size_t begin = - average_size * i + std::min(batch_size % acutal_thread_num, i); + average_size * i + std::min(batch_size % actual_thread_num, i); size_t end = - begin + average_size + (i < batch_size % acutal_thread_num ? 1 : 0); + begin + average_size + (i < batch_size % actual_thread_num ? 1 : 0); threads[i] = std::thread(set_output_str, begin, end, tensor); } - for (size_t i = 0; i < acutal_thread_num; i++) threads[i].join(); + for (size_t i = 0; i < actual_thread_num; i++) threads[i].join(); } auto end1 = std::chrono::steady_clock::now(); auto tt = std::chrono::duration_cast(end1 - start1); VLOG(2) << "writing a batch takes " << tt.count() << " us"; - size_t acutal_thread_num = + size_t actual_thread_num = std::min(static_cast(batch_size), tensor_iterator_thread_num); - for (size_t i = 0; i < acutal_thread_num; i++) { - size_t average_size = batch_size / acutal_thread_num; + for (size_t i = 0; i < actual_thread_num; i++) { + size_t average_size = batch_size / actual_thread_num; size_t begin = - average_size * i + std::min(batch_size % acutal_thread_num, i); + average_size * i + std::min(batch_size % actual_thread_num, i); size_t end = - begin + average_size + (i < batch_size % acutal_thread_num ? 1 : 0); + begin + average_size + (i < batch_size % actual_thread_num ? 1 : 0); for (size_t j = begin + 1; j < end; j++) { if (!ars[begin].empty() && !ars[j].empty()) ars[begin] += "\n"; ars[begin] += ars[j]; diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index d7714808ff08a..f288494549ce4 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -44,7 +44,7 @@ limitations under the License. */ #include "paddle/fluid/operators/reader/blocking_queue.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/timer.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace paddle { namespace framework { @@ -60,20 +60,21 @@ class Scope; namespace paddle { namespace framework { -std::string PrintLodTensor(phi::DenseTensor* tensor, - int64_t start, - int64_t end, - char separator = ',', - bool need_leading_separator = false); -void PrintLodTensor(phi::DenseTensor* tensor, - int64_t start, - int64_t end, - std::string& output_str, // NOLINT - char separator = ',', - bool need_leading_separator = false, - int num_decimals = 9); -std::pair GetTensorBound(phi::DenseTensor* tensor, int index); -bool CheckValidOutput(phi::DenseTensor* tensor, size_t batch_size); +TEST_API std::string PrintLodTensor(phi::DenseTensor* tensor, + int64_t start, + int64_t end, + char separator = ',', + bool need_leading_separator = false); +TEST_API void PrintLodTensor(phi::DenseTensor* tensor, + int64_t start, + int64_t end, + std::string& output_str, // NOLINT + char separator = ',', + bool need_leading_separator = false, + int num_decimals = 9); +TEST_API std::pair GetTensorBound(phi::DenseTensor* tensor, + int index); +TEST_API bool CheckValidOutput(phi::DenseTensor* tensor, size_t batch_size); class FleetWrapper; diff --git a/paddle/fluid/framework/dist_multi_trainer.cc b/paddle/fluid/framework/dist_multi_trainer.cc index 6fd95267ef6ab..119b6e569cef3 100644 --- a/paddle/fluid/framework/dist_multi_trainer.cc +++ b/paddle/fluid/framework/dist_multi_trainer.cc @@ -157,7 +157,7 @@ void DistMultiTrainer::Run() { std::vector> wait_futures; CHECK_EQ(static_cast(pool.size()), thread_num_); for (int i = 0; i < thread_num_; ++i) { - if (!debug_) { + if (!debug_) { // NOLINT wait_futures.emplace_back( pool[i]->Run([this, i]() { workers_[i]->TrainFiles(); })); } else { diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 27c7a7a7af276..8c6795bac3a95 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -82,7 +82,9 @@ message PpConfig { optional bool sharding_comm_overlap = 4 [ default = false ]; optional bool profiling = 5 [ default = false ]; optional bool release_gradients = 6 [ default = false ]; - optional bool overlap_p2p_comm = 7 [default = true]; + optional bool overlap_p2p_comm = 7 [default = false]; + optional bool clear_every_step_cache = 8 [default = false]; + optional bool use_batch_p2p_comm = 9 [default = true]; } message DygraphShardingConfig { @@ -91,6 +93,7 @@ message DygraphShardingConfig { optional bool comm_overlap = 3 [ default = false ]; optional bool split_param = 4 [ default = false ]; optional bool fuse_optimizer = 5 [ default = true ]; + optional bool use_reduce_avg = 6 [ default = true ]; } message HybridConfig { diff --git a/paddle/fluid/framework/dlpack_tensor.h b/paddle/fluid/framework/dlpack_tensor.h index 943ee88b67695..f39d91b84ee3d 100644 --- a/paddle/fluid/framework/dlpack_tensor.h +++ b/paddle/fluid/framework/dlpack_tensor.h @@ -28,7 +28,8 @@ class DLPackTensor { std::remove_reference::type; // int64_t // lanes is only used in CPU to enable vectorization - explicit DLPackTensor(const phi::DenseTensor& tensor, LaneType lanes = 1); + TEST_API explicit DLPackTensor(const phi::DenseTensor& tensor, + LaneType lanes = 1); inline operator const ::DLTensor&() const { return t_; } diff --git a/paddle/fluid/framework/downpour_lite_worker.cc b/paddle/fluid/framework/downpour_lite_worker.cc index 3d453c018c1d5..e86856bf1b2ff 100644 --- a/paddle/fluid/framework/downpour_lite_worker.cc +++ b/paddle/fluid/framework/downpour_lite_worker.cc @@ -410,7 +410,8 @@ void DownpourLiteWorker::TrainFilesWithProfiler() { fprintf(stderr, "push dense time percent: %f\n", push_dense_time / total_time * 100); - fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time); + fprintf( + stderr, "%6.2f instances/s\n", total_inst / total_time); // NOLINT } } timeline.Start(); diff --git a/paddle/fluid/framework/downpour_worker.cc b/paddle/fluid/framework/downpour_worker.cc index 6ce2967a08f1f..0d5bd66297c53 100644 --- a/paddle/fluid/framework/downpour_worker.cc +++ b/paddle/fluid/framework/downpour_worker.cc @@ -334,8 +334,9 @@ void DownpourWorker::AdjustInsWeight() { } float ins_weight = 1.0; if (nid_show >= 0 && nid_show < nid_adjw_threshold) { - ins_weight = log(M_E + (nid_adjw_threshold - nid_show) / - nid_adjw_threshold * nid_adjw_ratio); + ins_weight = static_cast( + log(M_E + (nid_adjw_threshold - nid_show) / nid_adjw_threshold * + nid_adjw_ratio)); // count nid adjw insnum and weight ++nid_adjw_num; nid_adjw_weight += ins_weight; diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index d935e9ea066bd..fbc2565e755fa 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -99,7 +99,7 @@ void Executor::CreateVariables(const ProgramDesc& pdesc, while (ancestor_scope->parent()) { ancestor_scope = ancestor_scope->parent(); } - if (ancestor_scope != scope) { + if (ancestor_scope != scope) { // NOLINT for (auto& var : global_block.AllVars()) { if (var->Name() == framework::kEmptyVarName) { continue; diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 0d6e4ea09c47a..0be2a603502cb 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" -#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/fluid/pir/transforms/general/inplace_pass.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" @@ -312,9 +312,8 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( int64_t program_id, framework::Scope *scope, const int64_t &place_hash_key) { - auto &interpretercore_info_cache = - framework::InterpreterCoreInfoCache::Instance(); - if (interpretercore_info_cache.Size() > 256000u /* max_cached_size*/) { + auto &cache = framework::InterpreterCoreInfoCache::Instance(); + if (cache.Size() > 256000u /* max_cached_size*/) { PADDLE_THROW(platform::errors::Fatal( "The cached info size has exceeded max_cached_size: 256000, " "which will cause error. ")); @@ -328,7 +327,7 @@ std::shared_ptr CreateProgramInterpreterCoreInfoToCache( core.reset(new InterpreterCore( place, program_desc.Block(0), scope, execution_config)); - auto &cached_value = interpretercore_info_cache.GetMutable( + auto &cached_value = cache.GetMutable( program_id, scope, place_hash_key, is_grad, /*in_pir_mode=*/false); cached_value.core_ = core; return core; @@ -341,9 +340,8 @@ std::shared_ptr CreatePirInterpreterCoreInfoToCache( int64_t program_id, framework::Scope *scope, const int64_t &place_hash_key) { - auto &interpretercore_info_cache = - framework::InterpreterCoreInfoCache::Instance(); - if (interpretercore_info_cache.Size() > 256000u /* max_cached_size*/) { + auto &cache = framework::InterpreterCoreInfoCache::Instance(); + if (cache.Size() > 256000u /* max_cached_size*/) { PADDLE_THROW(platform::errors::Fatal( "The cached info size has exceeded max_cached_size: 256000, " "which will cause error. ")); @@ -357,7 +355,7 @@ std::shared_ptr CreatePirInterpreterCoreInfoToCache( core.reset(new InterpreterCore( place, {}, ir_program->block(), scope, execution_config)); - auto &cached_value = interpretercore_info_cache.GetMutable( + auto &cached_value = cache.GetMutable( program_id, scope, place_hash_key, is_grad, /*in_pir_mode=*/true); cached_value.core_ = core; cached_value.ir_prog_ = std::move(ir_program); diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index 10ca69f42862e..f9afaabec79dc 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -27,7 +27,7 @@ #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/pir/include/core/dialect.h" diff --git a/paddle/fluid/framework/fleet/box_wrapper.h b/paddle/fluid/framework/fleet/box_wrapper.h index 2dee617925773..33b861f892c51 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.h +++ b/paddle/fluid/framework/fleet/box_wrapper.h @@ -45,7 +45,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/timer.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #define BUF_SIZE 1024 * 1024 extern void comlog_set_log_level(int log_level); diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.cc b/paddle/fluid/framework/fleet/gloo_wrapper.cc index 277004b6dc164..fbd16f0a1f592 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.cc +++ b/paddle/fluid/framework/fleet/gloo_wrapper.cc @@ -12,7 +12,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/gloo_wrapper.h" #include "paddle/fluid/framework/io/fs.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace gloo { namespace transport { @@ -165,7 +165,7 @@ void HdfsStore::wait(const std::vector& keys, int32_t last_check_rank = -1; for (size_t i = 0; i < check_key_status.size(); ++i) { if (!check_key_status[i]) { - last_check_rank = i; + last_check_rank = static_cast(i); break; } } @@ -252,7 +252,7 @@ void ParallelConnectContext::connectFullMesh( connect_threads[i].reset(new std::thread( [&store, &transportContext, total_add_size, this]( size_t thread_idx, size_t thread_num) -> void { - for (int i = thread_idx; i < size; i += thread_num) { + for (int i = thread_idx; i < size; i += thread_num) { // NOLINT if (i == rank) { continue; } diff --git a/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h b/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h index 6e7d0ba9ca734..ac915ed547fb7 100644 --- a/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h +++ b/paddle/fluid/framework/fleet/heter_ps/graph_sampler.h @@ -28,8 +28,8 @@ #include "paddle/fluid/distributed/ps/table/common_graph_table.h" #include "paddle/fluid/framework/fleet/heter_ps/gpu_graph_node.h" #include "paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/printf.h" +#include "paddle/utils/string/string_helper.h" #ifdef PADDLE_WITH_HETERPS namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu b/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu index be4ea8137194c..595ace5368f9b 100644 --- a/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu +++ b/paddle/fluid/framework/fleet/heter_ps/test_sample_rate.cu @@ -43,8 +43,8 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/string/printf.h" using paddle::framework; namespace platform = paddle::platform; diff --git a/paddle/fluid/framework/fleet/metrics.cc b/paddle/fluid/framework/fleet/metrics.cc index 58e1e195fbab7..57fe43fb44624 100644 --- a/paddle/fluid/framework/fleet/metrics.cc +++ b/paddle/fluid/framework/fleet/metrics.cc @@ -219,7 +219,7 @@ void BasicAucCalculator::calculate_bucket_error() { } } } else { - double* table[2] = {&_table[0][0], &_table[1][0]}; + double* table[2] = {&_table[0][0], &_table[1][0]}; // NOLINT for (int i = 0; i < _table_size; i++) { double click = table[1][i]; double show = table[0][i] + table[1][i]; @@ -301,7 +301,7 @@ void BasicAucCalculator::add_uid_unlock_data(double pred, WuaucRecord record; record.uid_ = uid; record.label_ = label; - record.pred_ = pred; + record.pred_ = static_cast(pred); wuauc_records_.emplace_back(std::move(record)); } diff --git a/paddle/fluid/framework/fleet/metrics.h b/paddle/fluid/framework/fleet/metrics.h index 700a1cece17f3..91b25ce132a1a 100644 --- a/paddle/fluid/framework/fleet/metrics.h +++ b/paddle/fluid/framework/fleet/metrics.h @@ -32,7 +32,7 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/timer.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #if defined(PADDLE_WITH_GLOO) #include diff --git a/paddle/fluid/framework/framework.proto b/paddle/fluid/framework/framework.proto index 545286fb04a5b..1f4414af3c07f 100644 --- a/paddle/fluid/framework/framework.proto +++ b/paddle/fluid/framework/framework.proto @@ -18,7 +18,7 @@ package paddle.framework.proto; // Any incompatible changes to ProgramDesc and its dependencies should // raise the version defined version.h. // -// Serailization and Deserialization codes should be modified in a way +// Serialization and Deserialization codes should be modified in a way // that supports old versions following the version and compatibility policy. message Version { optional int64 version = 1 [ default = 0 ]; } diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index dd795e190bdd2..dcfe096edf7b0 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -106,7 +106,7 @@ class GradOpDescMakerBase { "BUG from operator developer:" " for input argument with a list of variables, " " drop_empty_grad is not allowed because it makes" - " the correspondence bewteen a variable and its gradient" + " the correspondence between a variable and its gradient" " ambiguous.")); std::vector dropped_ret_val; diff --git a/paddle/fluid/framework/heter_section_worker.cc b/paddle/fluid/framework/heter_section_worker.cc index 65902f6c2d0c7..09e14bff65596 100644 --- a/paddle/fluid/framework/heter_section_worker.cc +++ b/paddle/fluid/framework/heter_section_worker.cc @@ -126,9 +126,9 @@ void HeterSectionWorker::Initialize(const TrainerDesc& desc) { bool is_first_stage = (pipeline_stage_ == 0); bool is_last_stage = (pipeline_stage_ + 1 == num_pipeline_stages_); - if (is_first_stage) { + if (is_first_stage) { // NOLINT for (auto& op_desc : program_->Block(0).AllOps()) { - auto op = std::move(OpRegistry::CreateOp(*op_desc)); + auto op = OpRegistry::CreateOp(*op_desc); auto op_type = op->Type(); if (listen_op_ == nullptr && op_type == "heter_listen_and_serv") { listen_op_ = std::move(op); @@ -142,11 +142,11 @@ void HeterSectionWorker::Initialize(const TrainerDesc& desc) { } else if (is_last_stage) { for (auto& op_desc : program_->Block(0).AllOps()) { if (listen_op_ == nullptr) { - listen_op_ = std::move(OpRegistry::CreateOp(*op_desc)); + listen_op_ = OpRegistry::CreateOp(*op_desc); } } for (auto& op_desc : program_->Block(1).AllOps()) { - auto op = std::move(OpRegistry::CreateOp(*op_desc)); + auto op = OpRegistry::CreateOp(*op_desc); int op_role = op->Attr(std::string("op_role")); bool is_forward_op = (op_role == static_cast(OpRole::kForward)) || (op_role == (static_cast(OpRole::kForward) | @@ -161,7 +161,7 @@ void HeterSectionWorker::Initialize(const TrainerDesc& desc) { } else { for (auto& op_desc : program_->Block(0).AllOps()) { if (listen_op_ == nullptr) { - listen_op_ = std::move(OpRegistry::CreateOp(*op_desc)); + listen_op_ = OpRegistry::CreateOp(*op_desc); } } for (auto& op_desc : program_->Block(1).AllOps()) { @@ -507,7 +507,7 @@ void HeterSectionWorker::PrintFetchVars() { if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) { time_t curtime; time(&curtime); - char mbstr[80]; + char mbstr[80]; // NOLINT std::strftime( mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", std::localtime(&curtime)); std::stringstream ss; diff --git a/paddle/fluid/framework/heter_service.proto b/paddle/fluid/framework/heter_service.proto index b1edbedf927ed..fd8a63bf56e96 100644 --- a/paddle/fluid/framework/heter_service.proto +++ b/paddle/fluid/framework/heter_service.proto @@ -24,8 +24,8 @@ enum VarType { // VariableMessage is serialized paddle variable message. // NOTICE(gongwb):don't modify this proto if you are not -// not familar with how we serialize in sendrecvop_utils.h -// and deserilize it in variable_response.h. +// not familiar with how we serialize in sendrecvop_utils.h +// and deserialize it in variable_response.h. message VariableMessage { enum Type { // Pod Types diff --git a/paddle/fluid/framework/hetercpu_worker.cc b/paddle/fluid/framework/hetercpu_worker.cc index 0959b0ae33442..77cc1bc9f8ad6 100644 --- a/paddle/fluid/framework/hetercpu_worker.cc +++ b/paddle/fluid/framework/hetercpu_worker.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/fleet/heter_wrapper.h" #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS) diff --git a/paddle/fluid/framework/infershape_utils.cc b/paddle/fluid/framework/infershape_utils.cc index bcf72be80decb..37352b4d47138 100644 --- a/paddle/fluid/framework/infershape_utils.cc +++ b/paddle/fluid/framework/infershape_utils.cc @@ -559,16 +559,15 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, for (auto& in_name : input_names) { if (ctx->HasInputs(in_name)) { - auto input_var = std::move(ctx->GetInputVarPtrs(in_name)); + auto input_var = ctx->GetInputVarPtrs(in_name); if (input_var.size() == 1) { infer_meta_context.EmplaceBackInput( - std::move(CompatMetaTensor(input_var[0], ctx->IsRuntime()))); + CompatMetaTensor(input_var[0], ctx->IsRuntime())); } else { paddle::small_vector inputs; for (const auto& in : input_var) { - inputs.emplace_back( - std::move(CompatMetaTensor(in, ctx->IsRuntime()))); + inputs.emplace_back(CompatMetaTensor(in, ctx->IsRuntime())); } infer_meta_context.EmplaceBackInputs(std::move(inputs)); } @@ -576,8 +575,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, // Note: Because the input of InferMetaFn is const MetaTensor&, // so when we prepare input MetaTensor by InferMetaContext->InputAt(), // we need to return a const reference of empty MetaTensor - infer_meta_context.EmplaceBackInput( - std::move(CompatMetaTensor(ctx->IsRuntime()))); + infer_meta_context.EmplaceBackInput(CompatMetaTensor(ctx->IsRuntime())); } } @@ -631,7 +629,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name)); } } else if (ctx->HasInput(attr_name)) { - auto infershape_input = std::move(ctx->GetInputVarPtrs(attr_name)); + auto infershape_input = ctx->GetInputVarPtrs(attr_name); if (infershape_input.size() == 1) { if (ctx->IsRuntime()) { Variable* var = PADDLE_GET_CONST(Variable*, infershape_input[0]); @@ -658,13 +656,13 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, if (attr_ptr && !is_attr_var) { auto& attr = *attr_ptr; switch (AttrTypeID(attr)) { - case framework::proto::AttrType::INTS: - infer_meta_context.EmplaceBackAttr(std::move( - phi::IntArray(PADDLE_GET_CONST(std::vector, attr)))); + case framework::proto::AttrType::INTS: // NOLINT + infer_meta_context.EmplaceBackAttr( + phi::IntArray(PADDLE_GET_CONST(std::vector, attr))); break; case framework::proto::AttrType::LONGS: - infer_meta_context.EmplaceBackAttr(std::move( - phi::IntArray(PADDLE_GET_CONST(std::vector, attr)))); + infer_meta_context.EmplaceBackAttr( + phi::IntArray(PADDLE_GET_CONST(std::vector, attr))); break; case framework::proto::AttrType::INT: infer_meta_context.EmplaceBackAttr( @@ -677,7 +675,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_name)); } } else if (ctx->HasInputs(attr_name) || ctx->HasInput(attr_name)) { - auto infershape_inputs = std::move(ctx->GetInputVarPtrs(attr_name)); + auto infershape_inputs = ctx->GetInputVarPtrs(attr_name); if (ctx->IsRuntime()) { // If is in runtime, we will get tensor's value for IntArray // and push it into attrs @@ -688,10 +686,10 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, } if (infershape_inputs.size() != 1) { infer_meta_context.EmplaceBackAttr( - std::move(framework::MakePhiIntArrayFromVarList(vars))); + framework::MakePhiIntArrayFromVarList(vars)); } else { infer_meta_context.EmplaceBackAttr( - std::move(framework::MakePhiIntArrayFromVar(*vars[0]))); + framework::MakePhiIntArrayFromVar(*vars[0])); } } else { // If is not in runtime, we will set default value(-1) for IntArray @@ -836,7 +834,7 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, attr_names[i])); } break; - case phi::AttributeType::FLOAT32S: + case phi::AttributeType::FLOAT32S: // NOLINT infer_meta_context.EmplaceBackAttr( PADDLE_GET_CONST(std::vector, attr)); break; @@ -868,32 +866,29 @@ CompatInferMetaContext BuildInferMetaContext(InferShapeContext* ctx, for (auto& out_name : output_names) { if (ctx->HasOutputs(out_name, true)) { - auto output_var = std::move(ctx->GetOutputVarPtrs(out_name)); + auto output_var = ctx->GetOutputVarPtrs(out_name); if (output_var.size() == 1) { infer_meta_context.EmplaceBackOutput( - std::move(CompatMetaTensor(output_var[0], ctx->IsRuntime()))); + CompatMetaTensor(output_var[0], ctx->IsRuntime())); } else { paddle::small_vector outputs; for (const auto& out : output_var) { if (ctx->IsRuntime()) { if (PADDLE_GET_CONST(Variable*, out)) { - outputs.emplace_back( - std::move(CompatMetaTensor(out, ctx->IsRuntime()))); + outputs.emplace_back(CompatMetaTensor(out, ctx->IsRuntime())); continue; } } else if (PADDLE_GET_CONST(VarDesc*, out)) { - outputs.emplace_back( - std::move(CompatMetaTensor(out, ctx->IsRuntime()))); + outputs.emplace_back(CompatMetaTensor(out, ctx->IsRuntime())); continue; } - outputs.emplace_back(std::move(CompatMetaTensor(ctx->IsRuntime()))); + outputs.emplace_back(CompatMetaTensor(ctx->IsRuntime())); } infer_meta_context.EmplaceBackOutputs(std::move(outputs)); } } else { - infer_meta_context.EmplaceBackOutput( - std::move(CompatMetaTensor(ctx->IsRuntime()))); + infer_meta_context.EmplaceBackOutput(CompatMetaTensor(ctx->IsRuntime())); } } diff --git a/paddle/fluid/framework/io/crypto/aes_cipher.cc b/paddle/fluid/framework/io/crypto/aes_cipher.cc index 8802dc1b12158..158d25a6957f7 100644 --- a/paddle/fluid/framework/io/crypto/aes_cipher.cc +++ b/paddle/fluid/framework/io/crypto/aes_cipher.cc @@ -65,7 +65,7 @@ std::string AESCipher::EncryptInternal(const std::string& plaintext, std::string ciphertext; m_filter->Attach(new CryptoPP::StringSink(ciphertext)); CryptoPP::Redirector* filter_redirector = new CryptoPP::Redirector(*m_filter); - CryptoPP::StringSource(plaintext, true, filter_redirector); + CryptoPP::StringSource ss(plaintext, true, filter_redirector); if (need_iv) { return iv_ + ciphertext; } @@ -96,7 +96,7 @@ std::string AESCipher::DecryptInternal(const std::string& ciphertext, std::string plaintext; m_filter->Attach(new CryptoPP::StringSink(plaintext)); CryptoPP::Redirector* filter_redirector = new CryptoPP::Redirector(*m_filter); - CryptoPP::StringSource( + CryptoPP::StringSource ss( ciphertext.substr(ciphertext_beg), true, filter_redirector); return plaintext; @@ -124,7 +124,7 @@ std::string AESCipher::AuthenticatedEncryptInternal( std::string ciphertext; m_filter->Attach(new CryptoPP::StringSink(ciphertext)); CryptoPP::Redirector* filter_redirector = new CryptoPP::Redirector(*m_filter); - CryptoPP::StringSource(plaintext, true, filter_redirector); + CryptoPP::StringSource ss(plaintext, true, filter_redirector); if (need_iv) { ciphertext = iv_.append(ciphertext); } @@ -155,7 +155,7 @@ std::string AESCipher::AuthenticatedDecryptInternal( std::string plaintext; m_filter->Attach(new CryptoPP::StringSink(plaintext)); CryptoPP::Redirector* filter_redirector = new CryptoPP::Redirector(*m_filter); - CryptoPP::StringSource( + CryptoPP::StringSource ss( ciphertext.substr(ciphertext_beg), true, filter_redirector); PADDLE_ENFORCE_EQ( m_filter->GetLastResult(), diff --git a/paddle/fluid/framework/io/fs.h b/paddle/fluid/framework/io/fs.h index 842f816d85792..cfff4f1d31790 100644 --- a/paddle/fluid/framework/io/fs.h +++ b/paddle/fluid/framework/io/fs.h @@ -23,7 +23,7 @@ #include "glog/logging.h" #include "paddle/fluid/framework/io/shell.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/io/save_load_tensor.cc b/paddle/fluid/framework/io/save_load_tensor.cc index 2ed37b6aa3874..b8a52e9c44fbf 100644 --- a/paddle/fluid/framework/io/save_load_tensor.cc +++ b/paddle/fluid/framework/io/save_load_tensor.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/io/save_paddle2cinn_varmap.cc b/paddle/fluid/framework/io/save_paddle2cinn_varmap.cc index 02587e0cfc21d..f4debede0a616 100644 --- a/paddle/fluid/framework/io/save_paddle2cinn_varmap.cc +++ b/paddle/fluid/framework/io/save_paddle2cinn_varmap.cc @@ -13,7 +13,7 @@ limitations under the License. */ #include #include #include "glog/logging.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "paddle/phi/core/enforce.h" namespace paddle { diff --git a/paddle/fluid/framework/io/save_runtime_graph.cc b/paddle/fluid/framework/io/save_runtime_graph.cc index cfb03cca8d4ed..6d06fff535620 100644 --- a/paddle/fluid/framework/io/save_runtime_graph.cc +++ b/paddle/fluid/framework/io/save_runtime_graph.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/node.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/io/shell.cc b/paddle/fluid/framework/io/shell.cc index cc893fefbb34f..fa449c1b10867 100644 --- a/paddle/fluid/framework/io/shell.cc +++ b/paddle/fluid/framework/io/shell.cc @@ -58,7 +58,7 @@ static int close_open_fds_internal() { long d_ino = 0; // NOLINT off_t d_off; unsigned short d_reclen = 0; // NOLINT - char d_name[256]; + char d_name[256]; // NOLINT }; int dir_fd = -1; @@ -66,7 +66,7 @@ static int close_open_fds_internal() { PADDLE_THROW(platform::errors::Unavailable("Failed to open proc/self/fd.")); return -1; } - char buffer[sizeof(linux_dirent)]; + char buffer[sizeof(linux_dirent)]; // NOLINT for (;;) { int bytes = 0; @@ -187,8 +187,8 @@ std::shared_ptr shell_popen(const std::string& cmd, std::string real_cmd = "set -o pipefail; " + cmd; - int pipe_fds[2]; - if (pipe(pipe_fds) != 0) { + std::array pipe_fds; + if (pipe(pipe_fds.data()) != 0) { *err_no = -1; return nullptr; } @@ -300,17 +300,17 @@ std::pair, std::shared_ptr> shell_p2open( std::string real_cmd = "set -o pipefail; " + cmd; - int pipein_fds[2]; - int pipeout_fds[2]; - if (pipe(pipein_fds) != 0) { + std::array pipein_fds; + std::array pipeout_fds; + if (pipe(pipein_fds.data()) != 0) { return {nullptr, nullptr}; } - if (pipe(pipeout_fds) != 0) { + if (pipe(pipeout_fds.data()) != 0) { return {nullptr, nullptr}; } - int child_pid = - shell_p2open_fork_internal(real_cmd.c_str(), pipein_fds, pipeout_fds); + int child_pid = shell_p2open_fork_internal( + real_cmd.c_str(), pipein_fds.data(), pipeout_fds.data()); close(pipein_fds[1]); close(pipeout_fds[0]); diff --git a/paddle/fluid/framework/io/shell.h b/paddle/fluid/framework/io/shell.h index 487c2aa95d05a..2b99adeb277a0 100644 --- a/paddle/fluid/framework/io/shell.h +++ b/paddle/fluid/framework/io/shell.h @@ -38,8 +38,8 @@ #include #include -#include "paddle/fluid/string/string_helper.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" +#include "paddle/utils/string/string_helper.h" #if defined(__arm__) || defined(__aarch64__) || defined(__ARM_NEON) || \ defined(__ARM_NEON__) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 765fa1779b0e5..cb8093298d9bb 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -322,6 +322,8 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(quant_dequant_xpu_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(roformer_relative_pos_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( diff --git a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc index a05a096daf928..f1657d4db5fdc 100644 --- a/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc +++ b/paddle/fluid/framework/ir/auto_mixed_precision_pass.cc @@ -96,7 +96,8 @@ inline bool VarNodeHasDtype(Node* var_node) { auto type = var_node->Var()->GetType(); return (type == VarType::SELECTED_ROWS) || (type == VarType::LOD_TENSOR) || (type == VarType::LOD_TENSOR_ARRAY) || (type == VarType::STRINGS) || - (type == VarType::VOCAB); + (type == VarType::VOCAB) || (type == VarType::SPARSE_COO) || + (type == VarType::SPARSE_CSR); } inline bool IsFP32(VarType::Type type) { return type == VarType::FP32; } @@ -123,12 +124,21 @@ void DoInsertCastOp(Graph* graph, const std::string& x_name, const std::string& out_name, const int in_dtype, - const int out_dtype) { - desc.SetType("cast"); - desc.SetInput("X", {x_name}); - desc.SetOutput("Out", {out_name}); - desc.SetAttr("in_dtype", in_dtype); - desc.SetAttr("out_dtype", out_dtype); + const int out_dtype, + const VarType::Type t) { + if (t == VarType::SPARSE_COO || t == VarType::SPARSE_CSR) { + desc.SetType("sparse_cast"); + desc.SetInput("x", {x_name}); + desc.SetOutput("out", {out_name}); + desc.SetAttr("index_dtype", -1); + desc.SetAttr("value_dtype", to_type); + } else { + desc.SetType("cast"); + desc.SetInput("X", {x_name}); + desc.SetOutput("Out", {out_name}); + desc.SetAttr("in_dtype", in_dtype); + desc.SetAttr("out_dtype", out_dtype); + } desc.SetAttr("use_mkldnn", false); desc.SetAttr("with_quant_attr", false); desc.Flush(); @@ -140,17 +150,21 @@ void DoInsertCastOp(Graph* graph, std::string cast_output_name = var_node->Var()->Name() + "_cast_auto_mixed.tmp_" + std::to_string((*suffix)++); + VarType::Type var_type = var_node->Var()->GetType(); framework::OpDesc cast_op_desc(block_desc); update_cast_desc(cast_op_desc, cast_input_name, cast_output_name, static_cast(from_type), - static_cast(to_type)); + static_cast(to_type), + var_type); auto* cast_op_node = graph->CreateOpNode(&cast_op_desc); auto* cast_output_vardesc = block_desc->Var(cast_output_name); + cast_output_vardesc->SetType(var_type); cast_output_vardesc->SetPersistable(false); cast_output_vardesc->SetDataType(to_type); cast_output_vardesc->SetShape(var_node->Var()->GetShape()); + cast_output_vardesc->Flush(); auto* cast_output_node = graph->CreateVarNode(cast_output_vardesc); IR_NODE_LINK_TO(cast_op_node, cast_output_node); (*cache)[var_node] = cast_output_node; @@ -452,8 +466,8 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { } } - // if op's input var and output var is not dense tensor, the op should - // not run at low precision. + // op's input var and output var only support + // dense/sparse_coo/sparse_csr tensor. for (auto* in_var_node : op_node->inputs) { CHECK_EQ(in_var_node->IsVar(), true); auto* real_in_var_node = real_vars_.at(in_var_node->Var()->Name()); @@ -461,7 +475,9 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { support_low_precision = support_low_precision && - (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR); + (real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR || + real_in_var_node->Var()->GetType() == VarType::SPARSE_COO || + real_in_var_node->Var()->GetType() == VarType::SPARSE_CSR); } for (auto* out_var_node : op_node->outputs) { CHECK_EQ(out_var_node->IsVar(), true); @@ -470,7 +486,9 @@ void AutoMixedPrecisionPass::GetOpPrecision() const { support_low_precision = support_low_precision && - (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR); + (real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR || + real_out_var_node->Var()->GetType() == VarType::SPARSE_COO || + real_out_var_node->Var()->GetType() == VarType::SPARSE_CSR); } } @@ -634,6 +652,23 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } + } else if (GetOpOriginalType(op_desc->Type()) == "sparse_batch_norm") { + auto vecs = op_desc->Input("bias"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") { auto vecs = op_desc->Input("Bias"); if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { @@ -670,37 +705,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } - } - - if (backend_ == phi::Backend::XPU) { - if (GetOpOriginalType(op_desc->Type()) == "layer_norm") { - auto vecs = op_desc->Input("Bias"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Scale"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } - } else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") { - auto vecs = op_desc->Input("Bias"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } - vecs = op_desc->Input("Scale"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } - } else if (GetOpOriginalType(op_desc->Type()) == "quantize_linear" || - GetOpOriginalType(op_desc->Type()) == "dequantize_linear") { - auto vecs = op_desc->Input("Scale"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } - vecs = op_desc->Input("ZeroPoint"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } + } else if (GetOpOriginalType(op_desc->Type()) == "quantize_linear" || + GetOpOriginalType(op_desc->Type()) == "dequantize_linear") { + auto vecs = op_desc->Input("Scale"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Input("ZeroPoint"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; } } @@ -728,18 +741,36 @@ bool AutoMixedPrecisionPass::OutputVarsNotConvert( if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { return true; } - } - - if (backend_ == phi::Backend::XPU) { - if (GetOpOriginalType(op_desc->Type()) == "layer_norm") { - auto vecs = op_desc->Output("Mean"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } - vecs = op_desc->Output("Variance"); - if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { - return true; - } + } else if (GetOpOriginalType(op_desc->Type()) == "sparse_batch_norm") { + auto vecs = op_desc->Output("mean_out"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("variance_out"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("saved_mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("saved_variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("reserve_space"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + } else if (GetOpOriginalType(op_desc->Type()) == "layer_norm" || + GetOpOriginalType(op_desc->Type()) == "group_norm") { + auto vecs = op_desc->Output("Mean"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; + } + vecs = op_desc->Output("Variance"); + if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) { + return true; } } diff --git a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc index 44cb004fec172..966f4ea14967d 100644 --- a/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc +++ b/paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc @@ -134,7 +134,7 @@ class CoalesceGradTensorPass : public ir::Pass { auto &pinned_var_set = graph->GetOrInit(details::kPinnedVars); - if (IsUnifiedDtype(p_g_dense_grad, vars_info)) { + if (IsUnifiedDtype(p_g_dense_grad, vars_info)) { // NOLINT RecordGradients(p_g_dense_grad, vars_info, &pinned_var_set); CoalesceTensors(vars_info, p_g_dense_grad, &result); } else { diff --git a/paddle/fluid/framework/ir/constant_folding_pass.cc b/paddle/fluid/framework/ir/constant_folding_pass.cc index 4375043544dc8..099209db48840 100644 --- a/paddle/fluid/framework/ir/constant_folding_pass.cc +++ b/paddle/fluid/framework/ir/constant_folding_pass.cc @@ -13,9 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/ir/constant_folding_pass.h" + #include #include #include "glog/logging.h" + +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" @@ -23,8 +27,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/framework/convert_utils.h" - namespace paddle { namespace framework { namespace ir { @@ -51,6 +53,37 @@ struct ConstantFolding : public PatternBase { }; } // namespace patterns +namespace { +std::unordered_set GetControlFlowVarNames(ir::Graph *graph) { + std::unordered_set control_flow_ops{"while", + "conditional_block"}; + std::unordered_set control_flow_var_names; + for (auto *node : graph->Nodes()) { + if (!node->IsOp() || control_flow_ops.count(node->Op()->Type()) == 0) + continue; + for (auto const &in_names : node->Op()->Inputs()) { + auto var_names = in_names.second; + control_flow_var_names.insert(var_names.begin(), var_names.end()); + } + for (auto const &out_names : node->Op()->Outputs()) { + auto var_names = out_names.second; + control_flow_var_names.insert(var_names.begin(), var_names.end()); + } + } + return control_flow_var_names; +} + +bool OutputUsedByControlFlow(ir::Node *node, + const std::unordered_set &cf_vars) { + for (auto out_node : node->outputs) { + if (cf_vars.count(out_node->Name())) { + return true; + } + } + return false; +} +} // namespace + ConstantFoldingPass::ConstantFoldingPass() = default; void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { @@ -69,6 +102,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { "save", "quantize_linear", "dequantize_linear"}; + const auto cf_vars = GetControlFlowVarNames(graph); int folded_op_num = 0; auto op_node_sorted = framework::ir::TopologyVariantSort( @@ -78,7 +112,9 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { if (std::find(blacklist.begin(), blacklist.end(), op_node->Name()) != blacklist.end()) continue; - + if (OutputUsedByControlFlow(op_node, cf_vars)) { + continue; + } bool input_persis = true; // map is used to record how many time a name string occurs in the whole // graph's nodes diff --git a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc index 50ba4fa6ce110..4faebacb5f55c 100644 --- a/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bn_fuse_pass.cc @@ -421,7 +421,8 @@ void ConvBNFusePass::ApplyImpl(ir::Graph* graph) const { // without MKL-DNN fuse conv+bn into conv+elementwise_add if (is_mkldnn) { if (conv->Op()->Type() == "conv2d" || - conv->Op()->Type() == "depthwise_conv2d") { + conv->Op()->Type() == "depthwise_conv2d" || + conv->Op()->Type() == "conv2d_transpose") { ConvertToFusedOp(conv->Op()); } if (mkldnn_with_bias) { @@ -816,6 +817,48 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() { // NOLINT .AddAttr("data_format") .IsStringIn({"NCHW", "AnyLayout"}) .End(); + + AddOpCompat(OpCompat("conv2d_transpose_bias")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .IsType>() + .IsOptional() + .End() + .AddAttr("output_size") + .IsType>() + .IsOptional() + .End() + .AddAttr("groups") + .IsNumEQ(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "AnyLayout"}) + .End(); } ConvTransposeEltwiseAddBNFusePass:: diff --git a/paddle/fluid/framework/ir/cutlass_teller.h b/paddle/fluid/framework/ir/cutlass_teller.h index 3d50544ede13b..2bc829e2fc8e9 100644 --- a/paddle/fluid/framework/ir/cutlass_teller.h +++ b/paddle/fluid/framework/ir/cutlass_teller.h @@ -1,5 +1,5 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at @@ -20,8 +20,9 @@ namespace framework { namespace ir { typedef enum { - cba, - cbaa, + cba, // This servers for conv_elementwise_add_fuse_pass + cbaa, // This servers for conv_elementwise_add2_act_fuse_pass + cbaele, // This servers for conv2d_fusion_cutlass_elementwise } CutlassFusionType; class CutlassTeller { @@ -33,6 +34,7 @@ class CutlassTeller { #if defined(PADDLE_WITH_CUTLASS) // Determine this NCHW conv2d + bias can be fused with activation by cutlass? + // This servers for conv_elementwise_add_fuse_pass. // will not set or change any attribute in op_desc bool CbaCanSupport(OpDesc *op_desc, Scope *scope, @@ -85,7 +87,8 @@ class CutlassTeller { } // Determine this NCHW conv2d + bias + elewise_add + act can be fused by - // cutlass? will not set or change any attribute in op_desc + // cutlass?, this is for conv_elementwise_add_fuse_pass + // will not set or change any attribute in op_desc bool CbaaCanSupport(OpDesc *op_desc, Scope *scope, std::string act_type, @@ -136,6 +139,69 @@ class CutlassTeller { return true; } + // Determine this NCHW conv2d_fusion + elewise_op + act1 can be fused by + // cutlass? + // This servers for conv2d_fusion_cutlass_elementwise. + // will not set or change any attribute in op_desc + bool CbaeleCanSupport(OpDesc *op_desc, + Scope *scope, + std::string ele_type, + std::string act1_type, + int device_id) { + auto strides = op_desc->GetAttrIfExists>("strides"); + auto dilations = op_desc->GetAttrIfExists>("dilations"); + CHECK_EQ(strides.size() == 2UL, true); + CHECK_EQ(dilations.size() == 2UL, true); + int stride_h = strides[0]; + int stride_w = strides[1]; + int dilation_h = dilations[0]; + int dilation_w = dilations[1]; + auto act_type = op_desc->GetAttrIfExists("activation"); + + // Do not allow conv2d_fusion already have residual input. + if (op_desc->Input("ResidualData").size() >= 1) { + return false; + } + + auto filter_names = op_desc->Input("Filter"); + + for (const auto &filter_name : filter_names) { + auto *filter_var = scope->FindLocalVar(filter_name); + const auto &filter_tensor = filter_var->Get(); + CHECK_EQ(filter_tensor.dims().size() == 4UL, true); + auto groups = op_desc->GetAttrIfExists("groups"); + int oc = filter_tensor.dims()[0]; + int kc = filter_tensor.dims()[1]; + int kh = filter_tensor.dims()[2]; + int kw = filter_tensor.dims()[3]; + + // For convience, we only support EXPLICIT + auto padding_algorithm = + op_desc->GetAttrIfExists("padding_algorithm"); + if (padding_algorithm != "EXPLICIT") { + return false; + } + + if (!Conv2dCanSupport(oc, + kc, + kh, + kw, + stride_h, + stride_w, + dilation_h, + dilation_w, + groups, + act_type, + device_id, + CutlassFusionType::cbaele, + act1_type, + ele_type)) { + return false; + } + } + return true; + } + // Determine whether this conv can be fused with the activation by cutlass // backend. bool Conv2dCanSupport(int oc, @@ -149,7 +215,10 @@ class CutlassTeller { int groups, std::string activation, int device_id, - CutlassFusionType fuse_type) { + CutlassFusionType fuse_type, + // below two are used by cbaele + std::string activation1 = "identity", + std::string elemenstwise_type = "elementwise_add") { int sm_version = platform::GetGPUComputeCapability(device_id); int ic = kc * groups; if (!cutlass_sm.count(sm_version)) { @@ -173,6 +242,14 @@ class CutlassTeller { !cbaa_act_set.count(activation)) { return false; } + + // conv + bias + act + elementwise_op + if (fuse_type == CutlassFusionType::cbaele && + !cbaele_act_set.count(activation + "_" + elemenstwise_type + "_" + + activation1)) { + return false; + } + } else if (groups == ic && ic == oc) { // return false; // conv2d_depthwise not support residual input @@ -250,6 +327,14 @@ class CutlassTeller { return false; } + bool CbaeleCanSupport(OpDesc *op_desc, + Scope *scope, + std::string ele_type, + std::string act1_type, + int device_id) { + return false; + } + bool Conv2dCanSupport(int oc, int kc, int kh, @@ -261,7 +346,10 @@ class CutlassTeller { int groups, std::string activation, int device_id, - CutlassFusionType fuse_type) { + CutlassFusionType fuse_type, + // below two are used by cbaele + std::string activation1 = "identity", + std::string elemenstwise_type = "elementwise_add") { return false; } std::unordered_set CbaAct(int device_id) { return {}; } @@ -270,6 +358,9 @@ class CutlassTeller { static const int CUTLASS_NHWC_ALIGNMENT = 8; const std::unordered_set cutlass_sm = { 75, + 80, + 85, + 86, }; const std::unordered_set cba_act_set = { "relu", "swish", "identity", "leaky_relu", "sigmoid"}; @@ -278,6 +369,10 @@ class CutlassTeller { const std::unordered_set cdba_act_set = { "identity", "relu", "swish", "sigmoid"}; const std::unordered_set cbaa_act_set = {"relu"}; + const std::unordered_set cbaele_act_set = { + "identity_elementwise_add_identity", + "swish_elementwise_add_identity", + }; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc index cfe644a61ea51..3bd051c597179 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.cc @@ -73,7 +73,7 @@ DeleteQuantDequantFilterOpPass::DeleteQuantDequantFilterOpPass() { } // Delete quant_dequant_op, then quantize and dequantize weight void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { - const std::string pattern_name = "delete_quantdequant_filter_op_pattern"; + const std::string pattern_name = "delete_quant_dequant_filter_op_pattern"; FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; @@ -141,7 +141,7 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { "the received is %d", quant_axis)); - // To Do @Wangzheee: use "OutScale" to quantdequant + // To Do @Wangzheee: use "OutScale" to quant_dequant /*auto scales_name = quant_dequant_op->Op()->Output("OutScale"); PADDLE_ENFORCE_EQ(scales_name.size(), 1, platform::errors::InvalidArgument( diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc index 7358a82c6ca3c..b8a5dfdaa9465 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.cc @@ -86,7 +86,7 @@ DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() { } // Delete quantize_linear_op dequantize_linear_op, then add input_scales void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { - const std::string pattern_name = "delete_quantdequant_linear_op_pattern"; + const std::string pattern_name = "delete_quant_dequant_linear_op_pattern"; FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; @@ -124,14 +124,18 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { return; } */ - std::unordered_set nodes2rm = {}; - - // delete Scale and ZeroPoint tensor in scope + // Scale and ZeroPoint tensor should be removed in save_optimized_model_pass std::vector vars2rm = {}; vars2rm.emplace_back(quantize_linear_op->Op()->Input("Scale")[0]); vars2rm.emplace_back(quantize_linear_op->Op()->Input("ZeroPoint")[0]); vars2rm.emplace_back(dequantize_linear_op->Op()->Input("Scale")[0]); vars2rm.emplace_back(dequantize_linear_op->Op()->Input("ZeroPoint")[0]); + auto& scale_and_zero_point_param = g->GetOrInit>( + framework::ir::kScaleAndZeroPointParamAttr); + scale_and_zero_point_param.insert( + scale_and_zero_point_param.end(), vars2rm.begin(), vars2rm.end()); + + std::unordered_set nodes2rm = {}; // Get input scale from tensor const phi::DenseTensor& input_scale_tensor = @@ -182,13 +186,6 @@ void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const { nodes2rm.insert(dequantize_linear_op); nodes2rm.insert(dequantize_linear_op_out); GraphSafeRemoveNodes(graph, nodes2rm); - - for (auto& var_name : vars2rm) { - if (scope->FindVar(var_name)) { - scope->EraseVars({var_name}); - } - } - found_count++; }; gpd(graph, handler); diff --git a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc index ebb0ed9d00dc1..2a7071d54843d 100644 --- a/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_quant_dequant_op_pass.cc @@ -32,21 +32,21 @@ namespace ir { GET_IR_NODE(quant_dequant_op_out); void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { - const std::string pattern_name = "delete_quantdequant_op_pattern"; + const std::string pattern_name = "delete_quant_dequant_op_pattern"; FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; - std::string quantdequant_types = + std::string quant_dequant_types = "fake_quantize_dequantize_moving_average_abs_max"; auto* input_node = gpd.mutable_pattern() ->NewNode("input_node") - ->assert_is_op_input(quantdequant_types, "X") + ->assert_is_op_input(quant_dequant_types, "X") ->AsInput(); patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(), pattern_name); - pattern(input_node, quantdequant_types); + pattern(input_node, quant_dequant_types); auto* scope = param_scope(); int found_count = 0; diff --git a/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc index 7cea0e9f30ce8..48332f10094fa 100644 --- a/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/delete_remove_padding_recover_padding_pass.cc @@ -66,14 +66,16 @@ void DeleteRemovePaddingRecoverPaddingPass::ApplyImpl(ir::Graph *graph) const { std::unordered_set del_node_set; bool delete_recover_padding = true; - for (size_t i = 0; i < recover_padding_out->outputs.size(); ++i) { + for (size_t i = 0; i < recover_padding_out->outputs.size(); + ++i) { // NOLINT if (recover_padding_out->outputs[i]->Name() == "remove_padding") { // op_node auto *remove_padding_out_node = - recover_padding_out->outputs[i]->outputs[0]; // var_node - auto *out_op_node = remove_padding_out_node->outputs[0]; // op_node + recover_padding_out->outputs[i]->outputs[0]; // NOLINT // var_node + auto *out_op_node = + remove_padding_out_node->outputs[0]; // NOLINT // op_node IR_NODE_LINK_TO(recover_padding_input, out_op_node); - del_node_set.insert(recover_padding_out->outputs[i]); + del_node_set.insert(recover_padding_out->outputs[i]); // NOLINT del_node_set.insert(remove_padding_out_node); out_op_node->Op()->RenameInput(remove_padding_out_node->Name(), recover_padding_input->Name()); diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 583e51dc931d2..cf38ab2993d3f 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace framework { class Scope; diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 78e6ea14e43fc..edbd052e3256d 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc index 9cb8ce260683f..15c5b0b379b13 100644 --- a/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc +++ b/paddle/fluid/framework/ir/fuse_optimizer_ops_pass/fuse_adam_op_pass.cc @@ -233,13 +233,13 @@ class FuseAdamOpPass : public FuseOptimizerOpPass { scale_ops.reserve(beta_name.size()); for (size_t i = 0; i < adam_ops.size(); ++i) { auto &beta_1_pow_name = beta_name[i]; - auto beta_pow_iter = std::find_if( - adam_ops[i]->inputs.begin(), - adam_ops[i]->inputs.end(), - [&beta_name, &beta_1_pow_name](ir::Node *var_node) -> bool { - return var_node->Var() && - var_node->Var()->Name() == beta_1_pow_name; - }); + auto beta_pow_iter = + std::find_if(adam_ops[i]->inputs.begin(), + adam_ops[i]->inputs.end(), + [&beta_1_pow_name](ir::Node *var_node) -> bool { + return var_node->Var() && + var_node->Var()->Name() == beta_1_pow_name; + }); PADDLE_ENFORCE_NE(beta_pow_iter, adam_ops[i]->inputs.end(), platform::errors::NotFound( diff --git a/paddle/fluid/framework/ir/fuse_pass_base.h b/paddle/fluid/framework/ir/fuse_pass_base.h index bc5fc2a16d393..d8522f1aeaabe 100644 --- a/paddle/fluid/framework/ir/fuse_pass_base.h +++ b/paddle/fluid/framework/ir/fuse_pass_base.h @@ -40,6 +40,11 @@ static const char kFuseStatisAttr[] = "__fuse_statis__"; // allocation. static const char kRepetitiveParamAttr[] = "__repetitive_param__"; +// scale and zero point of the quantized/dequantized op should be removed in +// save_optimized_model_pass. +static const char kScaleAndZeroPointParamAttr[] = + "__scale_and_zero_point_param__"; + enum FuseOptions { DO_NOT_FUSE, // fusing will not be done FUSE_NATIVE, // fusing will be done without MKL-DNN diff --git a/paddle/fluid/framework/ir/fusion_group/code_generator.cc b/paddle/fluid/framework/ir/fusion_group/code_generator.cc index e59c495f2dd8d..2e5c2b5be4ac3 100644 --- a/paddle/fluid/framework/ir/fusion_group/code_generator.cc +++ b/paddle/fluid/framework/ir/fusion_group/code_generator.cc @@ -173,12 +173,10 @@ std::string CodeGenerator::Generate( std::string func_name, const std::vector& expressions) { // TODO(liuyiqun): Check whether all expressions are elementwise operations. - std::set input_ids = std::move(DistilInputIds(expressions)); - std::set output_ids = std::move(DistilOutputIds(expressions)); - std::set intermediate_output_ids = - std::move(DistilIntermediateIds(expressions)); - std::unordered_map dtypes = - std::move(DistilDtypes(expressions)); + std::set input_ids = DistilInputIds(expressions); + std::set output_ids = DistilOutputIds(expressions); + std::set intermediate_output_ids = DistilIntermediateIds(expressions); + std::unordered_map dtypes = DistilDtypes(expressions); TemplateVariable template_var; template_var.Add("func_name", func_name); template_var.Add( diff --git a/paddle/fluid/framework/ir/generate_pass.h b/paddle/fluid/framework/ir/generate_pass.h index 3a9d0f1efa71e..9f1ff68c1850a 100644 --- a/paddle/fluid/framework/ir/generate_pass.h +++ b/paddle/fluid/framework/ir/generate_pass.h @@ -51,7 +51,8 @@ class OpHelper; class SubgraphHelper; // VarHelper is used to represent a variable node. -struct VarHelper { +class VarHelper { + public: enum class Type { kInput, kOutput }; explicit VarHelper(const char* name); diff --git a/paddle/fluid/framework/ir/generate_pass_tester.cc b/paddle/fluid/framework/ir/generate_pass_tester.cc index 760e1e8ce4ef8..f0f9330259fff 100644 --- a/paddle/fluid/framework/ir/generate_pass_tester.cc +++ b/paddle/fluid/framework/ir/generate_pass_tester.cc @@ -25,15 +25,14 @@ REGISTER_GENERATE_PASS(generate_fc_fuse) { VLOG(3) << "exec lambda func."; auto mul = OP_(mul)({{"X", x}, {"Y", y}}).Out("Out"); auto ewadd = OP_(elementwise_add)({{"X", mul}, {"Y", z}}).Out("Out"); - if (with_relu) { + if (with_relu) { // NOLINT return OP_(relu)({"X", ewadd}).Out("Out"); } else { return ewadd; } }; // replace - SUBGRAPH_(replace) = [subgraph = &replace, with_relu]( - VAR_(x), VAR_(y), VAR_(z)) { + SUBGRAPH_(replace) = [subgraph = &replace](VAR_(x), VAR_(y), VAR_(z)) { auto& fc = OP_(fc)({{"Input", x}, {"W", y}, {"Bias", z}}); return fc.Out("Out"); }; diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 67f2eae2be5e6..53e2697daa868 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -134,11 +134,10 @@ bool VarDescIsConsistency(const Graph &graph) { } for (auto &iter : var_name2node_set) { auto &first_node = *iter.second.begin(); - bool is_persistable = std::any_of(iter.second.begin(), - iter.second.end(), - [&first_node](const ir::Node *node) { - return node->Var()->Persistable(); - }); + bool is_persistable = std::any_of( + iter.second.begin(), iter.second.end(), [](const ir::Node *node) { + return node->Var()->Persistable(); + }); if (is_persistable) { bool is_consistency = std::all_of(iter.second.begin(), diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index df804cf0d4f7b..3910e7586e35c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace framework { @@ -781,8 +781,7 @@ void GraphSafeRemoveNodes( for (auto *node : nodes) { if (saved_nodes != nullptr) { // prevent unique_ptr node from being released - saved_nodes->insert( - std::move(graph->RemoveNode(const_cast(node)))); + saved_nodes->insert(graph->RemoveNode(const_cast(node))); } else { graph->RemoveNode(const_cast(node)); } @@ -3519,22 +3518,22 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { } void patterns::DeleteQuantDequantOpPattern::operator()( - PDNode *input_node, const std::string &quantdequant_types) { + PDNode *input_node, const std::string &quant_dequant_types) { auto quant_dequant_op_inscale = pattern->NewNode(quant_dequant_op_inscale_repr()) - ->assert_is_op_input(quantdequant_types, "InScale") + ->assert_is_op_input(quant_dequant_types, "InScale") ->AsInput(); auto quant_dequant_op = pattern->NewNode(quant_dequant_op_repr()) - ->assert_is_op(quantdequant_types); + ->assert_is_op(quant_dequant_types); auto quant_dequant_op_out = pattern->NewNode(quant_dequant_op_out_repr()) - ->assert_is_op_output(quantdequant_types, "Out") + ->assert_is_op_output(quant_dequant_types, "Out") ->AsOutput(); auto quant_dequant_op_outscale = pattern->NewNode(quant_dequant_op_outscale_repr()) - ->assert_is_op_output(quantdequant_types, "OutScale") + ->assert_is_op_output(quant_dequant_types, "OutScale") ->AsOutput(); quant_dequant_op->LinksFrom({quant_dequant_op_inscale, input_node}); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 22d88e96b2852..4eac3440a4514 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1869,9 +1869,9 @@ struct DeleteDropoutOpPattern : public PatternBase { struct DeleteQuantDequantOpPattern : public PatternBase { DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {} + : PatternBase(pattern, name_scope, "delete_quant_dequant_op_pattern") {} - void operator()(PDNode* input_node, const std::string& quantdequant_types); + void operator()(PDNode* input_node, const std::string& quant_dequant_types); PATTERN_DECL_NODE(quant_dequant_op_inscale); PATTERN_DECL_NODE(quant_dequant_op); @@ -1883,7 +1883,7 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase { DeleteQuantDequantFilterOpPattern(PDPattern* pattern, const std::string& name_scope) : PatternBase( - pattern, name_scope, "delete_quantdequant_filter_op_pattern") {} + pattern, name_scope, "delete_quant_dequant_filter_op_pattern") {} void operator()(); diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index b8ad98113a3a4..4654abe6eb48d 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -38,7 +38,7 @@ class NOP : public OperatorBase { class SumOpMaker : public OpProtoAndCheckerMaker { public: - void Make() { + void Make() override { AddInput("X", "").AsDuplicable(); AddOutput("Out", "").AsDuplicable(); AddComment(""); @@ -60,7 +60,7 @@ class SumOpVarTypeInference : public VarTypeInference { class DummyOpMaker : public OpProtoAndCheckerMaker { public: - void Make() { + void Make() override { AddInput("X", "").AsDuplicable(); AddOutput("Out", "").AsDuplicable(); AddComment(""); diff --git a/paddle/fluid/framework/ir/identity_op_clean_pass.cc b/paddle/fluid/framework/ir/identity_op_clean_pass.cc index ab9df0ae4abee..55316c1b82310 100644 --- a/paddle/fluid/framework/ir/identity_op_clean_pass.cc +++ b/paddle/fluid/framework/ir/identity_op_clean_pass.cc @@ -70,7 +70,7 @@ FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern, auto in_dtype = x->Op()->GetAttrIfExists("in_dtype"); auto out_dtype = x->Op()->GetAttrIfExists("out_dtype"); return in_dtype == out_dtype; - } else if (op_type == "c_identity") { + } else if (op_type == "c_identity") { // NOLINT return true; } else if (op_type == "assign") { const auto& in_name = x->Op()->Input("X")[0]; diff --git a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc index 56323c1605136..afaaefcc4ae98 100644 --- a/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/layer_norm_fuse_pass.cc @@ -21,8 +21,8 @@ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/pretty_log.h" -#include "paddle/fluid/string/printf.h" +#include "paddle/utils/string/pretty_log.h" +#include "paddle/utils/string/printf.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/lock_free_optimize_pass.h b/paddle/fluid/framework/ir/lock_free_optimize_pass.h index 0ca3b8585fb13..f36b7162fcf06 100644 --- a/paddle/fluid/framework/ir/lock_free_optimize_pass.h +++ b/paddle/fluid/framework/ir/lock_free_optimize_pass.h @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc index ac05579e4fa46..5431e62fe4220 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.cc @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc index d9ea00e3935cc..f48897674143a 100644 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc +++ b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass.cc @@ -22,7 +22,7 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle::framework::ir { diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc b/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc deleted file mode 100644 index 1f78e293a21a3..0000000000000 --- a/paddle/fluid/framework/ir/memory_optimize_pass/share_varinfo_into_cinn_pass_test.cc +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/details/computation_op_handle.h" -#include "paddle/fluid/framework/details/eager_deletion_op_handle.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" -#include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" -#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" -#include "paddle/fluid/framework/parallel_executor.h" -#include "paddle/fluid/framework/program_desc.h" - -USE_OP_ITSELF(mul); -USE_OP_ITSELF(elementwise_add); - -USE_OP_ITSELF(cinn_launch); -PD_DECLARE_KERNEL(cinn_launch, CPU, ALL_LAYOUT); -#ifdef PADDLE_WITH_CUDA -PD_DECLARE_KERNEL(cinn_launch, GPU, ALL_LAYOUT); -#endif - -namespace paddle::framework { - -using Name2VarInfoMap = - std::unordered_map>; - -static ProgramDesc BuildProgramInsideCinnLaunchOp() { - ProgramDesc program; - auto* block = program.MutableBlock(0); - block->Var("var1"); - block->Var("var2"); - block->Var("var3"); - block->Var("var4"); - block->Var("var5"); - - auto add_op = - std::unique_ptr(new OpDesc("elementwise_add", - {{"X", {"var1"}}, {"Y", {"var2"}}}, - {{"Out", {"var3"}}}, - {})); - block->AppendAllocatedOp(std::move(add_op)); - auto mul_op = std::unique_ptr(new OpDesc( - "mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {})); - block->AppendAllocatedOp(std::move(mul_op)); - return program; -} - -static ProgramDesc BuildProgramWithCinnLaunchOp(int64_t compilation_key) { - // create a cinn_launch op - ProgramDesc program; - auto* block = program.MutableBlock(0); - block->Var("var1"); - block->Var("var2"); - block->Var("var4"); - block->Var("var5"); - - auto cinn_launch_op = std::unique_ptr( - new OpDesc("cinn_launch", - {{"X", {"var1", "var2", "var4"}}}, - {{"Out", {"var5"}}}, - {{"compilation_key", compilation_key}})); - block->AppendAllocatedOp(std::move(cinn_launch_op)); - return program; -} - -struct TestPassContext { - explicit TestPassContext(const ProgramDesc& program) { - graph = std::make_unique(program); - details::BuildStrategy build_strategy; - details::ExecutionStrategy exec_strategy; - exec_strategy.use_device_ = paddle::platform::kCUDA; - executor.reset(new ParallelExecutor(platform::CUDAPlace(0), - &scope, - exec_strategy, - build_strategy, - graph.get())); - } - - Scope scope; - std::unique_ptr graph; - std::unique_ptr executor; -}; - -TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) { - // add a subgraph to CinnCompiler - auto subgraph = std::make_unique(BuildProgramInsideCinnLaunchOp()); - subgraph->GetOrInit( - paddle2cinn::kMemOptVarInfoFromMainGraph); - auto compilation_key = - paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph)); - - // build test data and apply pass - auto context = std::make_unique( - BuildProgramWithCinnLaunchOp(compilation_key)); - - // check result - const ir::Graph& result_subgraph = - paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key); - const auto& dst_varinfo_map = result_subgraph.Get( - paddle2cinn::kMemOptVarInfoFromMainGraph); - ASSERT_EQ(dst_varinfo_map.size(), 4); - EXPECT_EQ(dst_varinfo_map.count("var1"), 1); - EXPECT_EQ(dst_varinfo_map.count("var5"), 1); - EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2); - EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2); -} - -TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) { - // build test data and apply pass - auto context = - std::make_unique(BuildProgramInsideCinnLaunchOp()); - auto& varinfo_map_shared = context->graph->GetOrInit( - paddle2cinn::kMemOptVarInfoFromMainGraph); - varinfo_map_shared = { - {"var1", std::make_shared("var1", 1)}, - {"var2", std::make_shared("var2", 2)}, - }; - - ir::MemOptVarInfoMapList varinfo_maps(1); - auto& dst_varinfo_map = varinfo_maps.front(); - dst_varinfo_map = {{"var1", std::make_shared("var1", 1)}, - {"var2", std::make_shared("var2", 1)}, - {"var3", std::make_shared("var3", 1)}, - {"var4", std::make_shared("var4", 1)}, - {"var5", std::make_shared("var5", 1)}}; - auto share_pass = - ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass"); - share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps); - share_pass->Apply(context->graph.get()); - - // check result - ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr); - ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr); - ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr); - ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr); - ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr); -} - -} // namespace paddle::framework diff --git a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc b/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc deleted file mode 100644 index eeec6fd8788d4..0000000000000 --- a/paddle/fluid/framework/ir/memory_optimize_pass/test_reference_count_pass_last_lived_ops.cc +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "gtest/gtest.h" -#include "paddle/common/flags.h" -#include "paddle/fluid/framework/details/multi_devices_helper.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_helper.h" -#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" -#include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" -#include "paddle/fluid/framework/parallel_executor.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/phi/core/kernel_registry.h" - -COMMON_DECLARE_double(eager_delete_tensor_gb); - -namespace paddle { -namespace framework { -namespace p = paddle::platform; - -static std::vector CreatePlaces(size_t num, bool use_cuda) { - std::vector result; - result.reserve(num); - for (size_t i = 0; i < num; ++i) { - if (use_cuda) { - result.emplace_back(platform::CUDAPlace(static_cast(i))); - } else { - result.emplace_back(platform::CPUPlace()); - } - } - return result; -} - -static void NewVar(BlockDesc *block, - const std::string &name, - const std::vector &shape) { - auto *var_desc = block->Var(name); - var_desc->SetShape(shape); -} - -static void AppendOp(BlockDesc *block, - const std::string &type, - VariableNameMap inputs, - VariableNameMap outputs, - AttributeMap attrs) { - auto &op_info = OpInfoMap::Instance().Get(type); - if (op_info.Checker()) { - op_info.Checker()->Check(&attrs); - } - - auto *op = block->AppendOp(); - op->SetType(type); - for (auto &pair : inputs) { - op->SetInput(pair.first, pair.second); - } - - for (auto &pair : outputs) { - op->SetOutput(pair.first, pair.second); - for (auto &var_name : pair.second) { - if (!block->FindVarRecursive(var_name)) { - NewVar(block, var_name, {}); - } - } - } - - op->SetAttrMap(attrs); - op->InferVarType(block); - op->InferShape(*block); -} - -class ReferenceCountPassTestHelper { - public: - ReferenceCountPassTestHelper(const ProgramDesc &program, bool use_cuda) - : graph_(program) { - details::BuildStrategy build_strategy; - build_strategy.enable_inplace_ = false; - build_strategy.memory_optimize_ = false; - FLAGS_eager_delete_tensor_gb = -1; - - details::ExecutionStrategy exec_strategy; - exec_strategy.use_device_ = use_cuda ? p::kCUDA : p::kCPU; - - executor_ = std::make_unique(CreatePlaces(1, use_cuda), - std::vector(), - "", - &scope_, - std::vector(), - exec_strategy, - build_strategy, - &graph_); - - auto ref_cnt_pass = - ir::PassRegistry::Instance().Get("reference_count_pass"); - ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); - ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars_); - ref_cnt_pass->Apply(&const_cast(executor_->Graph())); - } - - bool IsLastLivedOps(const std::string &name, - std::vector ops) const { - std::sort(ops.begin(), ops.end()); - return LastLivedOpTypes(name) == ops; - } - - std::vector LastLivedOps(const std::string &name) const { - auto &ops = last_live_ops_of_vars_[0].at(name).ops(); - std::vector ret; - ret.reserve(ops.size()); - for (auto *op : ops) { - ret.emplace_back(op->GetOp()); - } - return ret; - } - - private: - std::vector LastLivedOpTypes(const std::string &name) const { - auto iter = last_live_ops_of_vars_[0].find(name); - std::vector ret; - if (iter != last_live_ops_of_vars_[0].end()) { - for (auto *op : iter->second.ops()) { - ret.emplace_back(op->GetOp()->Type()); - } - } - std::sort(ret.begin(), ret.end()); - return ret; - } - - private: - ir::Graph graph_; - Scope scope_; - std::unique_ptr executor_; - - ir::MemOptVarInfoMapList mem_opt_var_infos_; - std::vector last_live_ops_of_vars_; -}; - -TEST(test_reference_count_pass, test_no_need_buffer_var_shrink) { - ProgramDesc program; - auto *block = program.MutableBlock(0); - std::vector shape{{3, 4, 5}}; - - /** - * The network is: - * - * x0 = fluid.layer.data(...) - * x1 = scale(x0, scale=1) - * x2 = scale(x1, scale=2) - * x3 = elementwise_mul(x1, x2) - * scale(x3, out=x1, scale=3) # produce a new version of x1 - * x4, x5 = elementwise_add_grad(dout=x3, x=x2, y=x1) - * x6 = elementwise_mul(x4, x5) - * x7 = elementwise_add(x5, x5) - */ - std::string x0 = "x0"; - std::string x1 = "x1"; - std::string x2 = "x2"; - std::string x3 = "x3"; - std::string x4 = "x4"; - std::string x5 = "x5"; - std::string x6 = "x6"; - std::string x7 = "x7"; - - NewVar(block, x0, shape); - AppendOp(block, "scale", {{"X", {x0}}}, {{"Out", {x1}}}, {{"scale", 1.0f}}); - AppendOp(block, "scale", {{"X", {x1}}}, {{"Out", {x2}}}, {{"scale", 2.0f}}); - AppendOp(block, - "elementwise_mul", - {{"X", {x1}}, {"Y", {x2}}}, - {{"Out", {x3}}}, - {}); - AppendOp(block, "scale", {{"X", {x3}}}, {{"Out", {x1}}}, {{"scale", 3.0f}}); - AppendOp(block, - "elementwise_add_grad", - {{GradVarName("Out"), {x3}}, {"X", {x2}}, {"Y", {x1}}}, - {{GradVarName("X"), {x4}}, {GradVarName("Y"), {x5}}}, - {}); - AppendOp(block, - "elementwise_mul", - {{"X", {x4}}, {"Y", {x5}}}, - {{"Out", {x6}}}, - {}); - AppendOp(block, - "elementwise_add", - {{"X", {x5}}, {"Y", {x5}}}, - {{"Out", {x7}}}, - {}); - - std::vector use_cuda_list{false}; -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - use_cuda_list.push_back(true); -#endif - for (auto use_cuda : use_cuda_list) { - ReferenceCountPassTestHelper helper(program, use_cuda); - ASSERT_TRUE(helper.IsLastLivedOps(x0, {"scale"})); - ASSERT_EQ(PADDLE_GET_CONST(float, - helper.LastLivedOps(x0)[0]->Attrs().at("scale")), - 1.0f); - - ASSERT_TRUE(helper.IsLastLivedOps(x1, {"scale"})); - ASSERT_EQ(PADDLE_GET_CONST(float, - helper.LastLivedOps(x1)[0]->Attrs().at("scale")), - 3.0f); - - ASSERT_TRUE(helper.IsLastLivedOps(x2, {"elementwise_mul"})); - ASSERT_TRUE(helper.IsLastLivedOps(x3, {"elementwise_add_grad"})); - - ASSERT_TRUE(helper.IsLastLivedOps(x4, {"elementwise_mul"})); - ASSERT_TRUE( - helper.IsLastLivedOps(x5, {"elementwise_mul", "elementwise_add"})); - - ASSERT_TRUE(helper.IsLastLivedOps(x6, {"elementwise_mul"})); - ASSERT_TRUE(helper.IsLastLivedOps(x7, {"elementwise_add"})); - } -} - -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc index 0f0d385569083..c09a2d1ffbb8d 100644 --- a/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc @@ -161,7 +161,7 @@ class ComputePropagateScalesMkldnnPassTest : public testing::Test { begin(wh[i]), end(wh[i]), wh_tensor->mutable_data(phi::CPUPlace()) + i * wh[0].size()); - if (type == "gru") { + if (type == "gru") { // NOLINT ComputeGruWeightScales( graph, &scope, wx_name, wh_name, &var_quant_scales); } else { diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index b2903a1337f3f..0aa71c3df5fb5 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -153,6 +153,48 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); + AddOpCompat(OpCompat("conv2d_transpose_bias")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .IsType>() + .IsOptional() + .End() + .AddAttr("output_size") + .IsType>() + .IsOptional() + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("elementwise_add")) .AddInput("X") .IsTensor() diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index d4fb89f091c87..4fb8418686299 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -50,7 +50,7 @@ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { public: Conv2DTransposeBiasFusePass(); std::string type() const override { return "conv2d_transpose"; } - std::string fused_type() const override { return "conv2d_transpose"; } + std::string fused_type() const override { return "conv2d_transpose_bias"; } }; class Conv3DBiasFusePass : public ConvBiasFusePass { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc index dfd838895aeb4..951d064364ce3 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass_tester.cc @@ -73,9 +73,9 @@ void MainTest(const ProgramDesc& prog, auto graph = std::make_unique(prog); auto pass = PassRegistry::Instance().Get("cpu_bfloat16_pass"); - int original_nodes_num = graph->Nodes().size(); + int original_nodes_num = static_cast(graph->Nodes().size()); graph.reset(pass->Apply(graph.release())); - int current_nodes_num = graph->Nodes().size(); + int current_nodes_num = static_cast(graph->Nodes().size()); int quantize_nodes_count = 0; int dequantize_nodes_count = 0; diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index 2f1e7e8a53865..0e9c452455de3 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -94,8 +94,8 @@ void CPUQuantizePass::QuantizeInput(Graph* g, "Var(%s) isn't the input of the %s operator.", input_name, op->Op()->Type())); - unsigned max = is_input_unsigned ? U8_MAX : S8_MAX; - float scale = scale_to_one * max; + unsigned max = is_input_unsigned ? U8_MAX : S8_MAX; // NOLINT + float scale = static_cast(scale_to_one) * max; // Create quantize output variable VarDesc quantize_out_desc(patterns::PDNodeName("quantize", "out")); @@ -175,12 +175,13 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, double scale_out = GetScaleValueForNode(output); unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX; - float scale = scale_out * max; + float scale = static_cast(scale_out) * max; for (size_t var_id = 0; var_id < unique_var_names.size(); var_id++) { auto index = -1; for (size_t it = 0; it < inputs.size(); it++) { - if (inputs[it]->Name() == unique_var_names[var_id]) index = it; + if (inputs[it]->Name() == unique_var_names[var_id]) + index = static_cast(it); } if (index == -1) { @@ -249,7 +250,7 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, output_name, op->Op()->Type())); unsigned max = is_unsigned ? U8_MAX : S8_MAX; - float scale = scale_to_one * max; + float scale = static_cast(scale_to_one) * max; // Create dequantize input variable VarDesc dequantize_in_desc(patterns::PDNodeName("dequantize", "in")); @@ -298,12 +299,13 @@ void CPUQuantizePass::DequantizeOutputs(Graph* g, std::vector dequantize_in_nodes(outputs.size()); unsigned max = is_unsigned ? U8_MAX : S8_MAX; - float scale = scale_to_one * max; + float scale = static_cast(scale_to_one) * max; for (size_t var_id = 0; var_id < var_names.size(); var_id++) { auto index = -1; for (size_t it = 0; it < outputs.size(); it++) { - if (outputs[it]->Name() == var_names[var_id]) index = it; + if (outputs[it]->Name() == var_names[var_id]) + index = static_cast(it); } if (index == -1) { diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc index bad886ae40cdf..c7e15e24216aa 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass_tester.cc @@ -61,7 +61,7 @@ void SetOp(ProgramDesc* prog, op->SetOutput("Output", {outputs[0]}); } else if (type == "pool2d" || type == "fused_transpose" || type == "reshape2" || type == "nearest_interp" || - type == "nearest_interp_v2") { + type == "nearest_interp_v2" || type == "dropout") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); } else if (type == "slice") { @@ -70,9 +70,6 @@ void SetOp(ProgramDesc* prog, } else if (type == "split") { op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs}); - } else if (type == "dropout") { - op->SetInput("X", {inputs[0]}); - op->SetOutput("Out", {outputs[0]}); } else if (type == "fc") { op->SetInput("Input", {inputs[0]}); if (inputs.size() > 1) op->SetInput("W", {inputs[1]}); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc index d2c6d981c3a2e..7d4429a2eb7f2 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass_tester.cc @@ -41,7 +41,7 @@ void SetOp(ProgramDesc* prog, if (type != "dropout" && type != "quantize" && type != "dequantize") { op->SetAttr("mkldnn_data_type", mkldnn_data_type); } - if (type == "pool2d") { + if (type == "pool2d") { // NOLINT op->SetInput("X", {inputs[0]}); op->SetOutput("Out", {outputs[0]}); if (!scale.empty()) op->SetAttr("Scale_in", scale[0]); @@ -120,8 +120,9 @@ ProgramDesc BuildConvRequantProgramDesc(bool use_mkldnn, float scale_out, float scale_in) { ProgramDesc prog; - for (auto& v : std::initializer_list( - {"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) { + const std::vector values = { + "a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"}; + for (auto& v : values) { auto* var = prog.MutableBlock(0)->Var(v); if (v.find("w") == 0 || v.find("b") == 0) { var->SetPersistable(true); @@ -240,7 +241,7 @@ ProgramDesc BuildOpRequantProgramDesc(bool use_mkldnn, {"h"}, use_mkldnn, {matmul_scale, requant_scale3}); - SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, {use_mkldnn}); + SetOp(&prog, "concat", "Concat", {"c", "f", "h"}, {"g"}, use_mkldnn); return prog; } @@ -683,7 +684,7 @@ ProgramDesc BuildRequantOpProgramDesc(bool use_mkldnn, {"h"}, use_mkldnn, {op_scale_in, op_scale_out}); - SetOp(&prog, "concat", "Concat", {"b", "e", "h"}, {"i"}, {use_mkldnn}); + SetOp(&prog, "concat", "Concat", {"b", "e", "h"}, {"i"}, use_mkldnn); return prog; } diff --git a/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc index 44856c086dc93..fde7fb07b9108 100644 --- a/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/int8_scale_calculation_mkldnn_pass_tester.cc @@ -70,14 +70,7 @@ ProgramDesc BuildProgramDesc(bool convWithExistingBias, } } - if (convWithExistingBias) { - SetOp(&prog, - "conv2d", - "conv", - std::vector({"c", "weights", "conv_bias"}), - std::vector({"f"}), - scale_weights); - } else if (scale_weights.size() > 1) { + if (convWithExistingBias || scale_weights.size() > 1) { SetOp(&prog, "conv2d", "conv", diff --git a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h index 0443c935abf93..6260f379ca2e1 100644 --- a/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h +++ b/paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h @@ -147,6 +147,7 @@ static void GetInfoFromTheTmpOp(ir::Graph* graph, inline void ConvertToFusedOp(OpDesc* op) { const std::map fused_ops = { {"conv2d", "fused_conv2d"}, + {"conv2d_transpose", "conv2d_transpose_bias"}, {"depthwise_conv2d", "fused_conv2d"}, {"elementwise_add", "fused_elementwise_add"}, {"elementwise_sub", "fused_elementwise_sub"}, diff --git a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc index 72b07fc8934de..bad1f4597f4a2 100755 --- a/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/params_quantization_mkldnn_pass_tester.cc @@ -39,8 +39,8 @@ struct Data { const std::vector& getData() const { return data; } private: - const std::vector shape; - const std::vector data; + const std::vector shape{}; + const std::vector data{}; }; struct TestScope { diff --git a/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc index 09bebfaec99c3..5d5edb83a9134 100644 --- a/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.cc @@ -15,7 +15,7 @@ #include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace framework { @@ -137,7 +137,7 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( dequant_op->Op()->HasAttr("Scale") ? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Scale")) : 1; - float reorder_scale = 1.0 / scale; + float reorder_scale = static_cast(1.0) / scale; float shift = dequant_op->Op()->HasAttr("Shift") ? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift")) diff --git a/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc index 13612d9024628..e02b167a19e3b 100644 --- a/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc @@ -17,8 +17,8 @@ #include #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/string/pretty_log.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/utils/string/pretty_log.h" #define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); #define GET_NODES \ diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc index 2a81b73751d3b..d7d18f6e8469c 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/fix_op_run_order_pass.cc @@ -22,7 +22,7 @@ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/op_graph_view.h" #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc index 295ef57cfdfea..cc20f52180871 100644 --- a/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc @@ -933,7 +933,7 @@ bool ReduceSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result, void ReduceSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { if (UseGPU()) { - if (strategy_.fuse_broadcast_ops_ == true) { + if (strategy_.fuse_broadcast_ops_ == true) { // NOLINT CreateFusedBroadcastOp(result, bcast_var_name_set_); } else { for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { @@ -1193,7 +1193,7 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const { node->Op()->Type())); // Create fetch_barrier op handle to enable output on all devices. // **NOTE** fetch_barrier should output variables list same as recv op does. - if (node->Op()->Type() == "fetch_barrier") { + if (node->Op()->Type() == "fetch_barrier") { // NOLINT result->Get(kGraphOps).emplace_back( new details::FetchBarrierOpHandle( result->CreateOpNode(node->Op()), local_scopes_, places_)); @@ -1354,7 +1354,7 @@ void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const { strategy_.reduce_ == details::BuildStrategy::ReduceStrategy::kReduce) { return; } - if (strategy_.fuse_broadcast_ops_ == true) { + if (strategy_.fuse_broadcast_ops_ == true) { // NOLINT CreateFusedBroadcastOp(result, bcast_var_name_set_); } else { for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) { diff --git a/paddle/fluid/framework/ir/pass_test.cc b/paddle/fluid/framework/ir/pass_test.cc index 2d13a912d6cca..4c3d19f51e73f 100644 --- a/paddle/fluid/framework/ir/pass_test.cc +++ b/paddle/fluid/framework/ir/pass_test.cc @@ -43,7 +43,7 @@ void BuildCircleGraph(Graph* g) { class TestPass : public Pass { protected: - void ApplyImpl(ir::Graph* graph) const { + void ApplyImpl(ir::Graph* graph) const override { graph->Set("copy_test_pass_attr", new int); graph->Set("copy_test_graph_attr", new int); @@ -226,7 +226,7 @@ TEST(PassTest, TestPassAttrCheckConvertAllBlocks) { class TestPassWithDefault : public Pass { protected: - void ApplyImpl(ir::Graph* graph) const { + void ApplyImpl(ir::Graph* graph) const override { graph->Set("copy_default_attr", new int); int test_pass_attr = this->Get("default_attr"); diff --git a/paddle/fluid/framework/ir/quantize_helper.cc b/paddle/fluid/framework/ir/quantize_helper.cc index fa72f4caf4433..c4b06651f1bbb 100644 --- a/paddle/fluid/framework/ir/quantize_helper.cc +++ b/paddle/fluid/framework/ir/quantize_helper.cc @@ -27,8 +27,8 @@ void SaveQuantInfoInTheGraph( if (!graph->Has(flag)) { graph->Set(flag, new bool(true)); } - for (auto iter = info_map.begin(); iter != info_map.end(); ++iter) { - graph->Set(iter->first + suffix, new std::vector(iter->second)); + for (const auto& iter : info_map) { + graph->Set(iter.first + suffix, new std::vector(iter.second)); } } diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc index 704f59bbace67..028089c11687f 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.cc @@ -155,14 +155,19 @@ void FusedTokenPrune::operator()() { void ElementWise::operator()() { // Create nodes for elementwise. auto* elementwise_input = pattern->NewNode(elementwise_input_repr()) - ->assert_is_op_input("elementwise_add", "X"); + ->assert_is_op_input("elementwise_add", "X") + ->assert_var_not_persistable(); + auto* elementwise_weight = pattern->NewNode(elementwise_weight_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_persistable_var(); auto* elementwise_op = pattern->NewNode(elementwise_op_repr())->assert_is_op("elementwise_add"); auto* elementwise_out = pattern->NewNode(elementwise_out_repr()) ->assert_is_op_output("elementwise_add"); // Add links for elementwise op. - elementwise_op->LinksFrom({elementwise_input}).LinksTo({elementwise_out}); + elementwise_op->LinksFrom({elementwise_input, elementwise_weight}) + .LinksTo({elementwise_out}); } } // namespace patterns diff --git a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h index 6df73301b1c32..af7be0f2faf4a 100644 --- a/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h +++ b/paddle/fluid/framework/ir/remove_padding_recover_padding_pass.h @@ -126,6 +126,7 @@ struct ElementWise : public PatternBase { void operator()(); PATTERN_DECL_NODE(elementwise_input); + PATTERN_DECL_NODE(elementwise_weight); PATTERN_DECL_NODE(elementwise_op); PATTERN_DECL_NODE(elementwise_out); }; diff --git a/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc b/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc index 35e1fe74948f3..9097eb6572521 100644 --- a/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc +++ b/paddle/fluid/framework/ir/split_layernorm_to_math_ops_pass.cc @@ -21,8 +21,8 @@ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/pretty_log.h" -#include "paddle/fluid/string/printf.h" +#include "paddle/utils/string/pretty_log.h" +#include "paddle/utils/string/printf.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc b/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc index 3a9a2c81889ee..ac3441eb7e737 100644 --- a/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc +++ b/paddle/fluid/framework/ir/transfer_layout_elim_pass.cc @@ -239,7 +239,7 @@ void TransferLayoutElimPass::ApplyImpl(ir::Graph *graph) const { FusePassBase::Init(pattern_name, graph); auto transfer_format = [&](std::string data_format) -> std::string { - if (data_format == "NCHW") { + if (data_format == "NCHW") { // NOLINT return "NHWC"; } else if (data_format == "NHWC") { return "NCHW"; diff --git a/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc index 6e12933f0f4d5..6bc9cb324d80d 100644 --- a/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc +++ b/paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.cc @@ -201,7 +201,7 @@ TrtDeleteWeightQuantDequantLinearOpPass:: void TrtDeleteWeightQuantDequantLinearOpPass::ApplyImpl( ir::Graph* graph) const { const std::string pattern_name = - "delete_weight_quantdequant_linear_op_pattern"; + "delete_weight_quant_dequant_linear_op_pattern"; FusePassBase::Init(pattern_name, graph); GraphPatternDetector gpd; @@ -231,13 +231,17 @@ void TrtDeleteWeightQuantDequantLinearOpPass::ApplyImpl( return; } */ - std::unordered_set nodes2rm = {}; - - // delete Scale and ZeroPoint tensor in scope + // Scale and ZeroPoint tensor should be removed in save_optimized_model_pass std::vector vars2rm = {}; vars2rm.emplace_back(weight_dequantize_linear_op->Op()->Input("Scale")[0]); vars2rm.emplace_back( weight_dequantize_linear_op->Op()->Input("ZeroPoint")[0]); + auto& scale_and_zero_point_param = g->GetOrInit>( + framework::ir::kScaleAndZeroPointParamAttr); + scale_and_zero_point_param.insert( + scale_and_zero_point_param.end(), vars2rm.begin(), vars2rm.end()); + + std::unordered_set nodes2rm = {}; int bit_length = PADDLE_GET_CONST( int, weight_dequantize_linear_op->Op()->GetAttr("bit_length")); @@ -363,13 +367,6 @@ void TrtDeleteWeightQuantDequantLinearOpPass::ApplyImpl( } GraphSafeRemoveNodes(graph, nodes2rm); - - for (auto& var_name : vars2rm) { - if (scope->FindVar(var_name)) { - scope->EraseVars({var_name}); - } - } - found_count++; }; gpd(graph, handler); diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 81f96f2fc33f4..0708218dbd07c 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -218,7 +218,8 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { } new_desc.SetAttr("begin_norm_axis", begin_norm_axis); } - int32_t hidden_size = layer_norm_scale->Var()->GetShape()[0]; + int32_t hidden_size = + static_cast(layer_norm_scale->Var()->GetShape()[0]); new_desc.SetAttr("hidden_size", hidden_size); auto fused_node = graph->CreateOpNode(&new_desc); // OpDesc will be copied. diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc index e20320e29a959..fa75f29ae9187 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc @@ -25,7 +25,9 @@ namespace ir { namespace patterns { struct AdaptiveSeqlenPatternV1 : public PatternBase { - AdaptiveSeqlenPatternV1(PDPattern* pattern, const std::string& name_scope); + AdaptiveSeqlenPatternV1(PDPattern* pattern, + const std::string& name_scope, + const std::string& matmul_type); // declare operator node's name PATTERN_DECL_NODE(embedding_xpu); @@ -44,7 +46,8 @@ struct AdaptiveSeqlenPatternV1 : public PatternBase { }; AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, - const std::string& name_scope) + const std::string& name_scope, + const std::string& matmul_type) : PatternBase(pattern, name_scope, name_scope) { auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr()) ->assert_is_op("embedding_with_eltwise_add_xpu"); @@ -59,11 +62,11 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, ->assert_is_op_input("multi_encoder_xpu", "x"); auto* mask = pattern->NewNode(mask_repr()) - ->assert_is_op_input("matmul", "X") - ->assert_is_op_input("matmul", "Y"); - auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul"); + ->assert_is_op_input(matmul_type, "X") + ->assert_is_op_input(matmul_type, "Y"); + auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type); auto* matmul_out = pattern->NewNode(matmul_out_repr()) - ->assert_is_op_output("matmul", "Out") + ->assert_is_op_output(matmul_type, "Out") ->assert_is_op_input("scale", "X"); auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); auto* scale_out = pattern->NewNode(scale_out_repr()) @@ -88,9 +91,10 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, } // namespace patterns int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1( - ir::Graph* graph) const { + ir::Graph* graph, const std::string& matmul_type) const { GraphPatternDetector gpd; - patterns::AdaptiveSeqlenPatternV1 pattern(gpd.mutable_pattern(), name_scope_); + patterns::AdaptiveSeqlenPatternV1 pattern( + gpd.mutable_pattern(), name_scope_, matmul_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -143,7 +147,9 @@ int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1( namespace patterns { struct AdaptiveSeqlenPatternV2 : public PatternBase { - AdaptiveSeqlenPatternV2(PDPattern* pattern, const std::string& name_scope); + AdaptiveSeqlenPatternV2(PDPattern* pattern, + const std::string& name_scope, + const std::string& matmul_type); // declare operator node's name PATTERN_DECL_NODE(embedding_xpu); @@ -172,7 +178,8 @@ struct AdaptiveSeqlenPatternV2 : public PatternBase { }; AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, - const std::string& name_scope) + const std::string& name_scope, + const std::string& matmul_type) : PatternBase(pattern, name_scope, name_scope) { auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr()) ->assert_is_op("embedding_with_eltwise_add_xpu"); @@ -201,11 +208,11 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, pattern->NewNode(unsqueeze_0_repr())->assert_is_op("unsqueeze2"); auto* unsqueeze_0_out = pattern->NewNode(unsqueeze_0_out_repr()) ->assert_is_op_output("unsqueeze2", "Out") - ->assert_is_op_input("matmul_v2", "X") - ->assert_is_op_input("matmul_v2", "Y"); - auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2"); + ->assert_is_op_input(matmul_type, "X") + ->assert_is_op_input(matmul_type, "Y"); + auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type); auto* matmul_out = pattern->NewNode(matmul_out_repr()) - ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_output(matmul_type, "Out") ->assert_is_op_input("scale", "X"); auto* scale_0 = pattern->NewNode(scale_0_repr())->assert_is_op("scale"); auto* scale_0_out = pattern->NewNode(scale_0_out_repr()) @@ -244,9 +251,10 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, } // namespace patterns int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV2( - ir::Graph* graph) const { + ir::Graph* graph, const std::string& matmul_type) const { GraphPatternDetector gpd; - patterns::AdaptiveSeqlenPatternV2 pattern(gpd.mutable_pattern(), name_scope_); + patterns::AdaptiveSeqlenPatternV2 pattern( + gpd.mutable_pattern(), name_scope_, matmul_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -324,9 +332,13 @@ void MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); + std::vector matmul_types{"matmul", "matmul_v2"}; + int found_subgraph_count = 0; + for (auto& matmul_type : matmul_types) { + found_subgraph_count += ApplyAdaptiveSeqlenPassV1(graph, matmul_type); + found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph, matmul_type); + } - int found_subgraph_count = ApplyAdaptiveSeqlenPassV1(graph); - found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph); AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h index 22910c2120530..ea3b52bf35a24 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h @@ -76,7 +76,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase { | out_var* */ - int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph) const; + int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph, + const std::string& matmul_type) const; /* adaptive seqlen V2, before: @@ -132,7 +133,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase { | out_var* */ - int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph) const; + int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph, + const std::string& matmul_type) const; private: const std::string name_scope_{"multi_encoder_xpu_adaptive_seqlen_fuse_pass"}; diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index 8e126df64ad41..e7a5acac2bae2 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -38,7 +38,8 @@ struct SingleEncoderXPUPattern : public PatternBase { bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant); + bool is_smooth_quant, + const std::string& relative_type); // declare operator node's name // If norm_before, use ln_0 & ln_1. @@ -141,6 +142,16 @@ struct SingleEncoderXPUPattern : public PatternBase { PATTERN_DECL_NODE(smooth_scale_1_out); PATTERN_DECL_NODE(smooth_scale_2_out); + // roformer_relative_embedding_xpu + PATTERN_DECL_NODE(q_relative_emb); + PATTERN_DECL_NODE(q_cos_embedding); + PATTERN_DECL_NODE(q_sin_embedding); + PATTERN_DECL_NODE(q_relative_emb_out); + PATTERN_DECL_NODE(k_relative_emb); + PATTERN_DECL_NODE(k_cos_embedding); + PATTERN_DECL_NODE(k_sin_embedding); + PATTERN_DECL_NODE(k_relative_emb_out); + private: std::string act_type_; std::string matmul_type_0_; @@ -150,6 +161,7 @@ struct SingleEncoderXPUPattern : public PatternBase { bool with_q_scale_{false}; bool with_mask_{true}; bool is_smooth_quant_{false}; + std::string relative_type_ = ""; }; SingleEncoderXPUPattern::SingleEncoderXPUPattern( @@ -162,7 +174,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant) + bool is_smooth_quant, + const std::string& relative_type) : PatternBase(pattern, name_scope, name_scope), act_type_(act_type), matmul_type_0_(matmul_type_0), @@ -171,7 +184,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( norm_before_(norm_before), with_q_scale_(with_q_scale), with_mask_(with_mask), - is_smooth_quant_(is_smooth_quant) { + is_smooth_quant_(is_smooth_quant), + relative_type_(relative_type) { // layer_norm 0 PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr()); PDNode* ln_0_bias = nullptr; @@ -244,14 +258,38 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ->assert_var_not_persistable(); PDNode* q_scale = nullptr; PDNode* q_scale_out = nullptr; + std::string target_op_type = matmul_type_1_; if (with_q_scale_) { q_scale = pattern->NewNode(q_scale_repr())->assert_is_op("scale"); q_scale_out = pattern->NewNode(q_scale_out_repr()) ->assert_is_op_output("scale", "Out") ->assert_is_op_input(matmul_type_1_, "X") ->assert_var_not_persistable(); + target_op_type = "scale"; } else { - q_transpose_out->assert_is_op_input(matmul_type_1_, "X"); + if (relative_type_.empty()) { + q_transpose_out->assert_is_op_input(target_op_type, "X"); + } else { + q_transpose_out->assert_is_op_input(relative_type_, "x"); + } + } + PDNode* q_relative_emb = nullptr; + PDNode* q_cos_embedding = nullptr; + PDNode* q_sin_embedding = nullptr; + PDNode* q_relative_emb_out = nullptr; + if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build q_relative_emb"; + q_relative_emb = + pattern->NewNode(q_relative_emb_repr())->assert_is_op(relative_type_); + q_sin_embedding = pattern->NewNode(q_sin_embedding_repr()) + ->assert_is_op_input(relative_type_, "sin_emb") + ->AsInput(); + q_cos_embedding = pattern->NewNode(q_cos_embedding_repr()) + ->assert_is_op_input(relative_type_, "cos_emb") + ->AsInput(); + q_relative_emb_out = pattern->NewNode(q_relative_emb_out_repr()) + ->assert_is_op_output(relative_type_, "out") + ->assert_is_op_input(target_op_type, "X"); } // k: matmul + add + reshape + transpose @@ -279,9 +317,23 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( pattern->NewNode(k_transpose_repr())->assert_is_op("transpose2"); auto* k_transpose_out = pattern->NewNode(k_transpose_out_repr()) ->assert_is_op_output("transpose2", "Out") - ->assert_is_op_input(matmul_type_1_, "Y") ->assert_var_not_persistable(); + PDNode* k_relative_emb = nullptr; + PDNode* k_sin_embedding = q_sin_embedding; + PDNode* k_cos_embedding = q_cos_embedding; + PDNode* k_relative_emb_out = nullptr; + if (relative_type_.empty()) { + k_transpose_out->assert_is_op_input(matmul_type_1_, "Y"); + } else if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build k_relative_emb"; + k_transpose_out->assert_is_op_input(relative_type_, "x"); + k_relative_emb = + pattern->NewNode(k_relative_emb_repr())->assert_is_op(relative_type_); + k_relative_emb_out = pattern->NewNode(k_relative_emb_out_repr()) + ->assert_is_op_output(relative_type_, "out") + ->assert_is_op_input(matmul_type_1_, "Y"); + } // qk: matmul + add + softmax auto* qk_matmul = pattern->NewNode(qk_matmul_repr())->assert_is_op(matmul_type_1_); @@ -482,18 +534,31 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( q_add->LinksFrom({q_matmul_out, q_add_bias}).LinksTo({q_add_out}); q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out}); q_transpose->LinksFrom({q_reshape_out}).LinksTo({q_transpose_out}); - PDNode* qk_matmul_x = q_transpose_out; + PDNode* last_node = q_transpose_out; + if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build q_relative_emb link"; + q_relative_emb->LinksFrom({last_node, q_sin_embedding, q_cos_embedding}) + .LinksTo({q_relative_emb_out}); + last_node = q_relative_emb_out; + } if (with_q_scale_) { - q_scale->LinksFrom({q_transpose_out}).LinksTo({q_scale_out}); - qk_matmul_x = q_scale_out; + q_scale->LinksFrom({last_node}).LinksTo({q_scale_out}); + last_node = q_scale_out; } + PDNode* qk_matmul_x = last_node; k_matmul->LinksFrom({q_matmul_x, k_matmul_w}).LinksTo({k_matmul_out}); k_add->LinksFrom({k_matmul_out, k_add_bias}).LinksTo({k_add_out}); k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out}); k_transpose->LinksFrom({k_reshape_out}).LinksTo({k_transpose_out}); - - qk_matmul->LinksFrom({qk_matmul_x, k_transpose_out}).LinksTo({qk_matmul_out}); + last_node = k_transpose_out; + if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build k_relative_emb link"; + k_relative_emb->LinksFrom({last_node, k_sin_embedding, k_cos_embedding}) + .LinksTo({k_relative_emb_out}); + last_node = k_relative_emb_out; + } + qk_matmul->LinksFrom({qk_matmul_x, last_node}).LinksTo({qk_matmul_out}); PDNode* qk_softmax_x = qk_matmul_out; if (with_mask_) { qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out}); @@ -571,7 +636,8 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { pattern_param.norm_before, pattern_param.with_q_scale, pattern_param.with_mask, - pattern_param.is_smooth_quant); + pattern_param.is_smooth_quant, + pattern_param.relative_type); while (ApplyMultiEncoderXPUFuse(graph)) { multi_encoder_fused_counts++; } @@ -950,7 +1016,8 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant) const { + bool is_smooth_quant, + const std::string& relative_type) const { bool local_quant = false; if (std::getenv("XPU_LOCAL_QUANT")) { local_quant = atoi(std::getenv("XPU_LOCAL_QUANT")); @@ -965,7 +1032,8 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( norm_before, with_q_scale, with_mask, - is_smooth_quant); + is_smooth_quant, + relative_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -1068,6 +1136,16 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( GET_IR_NODE(smooth_scale_1_out); GET_IR_NODE(smooth_scale_2_out); + // roformer_relative_embedding_xpu + GET_IR_NODE(q_relative_emb); + GET_IR_NODE(q_cos_embedding); + GET_IR_NODE(q_sin_embedding); + GET_IR_NODE(q_relative_emb_out); + GET_IR_NODE(k_relative_emb); + GET_IR_NODE(k_cos_embedding); + GET_IR_NODE(k_sin_embedding); + GET_IR_NODE(k_relative_emb_out); + auto* block = q_matmul->Op()->Block(); auto* scope = param_scope(); auto weight_dtype = @@ -1275,6 +1353,24 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( op_desc.SetAttr("relative_type", static_cast(0)); op_desc.SetAttr("use_precision", use_precision); op_desc.SetAttr("is_per_channel", is_per_channel); + if (relative_type == "roformer_relative_embedding_xpu") { + // q/k share the rotary embedding + op_desc.SetInput("roformer_embedding", + {q_cos_embedding->Name(), q_sin_embedding->Name()}); + op_desc.SetAttr("relative_type", 1); + auto q_cos_emb_shape = q_cos_embedding->Var()->GetShape(); + CHECK_GE(static_cast(q_cos_emb_shape.size()), 2) + << q_cos_emb_shape.size(); + auto size_per_head = q_reshape_out->Var()->GetShape()[3]; + CHECK_EQ(size_per_head, q_cos_emb_shape[q_cos_emb_shape.size() - 1]); + int max_pos_len = q_cos_emb_shape[q_cos_emb_shape.size() - 2]; + VLOG(3) << "relative embedding max sequence len: " << max_pos_len; + op_desc.SetAttr("max_pos_len", max_pos_len); + } else { + op_desc.SetInput("roformer_embedding", {}); + op_desc.SetAttr("max_pos_len", 0); + } + // if quant,skip softmax,and use qk_matmul out_threshold as softmax_max auto softmax_max_name = qk_matmul->Op()->Output("Out")[0]; if (var_quant_scales.find(softmax_max_name) != var_quant_scales.end()) { @@ -1320,6 +1416,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( IR_NODE_LINK_TO(smooth_scale_1_weight, single_encoder_xpu); IR_NODE_LINK_TO(smooth_scale_2_weight, single_encoder_xpu); } + if (relative_type == "roformer_relative_embedding_xpu") { + IR_NODE_LINK_TO(q_cos_embedding, single_encoder_xpu); + IR_NODE_LINK_TO(q_sin_embedding, single_encoder_xpu); + } // Delete nodes std::unordered_set delete_nodes{ln_1, @@ -1405,6 +1505,12 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( delete_nodes.insert(smooth_scale_1_out); delete_nodes.insert(smooth_scale_2_out); } + if (relative_type == "roformer_relative_embedding_xpu") { + delete_nodes.insert(q_relative_emb); + delete_nodes.insert(q_relative_emb_out); + delete_nodes.insert(k_relative_emb); + delete_nodes.insert(k_relative_emb_out); + } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; }; @@ -1453,7 +1559,8 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { "fc_bias", "ln_scale", "ln_bias", - "smooth_scale_weight"}; + "smooth_scale_weight", + "roformer_embedding"}; std::map> arg_names_map; std::string mask_name = single_encoders[0]->Op()->Inputs().count("mask") > 0 ? single_encoders[0]->Op()->Inputs().at("mask")[0] @@ -1556,6 +1663,11 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { quant_types.end(), per_quant_types.begin(), per_quant_types.end()); } op_desc.SetAttr("quant_types", quant_types); + if (single_encoders[0]->Op()->HasAttr("max_pos_len")) { + op_desc.SetAttr("max_pos_len", + PADDLE_GET_CONST( + int, single_encoders[0]->Op()->GetAttr("max_pos_len"))); + } op_desc.SetOutput("out", {out_name}); op_desc.SetOutput("x_fp16", {x_fp16_name}); op_desc.SetOutput("out_fp16", {out_fp16_name}); @@ -1642,15 +1754,157 @@ std::vector MultiEncoderXPUFusePass::GeneratePatternParams() const { return std::vector{ // Params are arranged in alphabetic order - {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true, false}, - {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true, false}, - {"gelu", "mul", "matmul", "matmul", false, true, true, false}, - {"relu", "mul", "matmul", "matmul", false, true, true, false}, - - {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true, true}, - {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true, true}, - {"gelu", "mul", "matmul", "matmul", false, true, true, true}, - {"relu", "mul", "matmul", "matmul", false, true, true, true}, + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + false, + ""}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + ""}, + {"gelu", "mul", "matmul", "matmul", false, true, true, false, ""}, + {"relu", "mul", "matmul", "matmul", false, true, true, false, ""}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + ""}, + + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + true, + ""}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + ""}, + {"gelu", "mul", "matmul", "matmul", false, true, true, true, ""}, + {"relu", "mul", "matmul", "matmul", false, true, true, true, ""}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + ""}, + + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + false, + "roformer_relative_embedding_xpu"}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + {"gelu", + "mul", + "matmul", + "matmul", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + {"relu", + "mul", + "matmul", + "matmul", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + true, + "roformer_relative_embedding_xpu"}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, + {"gelu", + "mul", + "matmul", + "matmul", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, + {"relu", + "mul", + "matmul", + "matmul", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, }; } diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h index 6c45838073af6..238f7d8d419c5 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h @@ -129,6 +129,7 @@ struct PatternParam { bool with_q_scale; bool with_mask; bool is_smooth_quant; + std::string relative_type; }; class MultiEncoderXPUFusePass : public FusePassBase { @@ -144,7 +145,8 @@ class MultiEncoderXPUFusePass : public FusePassBase { bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant) const; + bool is_smooth_qunat, + const std::string& relative_type) const; bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index b0853690c065a..1509509b32a15 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -91,7 +91,7 @@ std::vector FindOpNodeByInputName(Graph* graph, template std::string IntTypeToString() { - LOG(FATAL) << "Not support type."; + PADDLE_THROW(phi::errors::InvalidArgument("Not support type.")); return ""; } diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index cdefbb5ca682c..c30d27cf398c5 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -248,7 +248,7 @@ static void QuantFP32ToIntX(const float* src_ptr, T* dst_ptr, float max_val, int numel) { - LOG(FATAL) << "Not support."; + PADDLE_THROW(phi::errors::Unimplemented("Not support.")); } template <> @@ -290,8 +290,9 @@ void ConvertWithQuant(phi::DenseTensor* weight, phi::DenseTensor* scale_max, bool transpose, bool per_channel_quant) { - LOG(FATAL) << "Not support for Tcpu is " - << phi::CppTypeToDataType::Type(); + std::stringstream ss; + ss << "Not support for Tcpu is " << phi::CppTypeToDataType::Type(); + PADDLE_THROW(phi::errors::Fatal(ss.str())); } template < @@ -440,8 +441,8 @@ void ConvertWithoutQuant(phi::DenseTensor* weight, QuantFP32ToIntX( weight_data, cpu_ctx->Alloc(weight), max_val, size); } else { - LOG(FATAL) - << "Only support float<->int31, int8<->int8 and int16<->int16 convert."; + PADDLE_THROW(phi::errors::InvalidArgument( + "Only support float<->int31, int8<->int8 and int16<->int16 convert.")); } } diff --git a/paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc new file mode 100644 index 0000000000000..2c50c77cad8d7 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +/* +fuse block in vis model to reformer_relative_pos_xpu op +------------------------------------------------------ */ +/* support xpu roformer relative pos */ +/* x --------------- */ +/* | \ | */ +/* | \ | */ +/* split shape | */ +/* / | \ | */ +/* / | \ | */ +/* | scale slice | */ +/* \ | / \ | */ +/* \ | / \ | */ +/* concat slice slice | */ +/* | / \ | */ +/* | / \ | */ +/* elementwise_mul elementwise_mul */ +/* | / */ +/* | / */ +/* elementwise_add */ +/* | */ +/* | */ +/* out */ +/*-------------------------------------------*/ +/* After the pass apply: */ +/* x */ +/* cos_emb | sin_emb */ +/* \ | / */ +/* xpu_roformer_relative */ +/* | */ +/* | */ +/* out */ +/*-------------------------------------------*/ + +struct RoformerRelativePosXPUPattern : public PatternBase { + RoformerRelativePosXPUPattern(PDPattern* pattern, + const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(split); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(concat); + PATTERN_DECL_NODE(mul1); + + PATTERN_DECL_NODE(shape); + PATTERN_DECL_NODE(slice1); + PATTERN_DECL_NODE(slice_sin); + PATTERN_DECL_NODE(slice_cos); + + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(add); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(sin_emb); + PATTERN_DECL_NODE(cos_emb); + PATTERN_DECL_NODE(split_out1); + PATTERN_DECL_NODE(split_out2); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(concat_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(shape_out); + PATTERN_DECL_NODE(slice1_out); + PATTERN_DECL_NODE(slice_sin_out); + PATTERN_DECL_NODE(slice_cos_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(add_out); +}; + +RoformerRelativePosXPUPattern::RoformerRelativePosXPUPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("split", "X") + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("shape", "Input") + ->AsInput(); + + auto* split = pattern->NewNode(split_repr()) + ->assert_is_op("split") + ->assert_op_attr("axis", 3) + ->assert_op_attr("num", 2); // do we really need it + + auto* split_out1 = pattern->NewNode(split_out1_repr()) + ->assert_is_op_input("scale", "X") + ->assert_is_op_nth_output("split", "Out", 1); + auto* split_out2 = pattern->NewNode(split_out2_repr()) + ->assert_is_op_nth_input("concat", "X", 1) + ->assert_is_op_nth_output("split", "Out", 0); + split->LinksFrom({x}).LinksTo({split_out1, split_out2}); + + auto* scale = pattern->NewNode(scale_repr()) + ->assert_is_op("scale") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + auto scale = op_desc->GetAttrIfExists("scale"); + return (std::fabs(scale + 1.0) < 1e-5); + }); + auto* scale_out = pattern->NewNode(scale_out_repr()) + ->assert_is_op_input("concat", "X") + ->assert_is_op_output("scale", "Out"); + scale->LinksFrom({split_out1}).LinksTo({scale_out}); + auto* concat = pattern->NewNode(concat_repr())->assert_is_op("concat"); + auto* concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_output("concat", "Out"); + concat->LinksFrom({scale_out, split_out2}).LinksTo({concat_out}); + auto* shape = pattern->NewNode(shape_repr())->assert_is_op("shape"); + auto* shape_out = pattern->NewNode(shape_out_repr()) + ->assert_is_op_input("slice", "Input") + ->assert_is_op_output("shape", "Out"); + shape->LinksFrom({x}).LinksTo({shape_out}); + auto* slice1 = pattern->NewNode(slice1_repr())->assert_is_op("slice"); + auto* slice1_out = pattern->NewNode(slice1_out_repr()) + ->assert_is_op_input("slice", "EndsTensorList") + ->assert_is_op_output("slice", "Out"); + slice1->LinksFrom({shape_out}).LinksTo({slice1_out}); + auto* sin_emb = pattern->NewNode(sin_emb_repr()) + ->assert_is_op_input("slice", "Input") + ->AsInput(); + auto* cos_emb = pattern->NewNode(cos_emb_repr()) + ->assert_is_op_input("slice", "Input") + ->AsInput(); + auto* slice_sin = pattern->NewNode(slice_sin_repr())->assert_is_op("slice"); + auto* slice_sin_out = pattern->NewNode(slice_sin_out_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_op_output("slice", "Out"); + slice_sin->LinksFrom({sin_emb, slice1_out}).LinksTo({slice_sin_out}); + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("elementwise_mul"); + auto* mul1_out = pattern->NewNode(mul1_out_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_op_output("elementwise_mul", "Out"); + mul1->LinksFrom({concat_out, slice_sin_out}).LinksTo({mul1_out}); + auto* add = pattern->NewNode(add_repr())->assert_is_op("elementwise_add"); + auto* add_out = pattern->NewNode(add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->AsOutput(); + auto* slice_cos = pattern->NewNode(slice_cos_repr())->assert_is_op("slice"); + auto* slice_cos_out = pattern->NewNode(slice_cos_out_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_op_output("slice", "Out"); + slice_cos->LinksFrom({cos_emb, slice1_out}).LinksTo({slice_cos_out}); + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("elementwise_mul"); + auto* mul2_out = pattern->NewNode(mul2_out_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_mul", "Out"); + mul2->LinksFrom({x, slice_cos_out}).LinksTo({mul2_out}); + add->LinksFrom({mul2_out, mul1_out}).LinksTo({add_out}); +} + +} // namespace patterns + +class RoformerRelativePosFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"roformer_relative_pos_fuse_pass"}; +}; + +void RoformerRelativePosFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::RoformerRelativePosXPUPattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle RoformerRelativePosFusePass fuse"; + /* declare operator node's name */ + // declare variable node's name + GET_IR_NODE(split); + GET_IR_NODE(scale); + GET_IR_NODE(concat); + GET_IR_NODE(mul1); + GET_IR_NODE(shape); + GET_IR_NODE(slice1); + GET_IR_NODE(slice_sin); + GET_IR_NODE(slice_cos); + GET_IR_NODE(mul2); + GET_IR_NODE(add); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(sin_emb); + GET_IR_NODE(cos_emb); + GET_IR_NODE(split_out1); + GET_IR_NODE(split_out2); + GET_IR_NODE(scale_out); + GET_IR_NODE(concat_out); + GET_IR_NODE(mul1_out); + GET_IR_NODE(shape_out); + GET_IR_NODE(slice1_out); + GET_IR_NODE(slice_sin_out); + GET_IR_NODE(slice_cos_out); + GET_IR_NODE(mul2_out); + GET_IR_NODE(add_out); + auto* block = add->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // Generate roformer_relative_embedding_xpu fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("roformer_relative_embedding_xpu"); + // set attrs for fused op + fused_op_desc.SetInput("x", {x->Name()}); + fused_op_desc.SetInput("sin_emb", {sin_emb->Name()}); + fused_op_desc.SetInput("cos_emb", {cos_emb->Name()}); + + fused_op_desc.SetOutput("out", {add_out->Name()}); + fused_op_desc.SetAttr("max_pos_len", + static_cast(cos_emb->Var()->GetShape()[2])); + + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(x, fused_op); + IR_NODE_LINK_TO(sin_emb, fused_op); + IR_NODE_LINK_TO(cos_emb, fused_op); + IR_NODE_LINK_TO(fused_op, add_out); + // delete useless node + std::unordered_set delete_nodes = {split, + scale, + concat, + mul1, + shape, + slice1, + slice_sin, + slice_cos, + mul2, + add, + split_out1, + split_out2, + scale_out, + concat_out, + shape_out, + slice1_out, + slice_sin_out, + slice_cos_out, + mul2_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(roformer_relative_pos_fuse_pass, + paddle::framework::ir::RoformerRelativePosFusePass); + +REGISTER_PASS_CAPABILITY(roformer_relative_pos_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "roformer_relative_embedding_xpu", 0)); diff --git a/paddle/fluid/framework/ir/xpu/squeeze_excitation_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/squeeze_excitation_fuse_pass.cc index 8009529854c9d..f75e87601b05f 100644 --- a/paddle/fluid/framework/ir/xpu/squeeze_excitation_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/squeeze_excitation_fuse_pass.cc @@ -310,9 +310,10 @@ int SqueezeExcitationFusePass::ApplyImpl(ir::Graph* graph, if (mul_1_w_dims[0] != mul_2_w_dims[1] || mul_1_w_dims[1] != mul_2_w_dims[0] || mul_1_w_len != mul_1_w_dims[0] * mul_1_w_dims[1]) { - LOG(FATAL) << "Error: Dims of excitation mul1 weight is: " << mul_1_w_dims - << ", but get dims of excitation mul2 weight is: " - << mul_2_w_dims; + std::stringstream ss; + ss << "Error: Dims of excitation mul1 weight is: " << mul_1_w_dims + << ", but get dims of excitation mul2 weight is: " << mul_2_w_dims; + PADDLE_THROW(phi::errors::InvalidArgument(ss.str())); } std::vector encode_filter_int16; encode_filter_int16.resize(mul_1_w_len + mul_2_w_len); diff --git a/paddle/fluid/framework/lod_tensor.h b/paddle/fluid/framework/lod_tensor.h index 9556430787153..a691c4ae74f29 100644 --- a/paddle/fluid/framework/lod_tensor.h +++ b/paddle/fluid/framework/lod_tensor.h @@ -27,17 +27,19 @@ limitations under the License. */ #include "paddle/fluid/platform/place.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/mixed_vector.h" +#include "paddle/utils/test_macros.h" namespace paddle { namespace framework { // Split phi::DenseTensor and copy to each place specified in places. -std::vector SplitLoDTensor( +TEST_API std::vector SplitLoDTensor( const phi::DenseTensor& src, const std::vector places); -void MergeLoDTensor(phi::DenseTensor* target, - const std::vector& lod_tensors, - platform::Place dst_place); +TEST_API void MergeLoDTensor( + phi::DenseTensor* target, + const std::vector& lod_tensors, + platform::Place dst_place); /* * LoD is short for Level of Details. @@ -65,7 +67,7 @@ LoD SliceInLevel(const LoD& in, /* * Transform an LoD from relative offsets to absolute offsets. */ -LoD ToAbsOffset(const LoD& in); +TEST_API LoD ToAbsOffset(const LoD& in); TEST_API bool operator==(const LoD& a, const LoD& b); @@ -85,7 +87,7 @@ TEST_API bool operator==(const LoD& a, const LoD& b); * tensor_height>0. */ -bool CheckLoD(const LoD& in, int tensor_height = -1); +TEST_API bool CheckLoD(const LoD& in, int tensor_height = -1); /* * Check whether this absolute lod's format is valid. * @@ -99,7 +101,7 @@ bool CheckLoD(const LoD& in, int tensor_height = -1); * same(the height of underlying tensor) or `tensor_height` if * tensor_height>0. */ -bool CheckAbsLoD(const LoD& in, int tensor_height = -1); +TEST_API bool CheckAbsLoD(const LoD& in, int tensor_height = -1); /* * Expand the `source` to fit the LoD of `lod`. For example, a `source` @@ -162,7 +164,7 @@ phi::DenseTensor LodExpand(const phi::DenseTensor& source, // Returns: // LoD = [[1, 4], [2, 4, 2, 3, 2]] // pair = {11, 24} -std::pair> GetSubLoDAndAbsoluteOffset( +TEST_API std::pair> GetSubLoDAndAbsoluteOffset( const LoD& lod, size_t start_idx, size_t end_idx, size_t start_level); /* @@ -182,7 +184,7 @@ void DeserializeFromStream(std::istream& is, const size_t& seek, const std::vector& shape); -LoD ConvertToOffsetBasedLoD(const LoD& length_lod); +TEST_API LoD ConvertToOffsetBasedLoD(const LoD& length_lod); void SerializeToStream(std::ostream& os, const phi::DenseTensor& tensor); diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index 5dae6c1c84514..d3b74fb00c1c5 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -234,6 +234,20 @@ void NaiveExecutor::RegisterInputHook(const HookFunc &hookfunc) { } } +void NaiveExecutor::RegisterOutputHook(const PirHookFunc &hookfunc) { + pir_output_hookfuncs_.push_back(hookfunc); + if (interpreter_core_) { + interpreter_core_->SetOutputHooks(pir_output_hookfuncs_); + } +} + +void NaiveExecutor::RegisterInputHook(const PirHookFunc &hookfunc) { + pir_input_hookfuncs_.push_back(hookfunc); + if (interpreter_core_) { + interpreter_core_->SetInputHooks(pir_input_hookfuncs_); + } +} + void NaiveExecutor::MakeReusePlan( const std::unordered_map &reuse_table) { std::unordered_map> clusters; diff --git a/paddle/fluid/framework/naive_executor.h b/paddle/fluid/framework/naive_executor.h index d36e3042b0b72..47f58924de144 100644 --- a/paddle/fluid/framework/naive_executor.h +++ b/paddle/fluid/framework/naive_executor.h @@ -45,6 +45,9 @@ class NaiveExecutor { public: using HookFunc = std::function; + using PirHookFunc = + std::function; + explicit NaiveExecutor(const platform::Place& place) : place_(place) {} ~NaiveExecutor(); @@ -94,6 +97,8 @@ class NaiveExecutor { void RegisterOutputHook(const HookFunc& hookfunc); void RegisterInputHook(const HookFunc& hookfunc); + void RegisterOutputHook(const PirHookFunc& hookfunc); + void RegisterInputHook(const PirHookFunc& hookfunc); private: void CreateOps(const ProgramDesc& desc, int block_id); @@ -107,6 +112,9 @@ class NaiveExecutor { std::vector output_hookfuncs_; std::vector input_hookfuncs_; + std::vector pir_output_hookfuncs_; + std::vector pir_input_hookfuncs_; + // Record information that tensor_a should ShareBufferWith tensor_b. std::unordered_map> reuse_cache_; diff --git a/paddle/fluid/framework/new_executor/CMakeLists.txt b/paddle/fluid/framework/new_executor/CMakeLists.txt index d00949a22ad82..d06fdd8c4c7cd 100644 --- a/paddle/fluid/framework/new_executor/CMakeLists.txt +++ b/paddle/fluid/framework/new_executor/CMakeLists.txt @@ -1,6 +1,6 @@ file(GLOB_RECURSE standalone_executor_srcs "*.cc") -if(NOT (WITH_CINN AND NOT CINN_ONLY)) +if(NOT (WITH_CINN)) list(REMOVE_ITEM standalone_executor_srcs ${CMAKE_CURRENT_SOURCE_DIR}/instruction/cinn_jit_instruction.cc) endif() @@ -26,7 +26,7 @@ set(standalone_executor_deps device_event_base framework_proto) -if(WITH_CINN AND NOT CINN_ONLY) +if(WITH_CINN) set(standalone_executor_deps ${standalone_executor_deps} cinn_runtime_dialect diff --git a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc index 166853e2b18da..0d73e2d3fede9 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc +++ b/paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.cc @@ -32,14 +32,14 @@ CreateInterpreterCoreGarbageCollector( const platform::Place& place, const std::vector>& vec_instruction) { if (platform::is_gpu_place(place)) { - if (IsInterpretercoreFastGCEnabled()) { + if (IsInterpretercoreFastGCEnabled()) { // NOLINT return std::unique_ptr( new InterpreterCoreFastGarbageCollector()); } else { return std::unique_ptr( new InterpreterCoreEventGarbageCollector(vec_instruction)); } - } else if (platform::is_xpu_place(place)) { + } else if (platform::is_xpu_place(place)) { // NOLINT // Because there is no multi-stream on XPU device, fast GC can // be used. // Previously, XPU used no_event GC. But `Wait` in no_event GC @@ -62,14 +62,14 @@ CreateInterpreterCoreGarbageCollector( const platform::Place& place, const std::vector& vec_instruction) { if (platform::is_gpu_place(place)) { - if (IsInterpretercoreFastGCEnabled()) { + if (IsInterpretercoreFastGCEnabled()) { // NOLINT return std::unique_ptr( new InterpreterCoreFastGarbageCollector()); } else { return std::unique_ptr( new InterpreterCoreEventGarbageCollector(vec_instruction)); } - } else if (platform::is_xpu_place(place)) { + } else if (platform::is_xpu_place(place)) { // NOLINT // Because there is no multi-stream on XPU device, fast GC can // be used. // Previously, XPU used no_event GC. But `Wait` in no_event GC diff --git a/paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.cc b/paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.cc index 3b7ebc18f36da..d236e740679dd 100644 --- a/paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.cc +++ b/paddle/fluid/framework/new_executor/garbage_collector/no_event_garbage_collector.cc @@ -49,9 +49,10 @@ void InterpreterCoreNoEventGarbageCollector::Add( if (var->IsType()) { Add(var->GetMutable()->MoveMemoryHolder(), ctx); - } else if (var->IsType< - operators::reader:: - OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) { + } else if ( + var->IsType< + operators::reader:: + OrderedMultiDeviceLoDTensorBlockingQueueHolder>()) { // NOLINT // TODO(xiongkun03) in old executor, this type of variable is not support // eager deletion. so we just leave it here ? } else if (var->IsType()) { diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc index 3708c255d59e4..83b7149ac7da2 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.cc @@ -40,6 +40,7 @@ class CinnJitInstruction::FnPtrImpl { : cinn_kernel_info_(cinn_kernel_info) {} void Run(const std::vector& kernel_args, void* stream) { + VLOG(6) << "Start Run: " << cinn_kernel_info_.fn_name; func_args_.clear(); // 1. Convert the phi::DenseTensor type to cinn_pod_value_t @@ -65,11 +66,13 @@ class CinnJitInstruction::FnPtrImpl { // 3. Launch host kernel ((lower_func_ptr_g)cinn_kernel_info_.fn_ptr)( static_cast(func_args_.data()), func_args_.size(), stream); + VLOG(6) << "End Run: " << cinn_kernel_info_.fn_name; } void InferShape(const std::vector& kernel_args, int32_t input_tensor_size, int32_t output_tensor_size) { + VLOG(6) << "Start InferShape: " << cinn_kernel_info_.fn_name; func_args_.clear(); // 1. Convert the phi::DenseTensor type to cinn_pod_value_t @@ -113,6 +116,7 @@ class CinnJitInstruction::FnPtrImpl { kernel_args[input_tensor_size + i]->Resize(dim); free(output_tensor_shapes[i]); } + VLOG(6) << "End InferShape: " << cinn_kernel_info_.fn_name; } private: @@ -163,6 +167,12 @@ CinnJitInstruction::CinnJitInstruction( result.type().dyn_cast(); tensor->set_type( paddle::dialect::TransToPhiDataType(alloc_tensor_type.dtype())); + for (size_t j = 0; j < alloc_tensor_type.dims().size(); ++j) { + if (alloc_tensor_type.dims()[j] < 0) { + need_update_shape = true; + continue; + } + } tensor->Resize(alloc_tensor_type.dims()); } } @@ -173,7 +183,7 @@ void CinnJitInstruction::Run() { auto stream = gpu_ctx->stream(); - if (FLAGS_cinn_bucket_compile) { + if (FLAGS_cinn_bucket_compile && need_update_shape) { fn_ptr_impl_->InferShape( tensor_args_, input_tensor_size, output_tensor_size); } @@ -184,8 +194,8 @@ void CinnJitInstruction::Run() { // 2. exexute kernel fn_ptr_impl_->Run(tensor_args_, static_cast(stream)); #else - VLOG(phi::FATAL) << "Not Supported: cinn jit instruction currently does not " - "support non-CUDA kernel"; + VLOG(0) << "Not Supported: cinn jit instruction currently does not " + "support non-CUDA kernel"; #endif } diff --git a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h index 5f744f4229d91..dadcae371471b 100644 --- a/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/cinn_jit_instruction.h @@ -52,6 +52,7 @@ class CinnJitInstruction : public InstructionBase { int32_t input_tensor_size; int32_t output_tensor_size; + bool need_update_shape{false}; std::vector tensor_args_; ::pir::Operation* op_{nullptr}; // not owned diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc index db8ef9f2de7bf..0730ef34f140b 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.cc @@ -198,6 +198,16 @@ IfInstruction::~IfInstruction() { } } +void IfInstruction::SetOutputHooks(const std::vector& hookfuncs) { + true_branch_inter_->SetOutputHooks(hookfuncs); + false_branch_inter_->SetOutputHooks(hookfuncs); +} + +void IfInstruction::SetInputHooks(const std::vector& hookfuncs) { + true_branch_inter_->SetInputHooks(hookfuncs); + false_branch_inter_->SetInputHooks(hookfuncs); +} + void IfInstruction::Run() { bool cond = true; if (cond_var_->IsType()) { diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h index cf0de0fc3581f..7667c9128a8a7 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/if_instruction.h @@ -48,6 +48,10 @@ class IfInstruction : public InstructionBase { PirInterpreter* FalseBranchInterpreter() const { return false_branch_inter_; } + void SetOutputHooks(const std::vector& hookfuncs); + + void SetInputHooks(const std::vector& hookfuncs); + private: ::pir::Operation* op_; diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc index d3c025e9ebbcd..ec0970cd26e34 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/tuple_pop_instruction.cc @@ -28,8 +28,8 @@ TuplePopInstruction::TuplePopInstruction(size_t id, : InstructionBase(id, place), op_(op), value_exe_info_(value_exe_info) { tuple_pop_op_ = op->dyn_cast(); VLOG(6) << "construct tuple_pop instruction for: " << tuple_pop_op_->name(); - auto stack_value = tuple_pop_op_.container(); - auto var_array = value_exe_info_->GetVarByValue(stack_value); + auto outlet_value = tuple_pop_op_.outlet(); + auto var_array = value_exe_info_->GetVarByValue(outlet_value); stack_element_var_array_ = var_array->GetMutable(); std::unordered_map> inputs; diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc index ae8b0d1df2eee..e4cc8568bbf88 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.cc @@ -240,6 +240,16 @@ void WhileInstruction::ShareDatasToOutputs() { } } +void WhileInstruction::SetOutputHooks( + const std::vector& hookfuncs) { + body_inter_->SetOutputHooks(hookfuncs); +} + +void WhileInstruction::SetInputHooks( + const std::vector& hookfuncs) { + body_inter_->SetInputHooks(hookfuncs); +} + void WhileInstruction::Run() { #ifdef PADDLE_WITH_DNNL // Executor on being destroyed clears oneDNN cache and resets diff --git a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h index 849d4ec4d184d..b6f729a784f5a 100644 --- a/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h +++ b/paddle/fluid/framework/new_executor/instruction/control_flow/while_instruction.h @@ -50,6 +50,10 @@ class WhileInstruction : public InstructionBase { PirInterpreter* BodyInterpreter() const { return body_inter_.get(); } + void SetOutputHooks(const std::vector& hookfuncs); + + void SetInputHooks(const std::vector& hookfuncs); + private: // 'output' = 'input' void ShareInputsToOutputs(); diff --git a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc index 683d1bd95dcb8..d5366c40e8d15 100644 --- a/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/custom_kernel_instruction.cc @@ -85,7 +85,7 @@ void CustomKernelInstruction::BuildCustomContext( input_name2id_map_[t] = input_index; input_index++; input_ptrs_.emplace_back(nullptr); - custom_kernel_ctx_.EmplaceBackInput(std::move(paddle::Tensor())); + custom_kernel_ctx_.EmplaceBackInput(paddle::Tensor()); } VLOG(8) << "ctx->EmplaceBackInput : an optional input " << t; continue; @@ -280,8 +280,7 @@ void CustomKernelInstruction::BuildCustomContext( out_name)); VLOG(3) << "Custom Operator: BuildContext - inplace optional outputs : " << out_name << " is None."; - cache_out_ptrs_.emplace_back(nullptr); - custom_kernel_ctx_.EmplaceBackOutput(std::move(paddle::Tensor())); + custom_kernel_ctx_.EmplaceBackOutput(paddle::Tensor()); VLOG(8) << "ctx->EmplaceBackOutput : an optional output"; continue; diff --git a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc index c44c8e8be84d3..098c77346778b 100644 --- a/paddle/fluid/framework/new_executor/instruction/instruction_util.cc +++ b/paddle/fluid/framework/new_executor/instruction/instruction_util.cc @@ -281,7 +281,9 @@ std::unordered_set GetInternalInputs(pir::Block* block) { } if (op.isa()) { auto tuple_pop_op = op.dyn_cast(); - inner_inputs.insert(tuple_pop_op.container()); + if (tuple_pop_op.has_container()) { + inner_inputs.insert(tuple_pop_op.container()); + } } for (size_t i = 0; i < op.num_operands(); ++i) { inner_inputs.insert(op.operand_source(i)); diff --git a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc index aa3df67535747..18b5e5a573b1d 100644 --- a/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/onednn/onednn_instruction.cc @@ -94,6 +94,8 @@ static phi::Attribute ConvertPirAttribute2RuntimeAttribute( phi::DataType dtype = attr.dyn_cast().data(); return dtype; + } else if (attr_type_name == "paddle::dialect::ScalarAttribute") { + return attr.dyn_cast().data(); } else { PADDLE_THROW(phi::errors::Unimplemented( "ConvertPirAttribute2RuntimeAttribute not support [%s] ", @@ -245,16 +247,16 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction( } VLOG(6) << "finish process infer meta context"; - auto kernel_name = + auto kernel_name_ = op_attributes.at("kernel_name").dyn_cast().AsString(); - auto kernel_key = op_attributes.at("kernel_key") - .dyn_cast() - .data(); + auto kernel_key_ = op_attributes.at("kernel_key") + .dyn_cast() + .data(); phi_kernel_ = new phi::Kernel( - phi::KernelFactory::Instance().SelectKernel(kernel_name, kernel_key)); + phi::KernelFactory::Instance().SelectKernel(kernel_name_, kernel_key_)); PADDLE_ENFORCE_EQ( - phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name); + phi_kernel_->IsValid(), true, "not found kernel for [%s]", kernel_name_); VLOG(6) << "finish process select kernel"; BuildPhiContextstream_priority(); op_func_node.scheduling_priority_ = dist_attr->scheduling_priority(); - // set mannual event information + // set manual event information op_func_node.force_record_event_ = dist_attr->force_record_event(); op_func_node.events_to_wait_ = dist_attr->events_to_wait(); op_func_node.event_to_record_ = dist_attr->event_to_record(); @@ -1342,6 +1342,7 @@ void PrintValuesAndVariables( GetOriginOutputNames(op_name); // 1. output string + VLOG(10) << "Generate output string ..."; std::string ret_value_str = "Value : ("; std::string ret_variable_str = "Variable: ("; if (!op.results().empty()) { @@ -1387,10 +1388,12 @@ void PrintValuesAndVariables( ret_variable_str += ") = "; // 2. op name + VLOG(10) << "Generate op name ..."; ret_value_str += op_name; ret_variable_str += op_name; // 3. input string + VLOG(10) << "Generate input string ..."; ret_value_str += "("; ret_variable_str += "("; if (!op.operands().empty()) { diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h index 6d5e408a2e573..c78277769c84c 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.h @@ -48,7 +48,7 @@ namespace interpreter { class AsyncWorkQueue { public: AsyncWorkQueue(size_t host_num_threads, - size_t deivce_num_threads, + size_t device_num_threads, EventsWaiter* waiter); // void WaitEmpty() { queue_group_->WaitQueueGroupEmpty(); } diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index e3839b863aa0d..131f756bdb1d3 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -498,7 +498,7 @@ void RunWhileBlockPreStaticBuild(const framework::Scope& scope, const framework::VariableNameMap& output_var_names = item->Outputs(); for (auto& ipt : input_var_names) { for (const std::string& var_name : ipt.second) { - if (operators::StrInVaraiableNameMap(var_name, output_var_names)) { + if (operators::StrInVariableNameMap(var_name, output_var_names)) { no_copy_var_names.insert(var_name); } } diff --git a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc index c485bc7d11c6c..abc39c7ec1e03 100644 --- a/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc +++ b/paddle/fluid/framework/new_executor/interpreter/stream_analyzer.cc @@ -123,8 +123,8 @@ void StreamAnalyzer::ConstructEvents(std::vector* instructions) { } } } - // NOTE(lizhiyu): The mannual event only support the program_interpreter to - // annalyze the streams across the sub_programs. construct mannual events to + // NOTE(lizhiyu): The manual event only support the program_interpreter to + // analyze the streams across the sub_programs. construct manual events to // record for (auto& instruction : *instructions) { // create extra event to record @@ -158,11 +158,11 @@ void StreamAnalyzer::ConstructEvents(std::vector* instructions) { instruction.AddEventToRecord(device_event, platform::kCUDA /*unused*/); (*program_force_events_to_wait_)[op_func_node->event_to_record_] = instruction.EventToRecord(); - VLOG(6) << "Create mannual event: " << op_func_node->event_to_record_ + VLOG(6) << "Create manual event: " << op_func_node->event_to_record_ << " for the operator: " << instruction.OpBase()->Type(); } } - // add extra mannual events + // add extra manual events if (!(op_func_node->events_to_wait_.empty())) { for (auto event_name : op_func_node->events_to_wait_) { PADDLE_ENFORCE_NE( @@ -608,10 +608,10 @@ void shrink_event_info( } } - for (size_t unnecessary_wiater_instr_id : unnecessary_waiter_instr_ids) { + for (size_t unnecessary_waiter_instr_id : unnecessary_waiter_instr_ids) { VLOG(8) << "Shrink event : " << recorder_instr_id << " -> " - << unnecessary_wiater_instr_id; - waiter_recorder_map[unnecessary_wiater_instr_id].erase( + << unnecessary_waiter_instr_id; + waiter_recorder_map[unnecessary_waiter_instr_id].erase( recorder_instr_id); } } @@ -738,8 +738,8 @@ void PirStreamAnalyzer::ConstructEvents( } } } - // NOTE(lizhiyu): The mannual event only support the program_interpreter to - // annalyze the streams across the sub_programs. construct mannual events to + // NOTE(lizhiyu): The manual event only support the program_interpreter to + // annalyze the streams across the sub_programs. construct manual events to // record for (auto& instr : instructions) { // create extra event to record @@ -770,11 +770,11 @@ void PirStreamAnalyzer::ConstructEvents( instr->AddEventToRecord(device_event, platform::kCUDA /*unused*/); (*program_force_events_to_wait_)[instr->EventToRecordInfo()] = instr->EventToRecord(); - VLOG(6) << "Create mannual event: " << instr->EventToRecordInfo() + VLOG(6) << "Create manual event: " << instr->EventToRecordInfo() << " for the operator: " << instr->Name(); } } - // add extra mannual events + // add extra manual events if (!(instr->EventsToWaitInfo().empty())) { for (auto event_name : instr->EventsToWaitInfo()) { PADDLE_ENFORCE_NE( diff --git a/paddle/fluid/framework/new_executor/interpreter_base_impl.h b/paddle/fluid/framework/new_executor/interpreter_base_impl.h index e99a02f37136e..1d9bac63d7c15 100644 --- a/paddle/fluid/framework/new_executor/interpreter_base_impl.h +++ b/paddle/fluid/framework/new_executor/interpreter_base_impl.h @@ -104,6 +104,10 @@ class InterpreterBaseImpl { virtual void SetInputHooks(const std::vector& hookfuncs) = 0; + virtual void SetOutputHooks(const std::vector& hookfuncs) = 0; + + virtual void SetInputHooks(const std::vector& hookfuncs) = 0; + virtual std::shared_ptr> GetDependencyCount() const = 0; virtual bool IsSharedResultsBuild() const = 0; diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 61151373b2a29..7bf78eed8b04e 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -139,6 +139,15 @@ void InterpreterCore::SetOutputHooks(const std::vector& hookfuncs) { impl_->SetOutputHooks(hookfuncs); } +void InterpreterCore::SetInputHooks(const std::vector& hookfuncs) { + impl_->SetInputHooks(hookfuncs); +} + +void InterpreterCore::SetOutputHooks( + const std::vector& hookfuncs) { + impl_->SetOutputHooks(hookfuncs); +} + void InterpreterCore::Build( const std::vector& feed_names, std::vector* op_func_nodes) { diff --git a/paddle/fluid/framework/new_executor/interpretercore.h b/paddle/fluid/framework/new_executor/interpretercore.h index f2b4426b8ebb2..39ad549a78455 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.h +++ b/paddle/fluid/framework/new_executor/interpretercore.h @@ -14,6 +14,7 @@ #pragma once #include "paddle/fluid/framework/new_executor/interpreter_base_impl.h" +#include "paddle/fluid/framework/new_executor/new_executor_defs.h" PD_DECLARE_bool(new_executor_use_local_scope); @@ -88,6 +89,10 @@ class InterpreterCore { void SetInputHooks(const std::vector& hookfuncs); + void SetOutputHooks(const std::vector& hookfuncs); + + void SetInputHooks(const std::vector& hookfuncs); + void Build(const std::vector& feed_names, std::vector* op_func_nodes); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.cc b/paddle/fluid/framework/new_executor/new_executor_defs.cc index b3ec52029bb5b..6c9e5b4a877d5 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.cc +++ b/paddle/fluid/framework/new_executor/new_executor_defs.cc @@ -94,7 +94,7 @@ void VariableScope::AddVar(const std::string& name, auto id = VarSize(); name2id_[name] = static_cast(id); vec_meta_info_.emplace_back(0, var_desc); - if (local_scope_ != nullptr) { + if (local_scope_ != nullptr) { // NOLINT var_list_.push_back(local_scope_->FindVar(name)); } else { var_list_.push_back(scope_->FindVar(name)); diff --git a/paddle/fluid/framework/new_executor/new_executor_defs.h b/paddle/fluid/framework/new_executor/new_executor_defs.h index c416b151aef03..79619828980aa 100644 --- a/paddle/fluid/framework/new_executor/new_executor_defs.h +++ b/paddle/fluid/framework/new_executor/new_executor_defs.h @@ -40,9 +40,13 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace framework { +class InstructionBase; +class ValueExecutionInfo; using OpKernelComputeFunc = std::function; using HookFunc = std::function; +using PirHookFunc = + std::function; using SchedulingPriority = int64_t; diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 1e2fa3269bb41..0eabcceeeb981 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -480,18 +480,9 @@ void HandleForSpecialOp(pir::Operation* op, auto shape = op->attribute("shape"); auto dim = phi::make_ddim(shape.data().GetData()); auto dtype = op->attribute("dtype"); - auto place = op->attribute("place").data(); - if (place.GetType() == phi::AllocationType::UNDEFINED) { - place = phi::CPUPlace(); - } if (!common::contain_unknown_dim(dim)) { phi::DenseTensorMeta meta(dtype.data(), dim); t->set_meta(meta); - auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); - dev_ctx->Alloc(t, dtype.data()); - VLOG(10) << "[Alloc var]: " - << op->attribute("name") << " " - << t->initialized(); } } } @@ -556,10 +547,10 @@ void HandleForSpecialOp(pir::Operation* op, auto value = op->operand_source(0); Scope* scope = const_cast(value_exe_info->GetScope()); - if (auto bool_atttr = + if (auto bool_attr = value.attribute(kAttrIsPersistable)) { - if (bool_atttr.data()) { - VLOG(6) << "Handle for builtin.shadow_ouptut persistable value:" + if (bool_attr.data()) { + VLOG(6) << "Handle for builtin.shadow_output persistable value:" << var_name; scope = const_cast(value_exe_info->GetScope()->root()); } @@ -753,7 +744,7 @@ void BuildScope(const pir::Block& block, Variable* var = value_exe_info->GetScope()->FindVar(kwarg.first); PADDLE_ENFORCE(var, paddle::platform::errors::InvalidArgument( - "The variable %s shoud exist", kwarg.first)); + "The variable %s should exist", kwarg.first)); value_exe_info->Add(kwarg.second, kwarg.first); } @@ -951,27 +942,27 @@ std::shared_ptr BuildOperatorBase( } attr_map[legacy_arg_name] = vec_int; } else if (array_list[0].isa()) { - std::vector vec_int64; + std::vector vec_int64; for (auto attribute : array_list) { vec_int64.push_back( attribute.dyn_cast().data()); // NOLINT } attr_map[legacy_arg_name] = vec_int64; } else if (array_list[0].isa()) { - std::vector vec_bool; + std::vector vec_bool; for (auto attribute : array_list) { vec_bool.push_back(attribute.dyn_cast().data()); } attr_map[legacy_arg_name] = vec_bool; } else if (array_list[0].isa()) { - std::vector vec_float; + std::vector vec_float; for (auto attribute : array_list) { vec_float.push_back( attribute.dyn_cast().data()); // NOLINT } attr_map[legacy_arg_name] = vec_float; } else if (array_list[0].isa()) { - std::vector vec_double; + std::vector vec_double; for (auto attribute : array_list) { vec_double.push_back( attribute.dyn_cast().data()); // NOLINT diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.cc b/paddle/fluid/framework/new_executor/pir_interpreter.cc index 236f18dfb223c..c2b234d8d667f 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.cc +++ b/paddle/fluid/framework/new_executor/pir_interpreter.cc @@ -81,6 +81,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); COMMON_DECLARE_bool(enable_pir_in_executor); COMMON_DECLARE_bool(enable_pir_in_executor_trace_run); +COMMON_DECLARE_int32(low_precision_op_list); #define CREATE_INSTR(instr_name) \ vec_instruction_base_.emplace_back(std::make_unique( \ @@ -89,6 +90,21 @@ COMMON_DECLARE_bool(enable_pir_in_executor_trace_run); namespace paddle { namespace framework { +void RecordLowPrecisionOp(const InstructionBase* instr_node) { + if (FLAGS_low_precision_op_list) { + std::string op_name = instr_node->Name(); + ::pir::Operation* op = instr_node->Operation(); + if (op->HasAttribute("kernel_key")) { + phi::KernelKey kernel_key = + op->attribute("kernel_key") + .dyn_cast() + .data(); + phi::KernelFactory::Instance().AddToLowPrecisionKernelList( + op_name, kernel_key.dtype()); + } + } +} + PirInterpreter::PirInterpreter(const platform::Place& place, const std::vector& fetch_var_names, const ::pir::Block* ir_block, @@ -145,7 +161,7 @@ PirInterpreter::PirInterpreter(const platform::Place& place, << std::chrono::high_resolution_clock::now().time_since_epoch().count(); BuildScope(*ir_block_, ss.str(), value_exe_info_.get()); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) calculate_stream_timer_ = std::make_unique(place); #endif } @@ -299,7 +315,7 @@ void PirInterpreter::ShareBuildResultsFrom(const InterpreterBaseImpl& src) { std::tuple PirInterpreter::InterpreterRunTime() { double start_time = 0, end_time = 0; -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) start_time = calculate_stream_timer_->StartTime(); end_time = calculate_stream_timer_->EndTime(); #endif @@ -337,7 +353,7 @@ std::shared_ptr PirInterpreter::GetWorkQueue() { void PirInterpreter::PrepareForCUDAGraphCapture() { if (!FLAGS_new_executor_use_cuda_graph) return; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_ENFORCE_EQ( platform::IsCUDAGraphCapturing(), false, @@ -362,7 +378,7 @@ void PirInterpreter::PrepareForCUDAGraphCapture() { void PirInterpreter::CheckCUDAGraphBeforeRun( const std::vector& feed_names) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ( feed_names.empty(), @@ -439,10 +455,12 @@ void PirInterpreter::UpdateNcclOpNum() { static std::set nccl_op_set = { "pd_op.c_softmax_with_cross_entropy", "pd_op.c_allgather", + "pd_op.c_allreduce_avg", "pd_op.c_allreduce_max", "pd_op.c_allreduce_min", "pd_op.c_allreduce_sum", "pd_op.c_allreduce_prod", + "pd_op.c_reduce_avg", "pd_op.c_reduce_max", "pd_op.c_reduce_min", "pd_op.c_reduce_prod", @@ -509,10 +527,12 @@ void PirInterpreter::UpdateNcclOpNum() { "pd_op.reduce_grad", "pd_op.c_softmax_with_cross_entropy_", "pd_op.c_allgather_", + "pd_op.c_allreduce_avg_", "pd_op.c_allreduce_max_", "pd_op.c_allreduce_min_", "pd_op.c_allreduce_sum_", "pd_op.c_allreduce_prod_", + "pd_op.c_reduce_avg_", "pd_op.c_reduce_max_", "pd_op.c_reduce_min_", "pd_op.c_reduce_prod_", @@ -702,9 +722,17 @@ void PirInterpreter::BuildInstruction() { continue; } } else if (op.dialect()->name() == "pd_op") { - if (op.isa()) { - vec_instruction_base_.emplace_back(std::make_unique( - op_idx++, place_, &op, value_exe_info_.get(), execution_config_)); + if (op.isa()) { // NOLINT + std::unique_ptr if_instr_ptr = + std::make_unique(op_idx++, + place_, + &op, + value_exe_info_.get(), + execution_config_); + if_instr_ptr->SetOutputHooks(pir_output_hookfuncs_); + if_instr_ptr->SetInputHooks(pir_input_hookfuncs_); + vec_instruction_base_.emplace_back(std::move(if_instr_ptr)); + sub_blocks_.insert( {&op.dyn_cast().true_block(), dynamic_cast(vec_instruction_base_.back().get()) @@ -722,8 +750,16 @@ void PirInterpreter::BuildInstruction() { vec_instruction_base_.back().get()) ->ForwardInterpreter()}); } else if (op.isa()) { - vec_instruction_base_.emplace_back(std::make_unique( - op_idx++, place_, &op, value_exe_info_.get(), execution_config_)); + std::unique_ptr while_instr_ptr = + std::make_unique(op_idx++, + place_, + &op, + value_exe_info_.get(), + execution_config_); + while_instr_ptr->SetOutputHooks(pir_output_hookfuncs_); + while_instr_ptr->SetInputHooks(pir_input_hookfuncs_); + vec_instruction_base_.emplace_back(std::move(while_instr_ptr)); + sub_blocks_.insert( {&op.dyn_cast().body(), dynamic_cast(vec_instruction_base_.back().get()) @@ -751,7 +787,7 @@ void PirInterpreter::BuildInstruction() { } VLOG(6) << "process " << op_name; - if (op.isa()) { + if (op.isa()) { // NOLINT CREATE_INSTR(LegacyKernelInstruction); } else { CREATE_INSTR(PhiKernelInstruction); @@ -861,7 +897,7 @@ std::string PirInterpreter::DebugValueInfo() { for (auto kv : value_exe_info_->GetValue2VarName()) { PADDLE_ENFORCE((bool)kv.first, platform::errors::PreconditionNotMet( - "vlaue(%s) should not be nullptr", kv.second)); + "var(%s) should not be nullptr", kv.second)); PADDLE_ENFORCE(value_exe_info_->HasVar(kv.second), platform::errors::PreconditionNotMet( "var(%s) should exist in var_name_2_id_", kv.second)); @@ -1720,7 +1756,7 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { try { instr_node->WaitEvent(cur_place); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (enable_job_schedule_profiler_) { std::string op_name = instr_node->Name(); ::pir::Operation* op = instr_node->Operation(); @@ -1731,6 +1767,9 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { } } #endif + + RecordLowPrecisionOp(instr_node); + VLOG(2) << "\nbegin: " << __func__ << " OP id:" << instr_node->Id() << " name:" << instr_node->Name() << " type:" << (instr_node->KernelType() == OpFuncType::kCpuSync @@ -1741,6 +1780,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { << " runs on " << platform::GetCurrentThreadName() << "\n" << "Before: " << cur_place << " " << instr_node->DebugStringEx(scope_, value_exe_info_.get()); + + if (execution_config_.used_for_inference) { + for (auto& hook : pir_input_hookfuncs_) { + hook(instr_node, value_exe_info_.get(), scope_); + } + } + if (!instr_node->IsArtificial()) { instr_node->Run(); @@ -1766,9 +1812,16 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { VLOG(4) << "done CheckGC"; memory::LogDeviceMemoryStats(cur_place, instr_node->Name()); } + + if (execution_config_.used_for_inference) { + for (auto& hook : pir_output_hookfuncs_) { + hook(instr_node, value_exe_info_.get(), scope_); + } + } + VLOG(5) << "after run kernel"; instr_node->RecordEvent(cur_place); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (enable_job_schedule_profiler_) { if (instr_node->Id() == last_calculate_instr_id_ && calculate_stream_timer_->IsStarted()) { @@ -1785,13 +1838,13 @@ void PirInterpreter::RunInstructionBase(InstructionBase* instr_node) { framework::InsertCallStackInfo(op->name(), op_callstack_attr, &ex); LOG(WARNING) << " OP id:" << instr_node->Id() << " " << instr_node->Name() << " raises an EnforceNotMet exception " - << platform::demangle(typeid(ex).name()) << ", " << ex.what(); + << platform::demangle(typeid(ex).name()); exception_holder_.Catch(std::make_exception_ptr(std::move(ex))); } catch (platform::EOFException&) { exception_holder_.Catch(std::current_exception()); } catch (std::exception& ex) { LOG(WARNING) << instr_node->Name() << " raises an exception " - << platform::demangle(typeid(ex).name()) << ", " << ex.what(); + << platform::demangle(typeid(ex).name()); exception_holder_.Catch(std::current_exception()); } catch (...) { LOG(WARNING) << instr_node->Name() << " raises an unknown exception"; diff --git a/paddle/fluid/framework/new_executor/pir_interpreter.h b/paddle/fluid/framework/new_executor/pir_interpreter.h index daf6351bb6723..9901dcf421cdc 100644 --- a/paddle/fluid/framework/new_executor/pir_interpreter.h +++ b/paddle/fluid/framework/new_executor/pir_interpreter.h @@ -18,7 +18,7 @@ #include "paddle/fluid/framework/new_executor/interpreter_base_impl.h" #include "paddle/pir/include/core/value.h" -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/phi/kernels/autotune/gpu_timer.h" #endif @@ -96,12 +96,16 @@ class PirInterpreter : public InterpreterBaseImpl { const platform::Place& GetPlace() const override { return place_; } - void SetOutputHooks(const std::vector& hookfuncs) override { - output_hookfuncs_ = hookfuncs; + void SetOutputHooks(const std::vector& hookfuncs) override {} + + void SetInputHooks(const std::vector& hookfuncs) override {} + + void SetOutputHooks(const std::vector& hookfuncs) override { + pir_output_hookfuncs_ = hookfuncs; } - void SetInputHooks(const std::vector& hookfuncs) override { - input_hookfuncs_ = hookfuncs; + void SetInputHooks(const std::vector& hookfuncs) override { + pir_input_hookfuncs_ = hookfuncs; } std::string GetNameByValue(::pir::Value value) const; @@ -200,8 +204,8 @@ class PirInterpreter : public InterpreterBaseImpl { int64_t onednn_op_num_{-1}; std::vector trace_execute_order_; - std::vector output_hookfuncs_; - std::vector input_hookfuncs_; + std::vector pir_output_hookfuncs_; + std::vector pir_input_hookfuncs_; /// ======================== /// /// For new ir /// @@ -274,7 +278,7 @@ class PirInterpreter : public InterpreterBaseImpl { // belongs to a parameter and cannot GC. std::unordered_set parameter_var_names_; -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::unique_ptr calculate_stream_timer_; #endif size_t last_calculate_instr_id_; diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 67a5c8c9d0b5b..8991fd9c3a22d 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -41,7 +41,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); #endif -COMMON_DECLARE_bool(enable_host_event_recorder_hook); +PHI_DECLARE_bool(enable_host_event_recorder_hook); PD_DECLARE_bool(log_memory_stats); COMMON_DECLARE_string(static_runtime_data_save_path); COMMON_DECLARE_bool(save_static_runtime_data); @@ -191,7 +191,7 @@ FetchList ProgramInterpreter::Run(const std::vector& feed_names, if (fetch_var) { auto fetch_list = std::move(*fetch_var->GetMutable()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ(fetch_list.empty(), true, @@ -269,7 +269,7 @@ FetchList ProgramInterpreter::Run( if (fetch_var) { auto fetch_list = std::move(*fetch_var->GetMutable()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ(fetch_list.empty(), true, @@ -533,7 +533,7 @@ void ProgramInterpreter::BuildInplace() { void ProgramInterpreter::PrepareForCUDAGraphCapture() { if (!FLAGS_new_executor_use_cuda_graph) return; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_ENFORCE_EQ( platform::IsCUDAGraphCapturing(), false, @@ -579,7 +579,7 @@ void ProgramInterpreter::PrepareForCUDAGraphCapture() { void ProgramInterpreter::CheckCUDAGraphBeforeRun( const std::vector& feed_names) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ( feed_names.empty(), @@ -862,7 +862,7 @@ void ProgramInterpreter::BuildOpFuncNode( auto& op_func_node = nodes[op_idx]; stream_analyzer_.SetForceEventsToWaitInfo(force_events_to_wait_); auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (FLAGS_new_executor_use_cuda_graph) { auto& op = op_func_node.operator_base_; auto& op_type = op->Type(); diff --git a/paddle/fluid/framework/new_executor/program_interpreter.h b/paddle/fluid/framework/new_executor/program_interpreter.h index 7e956249e22a3..94a8af8197d11 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.h +++ b/paddle/fluid/framework/new_executor/program_interpreter.h @@ -101,6 +101,10 @@ class ProgramInterpreter : public InterpreterBaseImpl { input_hookfuncs_ = hookfuncs; } + void SetOutputHooks(const std::vector& hookfuncs) override {} + + void SetInputHooks(const std::vector& hookfuncs) override {} + std::unordered_map>* GetForceEventsToWaitInfo() { return force_events_to_wait_; diff --git a/paddle/fluid/framework/new_executor/standalone_executor.cc b/paddle/fluid/framework/new_executor/standalone_executor.cc index 2bb0a7197774e..99d2b6a4b7fbc 100644 --- a/paddle/fluid/framework/new_executor/standalone_executor.cc +++ b/paddle/fluid/framework/new_executor/standalone_executor.cc @@ -23,7 +23,7 @@ #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" -#include "paddle/fluid/pir/transforms/inplace_pass.h" +#include "paddle/fluid/pir/transforms/general/inplace_pass.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_manager.h" @@ -57,7 +57,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, const std::string& job_type = job->Type(); std::shared_ptr program = nullptr; std::shared_ptr<::pir::Program> ir_program = nullptr; - if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { + if (FLAGS_enable_pir_api || FLAGS_enable_pir_in_executor) { // NOLINT ir_program = plan_.IrProgram(job_type); } else { // NOTE (liuchenghao): std::make_shared will duplicate ProgramDesc object, @@ -119,7 +119,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, shared_program->block(), micro_batch_scopes_[micro_batch_id], execution_config)); - // Note(lizhiyu): Add mannual event info + // Note(lizhiyu): Add manual event info auto pir_inter = const_cast( static_cast(interpretercores_.back()->Impl())); pir_inter->SetForceEventsToWaitInfo( @@ -132,7 +132,7 @@ StandaloneExecutor::StandaloneExecutor(const platform::Place& place, execution_config)); interpretercores_.back()->SetCopyProgram(program); - // Note(lizhiyu): Add mannual event info + // Note(lizhiyu): Add manual event info auto prog_inter = const_cast( static_cast( interpretercores_.back()->Impl())); diff --git a/paddle/fluid/framework/op_compatible_info.cc b/paddle/fluid/framework/op_compatible_info.cc index ba71043771ff2..203d177bba916 100644 --- a/paddle/fluid/framework/op_compatible_info.cc +++ b/paddle/fluid/framework/op_compatible_info.cc @@ -16,7 +16,7 @@ #include "paddle/common/macros.h" #include "paddle/fluid/platform/init_phi.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" REGISTER_FILE_SYMBOLS(op_compatible_info); @@ -68,42 +68,48 @@ inline bool CompareVersion(const std::string& str_first, } void OpCompatibleMap::InitOpCompatibleMap() { - op_compatible_map_["sequence_pad"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["sequence_unpad"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + op_compatible_map_["sequence_pad"] = {"1.6.0", + OpCompatibleType::definite_not}; + op_compatible_map_["sequence_unpad"] = {"1.6.0", + OpCompatibleType::definite_not}; op_compatible_map_["coalesce_tensor"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["crop_tensor"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; + op_compatible_map_["crop_tensor"] = {"1.6.0", OpCompatibleType::definite_not}; op_compatible_map_["deformable_conv"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; op_compatible_map_["deformable_conv_v1"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["dpsgd"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["eye"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["fill_any_like"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["hard_swish"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["gather_nd"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["instance_norm"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; + op_compatible_map_["dpsgd"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["eye"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["fill_any_like"] = {"1.6.0", + OpCompatibleType::definite_not}; + op_compatible_map_["hard_swish"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["gather_nd"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["instance_norm"] = {"1.6.0", + OpCompatibleType::definite_not}; op_compatible_map_["lookup_table_v2"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; op_compatible_map_["match_matrix_tensor"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; op_compatible_map_["multiclass_nms2"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["one_hot_v2"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; + op_compatible_map_["one_hot_v2"] = {"1.6.0", OpCompatibleType::definite_not}; op_compatible_map_["pull_box_sparse"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["scatter_nd_add"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["shard_index"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["size"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["strided_slice"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; + op_compatible_map_["scatter_nd_add"] = {"1.6.0", + OpCompatibleType::definite_not}; + op_compatible_map_["shard_index"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["size"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["strided_slice"] = {"1.6.0", + OpCompatibleType::definite_not}; op_compatible_map_["trilinear_interp"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["unfold"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["unique"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; + op_compatible_map_["unfold"] = {"1.6.0", OpCompatibleType::definite_not}; + op_compatible_map_["unique"] = {"1.6.0", OpCompatibleType::definite_not}; op_compatible_map_["unique_with_counts"] = {"1.6.0", - OpCompatibleType::DEFIN_NOT}; - op_compatible_map_["var_conv_2d"] = {"1.6.0", OpCompatibleType::DEFIN_NOT}; + OpCompatibleType::definite_not}; + op_compatible_map_["var_conv_2d"] = {"1.6.0", OpCompatibleType::definite_not}; op_compatible_map_["reshape2"] = {"1.6.0", OpCompatibleType::possible}; op_compatible_map_["slice"] = {"1.6.0", OpCompatibleType::possible}; @@ -156,7 +162,7 @@ CompatibleInfo OpCompatibleMap::GetOpCompatibleInfo(std::string op_name) const { if (it != op_compatible_map_.end()) { return it->second; } else { - return {default_required_version_, OpCompatibleType::DEFIN_NOT}; + return {default_required_version_, OpCompatibleType::definite_not}; } } @@ -174,7 +180,7 @@ OpCompatibleType OpCompatibleMap::IsRequireMiniVersion( if (CompareVersion(str_current_version, default_required_version_)) { return OpCompatibleType::compatible; } else { - return OpCompatibleType::DEFIN_NOT; + return OpCompatibleType::definite_not; } } } diff --git a/paddle/fluid/framework/op_compatible_info.h b/paddle/fluid/framework/op_compatible_info.h index 6f86b8b64ed21..7256a92b5b457 100644 --- a/paddle/fluid/framework/op_compatible_info.h +++ b/paddle/fluid/framework/op_compatible_info.h @@ -28,7 +28,7 @@ class OpCompatibleMap; enum class OpCompatibleType { compatible = 0, // support previous version - DEFIN_NOT = 1, // definitely can't support previous version + definite_not = 1, // definitely can't support previous version possible = 2, // possible can support previous version, not sure bug_fix = 3, // bug fix, can't support previous version precision_change = 4 // precision change, may cause difference diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 99ccbbe50d241..fe10a16375f34 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -65,7 +65,7 @@ PD_DECLARE_bool(benchmark); COMMON_DECLARE_bool(check_nan_inf); PD_DECLARE_bool(enable_unused_var_check); COMMON_DECLARE_bool(run_kp_kernel); -COMMON_DECLARE_bool(enable_host_event_recorder_hook); +PHI_DECLARE_bool(enable_host_event_recorder_hook); namespace paddle { namespace framework { @@ -96,6 +96,12 @@ static DDim GetDimsDebug(const Scope& scope, } } else if (var->IsType()) { return DDim({static_cast(var->Get().size())}); + } else if (var->IsType()) { + const phi::SparseCooTensor& tensor = var->Get(); + return tensor.dims(); + } else if (var->IsType()) { + const phi::SparseCsrTensor& tensor = var->Get(); + return tensor.dims(); } else { return DDim({-1}); } @@ -128,6 +134,18 @@ static std::string GetDtype(const Scope& scope, const std::string& name) { } } else if (var->IsType()) { return "strings"; + } else if (var->IsType()) { + const phi::SparseCooTensor& tensor = var->Get(); + if (UNLIKELY(!tensor.initialized())) { + return ""; + } + return DataTypeToString(framework::TransToProtoVarType(tensor.dtype())); + } else if (var->IsType()) { + const phi::SparseCsrTensor& tensor = var->Get(); + if (UNLIKELY(!tensor.initialized())) { + return ""; + } + return DataTypeToString(framework::TransToProtoVarType(tensor.dtype())); } else { return ""; } @@ -1001,7 +1019,7 @@ OperatorBase::OperatorBase(const std::string& type, // as Input. for (auto& attr : FilterAttrVar(attrs)) { VLOG(3) << "found Attribute with Variable type: " << attr.first; - inputs_[attr.first] = std::move(AttrVarNames(attr.second)); + inputs_[attr.first] = AttrVarNames(attr.second); attrs_.erase(attr.first); } } @@ -1704,6 +1722,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, all_kernels_must_compute_runtime_shape_ = true; const Scope* cur_scope = &scope; CheckWhetherPreparePhiData(Inputs(), Outputs(), scope); +#if defined(PADDLE_WITH_XPU) + if (std::getenv("XPU_NEED_PREPARE_PHI_DATA") != nullptr) { + need_prepare_phi_data_ = atoi(std::getenv("XPU_NEED_PREPARE_PHI_DATA")); + } +#endif if (!enable_cache_runtime_context_) { RuntimeContext ctx(Inputs(), Outputs(), scope); RunImpl(scope, place, &ctx); @@ -1754,12 +1777,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, std::string phi_kernel_name; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) { - if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) { + if (phi::KernelFactory::Instance().HasStructuredKernel( + type_)) { // NOLINT kernel_signature_ = std::make_unique(type_.c_str()); } else { kernel_signature_ = std::make_unique( - std::move(GetExpectedPhiKernelArgs(exe_ctx))); + GetExpectedPhiKernelArgs(exe_ctx)); } VLOG(6) << *kernel_signature_.get(); @@ -1989,7 +2013,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, 1, platform::EventRole::kInnerOp); if (need_prepare_data_) { - if (fallback_to_cpu) { + if (fallback_to_cpu) { // NOLINT transfer_scope = PrepareData(scope, phi_cpu_kernel_key, &transfered_inplace_vars, @@ -2037,7 +2061,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, phi::KernelContext phi_kernel_context; if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && !need_prepare_data_) { - // TODO(inference): Now we only suppor dense_tensor cache, we may be + // TODO(inference): Now we only support dense_tensor cache, we may be // support ScalarTensor, SparseTensor in future. bool all_dense_tensor_input_{true}; for (auto& iter : Inputs()) { @@ -2278,11 +2302,11 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( phi::KernelKey OperatorWithKernel::ChoosePhiKernel( const ExecutionContext& ctx) const { std::string phi_kernel_name; - if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) { + if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) { // NOLINT kernel_signature_ = std::make_unique(type_.c_str()); } else { - kernel_signature_ = std::make_unique( - std::move(GetExpectedPhiKernelArgs(ctx))); + kernel_signature_ = + std::make_unique(GetExpectedPhiKernelArgs(ctx)); } VLOG(6) << *kernel_signature_.get(); phi_kernel_name = kernel_signature_->name; @@ -2572,7 +2596,7 @@ Scope* OperatorWithKernel::PrepareData( // for some situation like InferShape(). // In this situation We cannot skip Var analysis, as // oneDNN shape of Var may differ from kNHWC Var - // In such situation corressponding resized Var + // In such situation corresponding resized Var // has to be created and registered if ((tensor_in->layout() == DataLayout::ONEDNN) && (var->IsType() == true) && @@ -3104,7 +3128,7 @@ static void SetDnnAttrIntoDeviceContext( case proto::AttrType::STRING: one_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(std::string, attr)); break; - case proto::AttrType::INTS: + case proto::AttrType::INTS: // NOLINT one_dnn_ctx->SetDnnAttr(attr_name, PADDLE_GET_CONST(std::vector, attr)); break; @@ -3192,7 +3216,7 @@ void OperatorWithKernel::BuildPhiKernelContext( for (size_t i = 0; i < input_names.size(); ++i) { auto it = ctx.inputs.find(input_names[i]); - // calcute the start and end index of the input tensors + // calculate the start and end index of the input tensors size_t start_idx = (i == 0 ? 0 : phi_kernel_context->InputRangeAt(i - 1).second); // deal with optional here @@ -3352,27 +3376,27 @@ void OperatorWithKernel::BuildPhiKernelContext( need_prepare_phi_data_ = true; auto& ins_vector = ctx.inputs.at(attr_names[i]); phi_kernel_context->EmplaceBackAttr( - std::move(framework::MakePhiScalarFromVar(*ins_vector.front()))); + framework::MakePhiScalarFromVar(*ins_vector.front())); } break; case phi::AttributeType::INT_ARRAY: if (attr_iter != Attrs().end()) { switch (AttrTypeID(attr_iter->second)) { - case proto::AttrType::INTS: - phi_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( - PADDLE_GET_CONST(std::vector, attr_iter->second)))); + case proto::AttrType::INTS: // NOLINT + phi_kernel_context->EmplaceBackAttr(phi::IntArray( + PADDLE_GET_CONST(std::vector, attr_iter->second))); break; case proto::AttrType::LONGS: - phi_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( - PADDLE_GET_CONST(std::vector, attr_iter->second)))); + phi_kernel_context->EmplaceBackAttr(phi::IntArray( + PADDLE_GET_CONST(std::vector, attr_iter->second))); break; case proto::AttrType::INT: - phi_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( - &PADDLE_GET_CONST(int32_t, attr_iter->second), 1))); + phi_kernel_context->EmplaceBackAttr(phi::IntArray( + &PADDLE_GET_CONST(int32_t, attr_iter->second), 1)); break; case proto::AttrType::LONG: - phi_kernel_context->EmplaceBackAttr(std::move(phi::IntArray( - &PADDLE_GET_CONST(int64_t, attr_iter->second), 1))); + phi_kernel_context->EmplaceBackAttr(phi::IntArray( + &PADDLE_GET_CONST(int64_t, attr_iter->second), 1)); break; default: PADDLE_THROW(platform::errors::Unimplemented( @@ -3384,11 +3408,11 @@ void OperatorWithKernel::BuildPhiKernelContext( need_prepare_phi_data_ = true; auto& ins_vector = ctx.inputs.at(attr_names[i]); if (ins_vector.size() == 1) { // ShapeTensor - phi_kernel_context->EmplaceBackAttr(std::move( - framework::MakePhiIntArrayFromVar(*ins_vector.front()))); + phi_kernel_context->EmplaceBackAttr( + framework::MakePhiIntArrayFromVar(*ins_vector.front())); } else { // ShapeTensorList phi_kernel_context->EmplaceBackAttr( - std::move(framework::MakePhiIntArrayFromVarList(ins_vector))); + framework::MakePhiIntArrayFromVarList(ins_vector)); } } break; @@ -3398,7 +3422,7 @@ void OperatorWithKernel::BuildPhiKernelContext( attr_iter, Attrs().end(), platform::errors::NotFound("(%s) is not found in AttributeMap when " - "buildind static KernelContext.", + "building static KernelContext.", attr_names[i])); switch (AttrTypeID(attr_iter->second)) { case proto::AttrType::INTS: { @@ -3472,7 +3496,7 @@ void OperatorWithKernel::BuildPhiKernelContext( RuntimeAttrs().end(), platform::errors::NotFound( "(%s) is not found in AttributeMap when " - "buildind static KernelContext.", + "building static KernelContext.", attr_names[i])); } @@ -3497,7 +3521,7 @@ void OperatorWithKernel::BuildPhiKernelContext( phi_kernel_context->EmplaceBackAttr( PADDLE_GET_CONST(int64_t, attr_iter->second)); break; - case phi::AttributeType::INT32S: + case phi::AttributeType::INT32S: // NOLINT phi_kernel_context->EmplaceBackAttr( PADDLE_GET_CONST(std::vector, attr_iter->second)); break; @@ -3536,7 +3560,7 @@ void OperatorWithKernel::BuildPhiKernelContext( attr_names[i])); } break; - case phi::AttributeType::FLOAT32S: + case phi::AttributeType::FLOAT32S: // NOLINT phi_kernel_context->EmplaceBackAttr( PADDLE_GET_CONST(std::vector, attr_iter->second)); break; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index b25ebd671ea31..fc25f26692682 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -49,9 +49,9 @@ #include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/operators/cinn/cinn_launch_context.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" +#include "paddle/utils/string/string_helper.h" COMMON_DECLARE_bool(enable_pe_launch_cinn); COMMON_DECLARE_bool(enable_cinn_auto_tune); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc index dc36f40d9c6a3..c5a838bc66f8f 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc @@ -169,11 +169,11 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr subgraph_ptr) { if (!consumer->substitute) { continue; } - // fast depency check. + // fast dependency check. if (IsDependencySimplify(producer, consumer, consumers)) { continue; } - // global depency check. + // global dependency check. if (IsDependency(producer, consumer, consumers)) { continue; } @@ -196,7 +196,7 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr subgraph_ptr) { producer->node_set.insert(candidate->node_set.begin(), candidate->node_set.end()); - // update bound for check depency + // update bound for check dependency producer->max_depth = std::max(producer->max_depth, candidate->max_depth); producer->min_depth = std::min(producer->min_depth, candidate->min_depth); @@ -219,7 +219,7 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr subgraph_ptr) { tmp->producers.erase(candidate); } - // remove candicate in producer/consumer + // remove candidate in producer/consumer producer->producers.erase(candidate); producer->consumers.erase(candidate); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h index e8ff3915c8511..7b02761b9e855 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h @@ -78,7 +78,7 @@ class CinnSubgraphDetector { // SubGraph Fusion void DoSubGraphFusion(); bool FuseSubGraph(CinnSubGraphPtr); - // check exist depency. + // check exist dependency. bool IsDependency(const CinnSubGraphPtr &, const CinnSubGraphPtr &, const std::unordered_set &); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 897e520813809..ccf2b718e535e 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -639,15 +639,15 @@ void InitP2P(const std::vector &places) { for (int i = 0; i < count; ++i) { for (int j = 0; j < count; ++j) { if (devices[i] == devices[j]) continue; - int can_acess = -1; + int can_access = -1; #ifdef PADDLE_WITH_HIP hipError_t ret = - hipDeviceCanAccessPeer(&can_acess, devices[i], devices[j]); - if (ret != hipSuccess || can_acess != 1) { + hipDeviceCanAccessPeer(&can_access, devices[i], devices[j]); + if (ret != hipSuccess || can_access != 1) { #else cudaError_t ret = - cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]); - if (ret != cudaSuccess || can_acess != 1) { + cudaDeviceCanAccessPeer(&can_access, devices[i], devices[j]); + if (ret != cudaSuccess || can_access != 1) { #endif LOG(WARNING) << "Cannot enable P2P access from " << devices[i] << " to " << devices[j]; @@ -1416,7 +1416,7 @@ void ParallelExecutor::PreludeToRun( platform::RecordEvent record_run( "ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); VLOG(3) << "enter ParallelExecutor Run"; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ(fetch_tensors.empty(), true, @@ -1804,7 +1804,7 @@ const ir::Graph &ParallelExecutor::Graph() const { void ParallelExecutor::PrepareForCUDAGraphCapture(ir::Graph *graph) { const auto &build_strategy = member_->build_strategy_; if (!build_strategy.allow_cuda_graph_capture_) return; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_ENFORCE_EQ( build_strategy.async_mode_, false, diff --git a/paddle/fluid/framework/phi_utils.cc b/paddle/fluid/framework/phi_utils.cc index 15727db9d0f5d..4b683f918009a 100644 --- a/paddle/fluid/framework/phi_utils.cc +++ b/paddle/fluid/framework/phi_utils.cc @@ -20,12 +20,12 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/selected_rows_utils.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/type_defs.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { @@ -243,7 +243,7 @@ void InitDefaultKernelSignatureMap() { paddle::framework::KernelArgsNameMakerByOpProto maker(op_proto); VLOG(10) << "Register `" << op_type << "` kernel signature:"; phi::DefaultKernelSignatureMap::Instance().Insert( - op_type, std::move(maker.GetKernelSignature())); + op_type, maker.GetKernelSignature()); } } }); diff --git a/paddle/fluid/framework/program_converter.cc b/paddle/fluid/framework/program_converter.cc index 48d45277dfffd..83bfdb264e681 100644 --- a/paddle/fluid/framework/program_converter.cc +++ b/paddle/fluid/framework/program_converter.cc @@ -282,7 +282,7 @@ void ConvertAssignValueOp(OpDesc* op) { } op->RemoveAttr("int64_values"); } - op->SetAttr("values", values); + if (!values.empty()) op->SetAttr("values", values); } void ConvertProgram(ProgramDesc* program) { diff --git a/paddle/fluid/framework/program_desc.cc b/paddle/fluid/framework/program_desc.cc index baf50d275c89f..512cdd9b38769 100644 --- a/paddle/fluid/framework/program_desc.cc +++ b/paddle/fluid/framework/program_desc.cc @@ -78,8 +78,8 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { // record all block desc's ptr from origin program old_block_desc.emplace_back(o.blocks_[i].get()); } - for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { - auto all_ops = blocks_[block_id]->AllOps(); + for (size_t block_id = 0; block_id < blocks_.size(); ++block_id) { // NOLINT + auto all_ops = blocks_[block_id]->AllOps(); // NOLINT for (size_t op_id = 0; op_id < all_ops.size(); ++op_id) { auto &op = all_ops[op_id]; @@ -92,7 +92,7 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { block_desc) != old_block_desc.end()) { // The block is owned by the origin program. Just use id to get // the corresponding block. - int sub_block_id = o.Block(block_id) + int sub_block_id = o.Block(block_id) // NOLINT .Op(static_cast(op_id)) ->GetBlockAttrId(attr_name); op->SetBlockAttr(attr_name, MutableBlock(sub_block_id)); @@ -103,7 +103,7 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { op->SetBlockAttr(attr_name, block_desc); } } else if (op->GetAttrType(attr_name) == proto::AttrType::BLOCKS) { - std::vector sub_block_ids = o.Block(block_id) + std::vector sub_block_ids = o.Block(block_id) // NOLINT .Op(static_cast(op_id)) ->GetBlocksAttrIds(attr_name); std::vector block_descs; @@ -114,19 +114,20 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) { } else if (op->GetAttrType(attr_name, true) == proto::AttrType::VAR) { VarDesc *var_desc = PADDLE_GET_CONST(VarDesc *, op->GetAttr(attr_name, true)); - op->SetVarAttr(attr_name, - o.Block(block_id).FindVarRecursive(var_desc->Name())); + op->SetVarAttr( + attr_name, + o.Block(block_id).FindVarRecursive(var_desc->Name())); // NOLINT } else if (op->GetAttrType(attr_name, true) == proto::AttrType::VARS) { std::vector vars_desc = PADDLE_GET_CONST( std::vector, op->GetAttr(attr_name, true)); std::vector new_vars_desc; - std::transform( - vars_desc.begin(), - vars_desc.end(), - std::back_inserter(new_vars_desc), - [&](VarDesc *var_desc) { - return o.Block(block_id).FindVarRecursive(var_desc->Name()); - }); + std::transform(vars_desc.begin(), + vars_desc.end(), + std::back_inserter(new_vars_desc), + [&](VarDesc *var_desc) { + return o.Block(block_id).FindVarRecursive( + var_desc->Name()); // NOLINT + }); op->SetVarsAttr(attr_name, new_vars_desc); } } diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index 4cc03b95abc52..b0649563d8f9e 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/isfinite_op.h" #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/lodtensor_printer.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ defined PADDLE_WITH_XPU_BKCL) && \ diff --git a/paddle/fluid/framework/reader.h b/paddle/fluid/framework/reader.h index f926829dc9bd4..8aef207f5da32 100644 --- a/paddle/fluid/framework/reader.h +++ b/paddle/fluid/framework/reader.h @@ -48,15 +48,15 @@ class ReaderBase { "and need_check_feed")); } - virtual void ReadNext(paddle::framework::LoDTensorArray* out); + TEST_API virtual void ReadNext(paddle::framework::LoDTensorArray* out); - virtual void Shutdown(); + TEST_API virtual void Shutdown(); - virtual void Start(); + TEST_API virtual void Start(); // Return the readers which are the end of decorating chain. Basically // they are readers just before read op. - std::unordered_set GetEndPoints(); + TEST_API std::unordered_set GetEndPoints(); // Returns the shapes of the fed variables const std::vector& Shapes() const { return shapes_; } @@ -70,7 +70,7 @@ class ReaderBase { // This function returns whether you have the check shape for this Reader. const std::vector& NeedCheckFeed() const { return need_check_feed_; } - virtual ~ReaderBase(); + TEST_API virtual ~ReaderBase(); protected: virtual void ReadNextImpl(paddle::framework::LoDTensorArray* out UNUSED) {} @@ -98,7 +98,7 @@ class ReaderBase { friend class DecoratedReader; // These methods can be only invoked inside DecoratedReader to record the // decorating chain. - void InsertDecoratedReader( + TEST_API void InsertDecoratedReader( const std::shared_ptr& decorated_reader); // A set of which readers that decorated this reader. std::vector> decorated_readers_; @@ -121,7 +121,7 @@ class DecoratedReader : public ReaderBase, reader_->InsertDecoratedReader(shared_from_this()); } - ~DecoratedReader(); + TEST_API ~DecoratedReader(); const std::shared_ptr& UnderlyingReader() const { return reader_; diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 91d24cc70552c..19e09ab5edf8d 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -238,7 +238,7 @@ void SectionWorker::TrainFiles() { #endif } // max_memory_size >= 0 - if (schedule_mode_ == 0) { + if (schedule_mode_ == 0) { // NOLINT RunFThenB(gc); } else { Run1F1B(gc); diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index 49603b34255db..427d4be4558e9 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -78,13 +78,14 @@ class InferShapeContext { virtual DDim GetInputDim(const std::string &name) const = 0; virtual std::vector GetInputsDim(const std::string &name) const = 0; - virtual std::vector GetReaderDims(const std::string &name) const; + TEST_API virtual std::vector GetReaderDims( + const std::string &name) const; virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; virtual void SetOutputsDim(const std::string &name, const std::vector &dims) = 0; - virtual void SetReaderDims(const std::string &name, - const std::vector &dims); + TEST_API virtual void SetReaderDims(const std::string &name, + const std::vector &dims); virtual std::string GetInputNameByIdx(size_t idx) const = 0; virtual std::string GetOutputNameByIdx(size_t idx) const = 0; virtual AttrReader Attrs() const = 0; diff --git a/paddle/fluid/framework/string_array.cc b/paddle/fluid/framework/string_array.cc index 07e3f07294fae..e701a423abd82 100644 --- a/paddle/fluid/framework/string_array.cc +++ b/paddle/fluid/framework/string_array.cc @@ -47,7 +47,7 @@ void NFD(const std::string& s, std::string* ret) { char* result = reinterpret_cast( utf8proc_NFD(reinterpret_cast(s.c_str()))); if (result) { - *ret = std::move(std::string(result)); + *ret = std::string(result); free(result); // NOLINT } } diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index fafde716b7bba..bd869a0588067 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -710,8 +710,9 @@ void TensorFromStream(std::istream& is, PADDLE_THROW(platform::errors::Unimplemented( "XPUPlace is not supported when not compiled with XPU")); } else { - PADDLE_THROW(platform::errors::Unimplemented( - "CutomPlace is not supported when not compiled with CustomDevice")); + PADDLE_THROW( + platform::errors::Unimplemented("CustomPlace is not supported when " + "not compiled with CustomDevice")); } #endif } else { @@ -887,7 +888,8 @@ std::ostream& print_tensor(std::ostream& os, const phi::DenseTensor& tensor) { auto element_num = tensor.numel(); os << " - data: ["; - // Note: int8_t && uint8_t is typedf of char, ostream unable to print properly + // Note: int8_t && uint8_t is typedef of char, ostream unable to print + // properly if (typeid(int8_t) == typeid(T) || typeid(uint8_t) == typeid(T)) { if (element_num > 0) { os << signed(inspect[0]); diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index 96f3d71c132af..1e65c5f163584 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -53,12 +53,12 @@ class PrintOptions { PrintOptions() {} }; -void TensorToStream(std::ostream& os, - const phi::DenseTensor& tensor, - const platform::DeviceContext& dev_ctx); -void TensorFromStream(std::istream& is, - phi::DenseTensor* tensor, - const platform::DeviceContext& dev_ctx); +TEST_API void TensorToStream(std::ostream& os, + const phi::DenseTensor& tensor, + const platform::DeviceContext& dev_ctx); +TEST_API void TensorFromStream(std::istream& is, + phi::DenseTensor* tensor, + const platform::DeviceContext& dev_ctx); void TensorFromStream(std::istream& is, phi::DenseTensor* tensor, const platform::DeviceContext& dev_ctx, @@ -103,11 +103,12 @@ void TensorToVector(const phi::DenseTensor& src, const platform::DeviceContext& ctx, std::vector* dst); template -void TesnorToVector(const phi::DenseTensor& src, std::vector* dst); +void TensorToVector(const phi::DenseTensor& src, std::vector* dst); // convert dlpack's DLTensor to tensor -void TensorFromDLPack(const ::DLTensor& dl_tensor, phi::DenseTensor* dst); +TEST_API void TensorFromDLPack(const ::DLTensor& dl_tensor, + phi::DenseTensor* dst); void TensorFromDLPack(const DLManagedTensor* src, phi::DenseTensor* dst); // diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index af7fc63a2122a..97857781fa6c2 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -34,7 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/operators/reader/blocking_queue.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index ba5dac4830aa1..81b2df6efc723 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -26,8 +26,8 @@ namespace framework { class TrainerBase; -typedef std::shared_ptr (*CreatetrainerFunction)(); -typedef std::unordered_map trainerMap; +typedef std::shared_ptr (*CreateTrainerFunction)(); +typedef std::unordered_map trainerMap; trainerMap g_trainer_map; #define REGISTER_TRAINER_CLASS(trainer_class) \ diff --git a/paddle/fluid/framework/var_type_traits.h b/paddle/fluid/framework/var_type_traits.h index 9bffd125a3f3d..3751118915e9a 100644 --- a/paddle/fluid/framework/var_type_traits.h +++ b/paddle/fluid/framework/var_type_traits.h @@ -97,8 +97,8 @@ namespace paddle { namespace framework { TEST_API const char *ToTypeName(int var_id); -const std::type_index &VarTraitIdToTypeIndex(int var_id); -int TypeIndexToVarTraitId(const std::type_index &type); +TEST_API const std::type_index &VarTraitIdToTypeIndex(int var_id); +TEST_API int TypeIndexToVarTraitId(const std::type_index &type); namespace detail { diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 86688213ef186..31ab7e1b1bcaa 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -52,7 +52,6 @@ cc_library( variable_helper op_registry var_helper) -add_subdirectory(jit) if(WITH_GPU) cc_library( layout_autotune @@ -73,7 +72,6 @@ cc_library( SRCS tracer.cc DEPS layer engine - program_desc_tracer amp denormal garbage_collector diff --git a/paddle/fluid/imperative/all_reduce.cc b/paddle/fluid/imperative/all_reduce.cc index c4bb42e4c22bb..f86bce962e021 100644 --- a/paddle/fluid/imperative/all_reduce.cc +++ b/paddle/fluid/imperative/all_reduce.cc @@ -32,7 +32,7 @@ #include "paddle/fluid/imperative/parallel_context.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace imperative { diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 50df994014004..c2aab61851fb5 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -185,7 +185,7 @@ AmpOperators::GetMutableUnsupportedOps(const phi::DataType& data_type) { true, phi::errors::InvalidArgument( "The data_type mismatch. It should be FLOAT16 or BFLOAT16.")); - if (data_type == phi::DataType::FLOAT16) { + if (data_type == phi::DataType::FLOAT16) { // NOLINT return unsupported_fp16_ops_; } else { return unsupported_bf16_ops_; @@ -375,7 +375,8 @@ template NameVarMap AutoCastInputs(const std::string& op_type, const NameVarMap& ins) { NameVarMap new_ins(ins); - if (AmpOperators::Instance().GetMutableAllowOps()->count(op_type)) { + if (AmpOperators::Instance().GetMutableAllowOps()->count( + op_type)) { // NOLINT for (auto& pair : new_ins) { // NOTE(zhiqiu): batch_norm and layer_norm support only input x is fp16. if ((op_type == "batch_norm" || op_type == "layer_norm" || diff --git a/paddle/fluid/imperative/amp_utils.h b/paddle/fluid/imperative/amp_utils.h index 37dcd48359e34..3b961e5960c81 100644 --- a/paddle/fluid/imperative/amp_utils.h +++ b/paddle/fluid/imperative/amp_utils.h @@ -58,7 +58,7 @@ static inline phi::DataType GetPromoteType( "float16") { if (op_name == "fused_attention") { for (size_t i = 0; i < amp_tensors_vector.size(); i++) { - if (i != 3 || i != 4 || i != 9 || i != 10) { + if (i < 3 || (i > 4 && i < 9) || i > 10) { if (GetDataType(amp_tensors_vector[i][0]) == phi::DataType::FLOAT32) { dst_type = phi::DataType::FLOAT32; return dst_type; @@ -67,7 +67,7 @@ static inline phi::DataType GetPromoteType( } } else if (op_name == "fused_feedforward") { for (size_t i = 0; i < amp_tensors_vector.size(); i++) { - if (i != 7 || i != 8 || i != 9 || i != 10) { + if (i < 7 || i > 10) { if (GetDataType(amp_tensors_vector[i][0]) == phi::DataType::FLOAT32) { dst_type = phi::DataType::FLOAT32; return dst_type; diff --git a/paddle/fluid/imperative/bkcl_context.cc b/paddle/fluid/imperative/bkcl_context.cc index 7d6dace21cca2..328cd2bceeffd 100644 --- a/paddle/fluid/imperative/bkcl_context.cc +++ b/paddle/fluid/imperative/bkcl_context.cc @@ -27,8 +27,8 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/split.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace imperative { diff --git a/paddle/fluid/imperative/gloo_context.cc b/paddle/fluid/imperative/gloo_context.cc index 4e0df45e840f2..00e0fdb1b4ee7 100644 --- a/paddle/fluid/imperative/gloo_context.cc +++ b/paddle/fluid/imperative/gloo_context.cc @@ -19,8 +19,8 @@ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/split.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/imperative/gradient_accumulator.cc b/paddle/fluid/imperative/gradient_accumulator.cc index 8f4dfbbcdc977..d9c91a4c6b0a0 100644 --- a/paddle/fluid/imperative/gradient_accumulator.cc +++ b/paddle/fluid/imperative/gradient_accumulator.cc @@ -518,7 +518,7 @@ void VariableWrapperAdd(std::shared_ptr var, static platform::Place GetPlaceOfVar( const std::shared_ptr& var) { platform::Place place; - if (var->Var().IsType()) { + if (var->Var().IsType()) { // NOLINT place = var->Var().Get().place(); } else if (var->Var().IsType()) { place = var->Var().Get().place(); @@ -735,7 +735,7 @@ void SortedGradientAccumulator::SumGrad(std::shared_ptr var, } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (paddle::platform::is_gpu_place(place)) { + if (paddle::platform::is_gpu_place(place)) { // NOLINT // sum selected rows firstly for (auto& var_info : tmp_grad_vars_) { if (!var_info.var->Var().IsType()) { diff --git a/paddle/fluid/imperative/heter_ccl_context.cc b/paddle/fluid/imperative/heter_ccl_context.cc index 3f7f39c3f9002..37929dc6e9c8f 100644 --- a/paddle/fluid/imperative/heter_ccl_context.cc +++ b/paddle/fluid/imperative/heter_ccl_context.cc @@ -24,8 +24,8 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/split.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/imperative/jit/CMakeLists.txt b/paddle/fluid/imperative/jit/CMakeLists.txt deleted file mode 100644 index bcc1c0746b823..0000000000000 --- a/paddle/fluid/imperative/jit/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -cc_library( - op_desc_meta - SRCS op_desc_meta.cc - DEPS proto_desc layer) -cc_library( - program_desc_tracer - SRCS program_desc_tracer.cc - DEPS op_desc_meta) diff --git a/paddle/fluid/imperative/jit/op_desc_meta.cc b/paddle/fluid/imperative/jit/op_desc_meta.cc deleted file mode 100644 index 1488f999bca9b..0000000000000 --- a/paddle/fluid/imperative/jit/op_desc_meta.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/imperative/jit/op_desc_meta.h" - -namespace paddle { -namespace imperative { -namespace jit { - -OpDescMeta::OpDescMeta(const std::string &type, - const NameVarBaseMap &inputs, - const NameVarBaseMap &outputs, - const framework::AttributeMap &attrs) - : type_(type), attrs_(attrs) { - auto *proto = framework::OpInfoMap::Instance().GetNullable(type_); - if (proto && proto->Checker()) { - proto->Checker()->Check(&attrs_); - } - - for (auto &pair : inputs) { - inputs_[pair.first].assign(pair.second.begin(), pair.second.end()); - } - - for (auto &pair : outputs) { - outputs_[pair.first].assign(pair.second.begin(), pair.second.end()); - } -} - -const std::string &OpDescMeta::Type() const { return type_; } - -const WeakNameVarBaseMap &OpDescMeta::Inputs() const { return inputs_; } - -const WeakNameVarBaseMap &OpDescMeta::Outputs() const { return outputs_; } - -const framework::AttributeMap &OpDescMeta::Attrs() const { return attrs_; } - -} // namespace jit -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/jit/op_desc_meta.h b/paddle/fluid/imperative/jit/op_desc_meta.h deleted file mode 100644 index c0463a628683b..0000000000000 --- a/paddle/fluid/imperative/jit/op_desc_meta.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/imperative/layer.h" -#include "paddle/fluid/imperative/type_defs.h" - -namespace paddle { -namespace imperative { -namespace jit { - -class OpDescMeta { - public: - OpDescMeta(const std::string &type, - const NameVarBaseMap &inputs, - const NameVarBaseMap &outputs, - const framework::AttributeMap &attrs); - - const std::string &Type() const; - - const WeakNameVarBaseMap &Inputs() const; - - const WeakNameVarBaseMap &Outputs() const; - - const framework::AttributeMap &Attrs() const; - - private: - std::string type_; - WeakNameVarBaseMap inputs_; - WeakNameVarBaseMap outputs_; - framework::AttributeMap attrs_; -}; - -} // namespace jit -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.cc b/paddle/fluid/imperative/jit/program_desc_tracer.cc deleted file mode 100644 index 86a38f3942aaa..0000000000000 --- a/paddle/fluid/imperative/jit/program_desc_tracer.cc +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/imperative/jit/program_desc_tracer.h" - -#include "paddle/fluid/framework/convert_utils.h" - -namespace paddle { -namespace imperative { -class VarBase; -} // namespace imperative -} // namespace paddle - -namespace paddle { -namespace imperative { -namespace jit { - -// A helper class to generate unique name for each non-persistable var -class UniqueBlockVarGenerator { - public: - UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, - const VarBaseSet &non_exist_input_vars, - framework::BlockDesc *block); - - std::string NameOf(const std::weak_ptr &var, - const std::string &prefix); - - private: - void InsertNewVarInBlock(const std::weak_ptr &var, - const framework::VarDesc &ref_desc, - const std::string &name, - bool force_persistable = false); - - private: - const VarDescMetaMap &all_vars_; - framework::BlockDesc *block_; - std::unordered_map counter_; - - std::map, - std::string, - std::owner_less>> - var_to_name_; - std::unordered_set existing_names_; -}; - -UniqueBlockVarGenerator::UniqueBlockVarGenerator( - const VarDescMetaMap &all_vars, - const VarBaseSet &non_exist_input_vars, - framework::BlockDesc *block) - : all_vars_(all_vars), block_(block) { - for (auto &var_pair : all_vars_) { - auto *var_desc = var_pair.second.get(); - if (var_desc->Persistable()) { - InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name()); - } else if (non_exist_input_vars.count(var_pair.first.lock()) > 0) { - VLOG(10) << "Mark " << var_desc->Name() << " as persistable"; - InsertNewVarInBlock(var_pair.first, - *var_desc, - var_desc->Name(), - /*force_persistable=*/true); - } - } -} - -std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr &var, - const std::string &prefix) { - VLOG(3) << "Finding: " << var.lock()->Name(); - auto all_vars_iter = all_vars_.find(var); - PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), - true, - platform::errors::NotFound( - "Variable is not found in UniqueBlockVarGenerator")); - - auto iter = var_to_name_.find(var); - if (iter != var_to_name_.end()) { - VLOG(5) << "Return existing var name " << iter->second; - return iter->second; - } else { - auto generate_unique_name = [this, &prefix] { - auto &cnt = counter_[prefix]; - do { - auto name = prefix + std::to_string(cnt++); - if (existing_names_.count(name) == 0) { - return name; - } - } while (cnt > 0); - PADDLE_THROW( - platform::errors::OutOfRange("Too many vars in the program")); - }; - - auto unique_name = generate_unique_name(); - VLOG(5) << "Generate new var name " << unique_name; - InsertNewVarInBlock(var, *(all_vars_iter->second), unique_name); - return unique_name; - } -} - -void UniqueBlockVarGenerator::InsertNewVarInBlock( - const std::weak_ptr &var, - const framework::VarDesc &var_desc, - const std::string &name, - bool force_persistable) { - var_to_name_[var] = name; - existing_names_.insert(name); - auto *new_var_desc = block_->Var(name); - *new_var_desc = var_desc; - new_var_desc->SetName(name); - if (force_persistable) { - new_var_desc->SetPersistable(true); - } -} - -bool ProgramDescTracer::ContainVar(const std::weak_ptr &var) const { - auto vars_iter = vars_.find(var); - bool ret = (vars_iter != vars_.end()); - if (!ret) { - VLOG(5) << "Can't found variable: " << var.lock()->Name(); - } - return ret; -} - -void ProgramDescTracer::InsertOp(const std::string &type, - const NameVarBaseMap &inputs, - const NameVarBaseMap &outputs, - const framework::AttributeMap &attrs) { - ops_.emplace_back(new OpDescMeta(type, inputs, outputs, attrs)); - auto &new_op = ops_.back(); - for (auto &pair : new_op->Inputs()) { - for (auto &var : pair.second) { - InsertVarIfNotExist(var.lock(), true); - } - } - - for (auto &pair : new_op->Outputs()) { - for (auto &var : pair.second) { - InsertVarIfNotExist(var.lock(), false); - } - } -} - -void ProgramDescTracer::InsertOp(const std::string &type, - const NameTensorMap &inputs, - const NameTensorMap &outputs, - const framework::AttributeMap &attrs) { - // TODO(jiabin): Support this later. -} - -TracedProgramTuple ProgramDescTracer::CreateProgramDesc( - const std::vector> &feed_vars, - const std::string &feed_prefix, - const std::vector> &fetch_vars, - const std::string &fetch_prefix, - const std::string &tmp_prefix) const { - std::unique_ptr prog(new framework::ProgramDesc()); - auto *block = prog->MutableBlock(0); - - auto non_exist_vars_copy = non_exist_input_vars_; - for (auto &feed_var : feed_vars) { - non_exist_vars_copy.erase(feed_var); - } - - UniqueBlockVarGenerator generator(vars_, non_exist_vars_copy, block); - - std::vector feed_var_names; - for (auto &feed_var : feed_vars) { - if (ContainVar(feed_var)) { - feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix)); - } - } - - std::vector fetch_var_names; - for (auto &fetch_var : fetch_vars) { - if (ContainVar(fetch_var)) { - fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix)); - } - } - - for (auto &op : ops_) { - auto *op_desc = block->AppendOp(); - op_desc->SetType(op->Type()); - op_desc->SetAttrMap(op->Attrs()); - - for (auto &pair : op->Inputs()) { - std::vector names; - names.reserve(pair.second.size()); - for (auto &var : pair.second) { - if (ContainVar(var)) { - names.emplace_back(generator.NameOf(var, tmp_prefix)); - } - } - - op_desc->SetInput(pair.first, names); - } - - for (auto &pair : op->Outputs()) { - std::vector names; - names.reserve(pair.second.size()); - for (auto &var : pair.second) { - if (ContainVar(var)) { - names.emplace_back(generator.NameOf(var, tmp_prefix)); - } - } - - op_desc->SetOutput(pair.first, names); - } - } - - prog->Flush(); - - std::vector> persistable_vars( - non_exist_vars_copy.begin(), non_exist_vars_copy.end()); - for (auto &pair : vars_) { - if (pair.second->Persistable()) { - auto var = pair.first.lock(); - PADDLE_ENFORCE_NOT_NULL( - var, - platform::errors::NotFound("Persistable var %s does not exist", - pair.second->Name())); - persistable_vars.emplace_back(var); - } - } - return std::make_tuple(std::move(prog), - std::move(feed_var_names), - std::move(fetch_var_names), - std::move(persistable_vars)); -} - -void ProgramDescTracer::InsertVarIfNotExist( - const std::shared_ptr &new_var, bool is_input) { - PADDLE_ENFORCE_NOT_NULL( - new_var, - platform::errors::InvalidArgument("The variable to insert is NULL.")); - if (vars_.count(new_var) != 0) return; - - auto new_var_desc = new framework::VarDesc(""); - vars_[new_var].reset(new_var_desc); - - if (new_var->Persistable() || is_input) { - new_var_desc->SetName(new_var->Name()); - new_var_desc->SetPersistable(new_var->Persistable()); - if (!new_var->Persistable()) { - non_exist_input_vars_.insert(new_var); - } - } else { - new_var_desc->SetPersistable(false); - } - - const auto &inner_var = new_var->Var(); - PADDLE_ENFORCE_EQ(inner_var.IsInitialized(), - true, - platform::errors::InvalidArgument( - "The variable to insert is not initialized.")); - if (inner_var.IsType()) { - const auto &tensor = inner_var.Get(); - new_var_desc->SetType(framework::proto::VarType::LOD_TENSOR); - new_var_desc->SetShape(common::vectorize(tensor.dims())); - new_var_desc->SetLoDLevel(static_cast(tensor.lod().size())); - if (tensor.IsInitialized()) { - new_var_desc->SetDataType(framework::TransToProtoVarType(tensor.dtype())); - } else { - new_var_desc->SetDataType(framework::proto::VarType::FP32); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Not support variable type %s.", - framework::ToTypeName(inner_var.Type()))); - } -} - -void ProgramDescTracer::Reset() { - ops_.clear(); - vars_.clear(); - non_exist_input_vars_.clear(); -} - -} // namespace jit -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.h b/paddle/fluid/imperative/jit/program_desc_tracer.h deleted file mode 100644 index 24550bcf90041..0000000000000 --- a/paddle/fluid/imperative/jit/program_desc_tracer.h +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/imperative/jit/op_desc_meta.h" -#include "paddle/fluid/imperative/layer.h" -#include "paddle/fluid/imperative/type_defs.h" -#include "paddle/fluid/platform/macros.h" - -namespace paddle { -namespace imperative { -class VarBase; -} // namespace imperative -} // namespace paddle - -namespace paddle { -namespace imperative { -namespace jit { - -using VarDescMetaMap = std::map, - std::unique_ptr, - std::owner_less>>; - -using VarBaseSet = std::set, - std::owner_less>>; - -using TracedProgramTuple = - std::tuple /*program*/, - std::vector /*feed_var_names*/, - std::vector /*fetch_var_names*/, - std::vector> /*persistable_vars*/>; - -class ProgramDescTracer { - DISABLE_COPY_AND_ASSIGN(ProgramDescTracer); - - public: - ProgramDescTracer() = default; - - void InsertOp(const std::string &type, - const NameVarBaseMap &inputs, - const NameVarBaseMap &outputs, - const framework::AttributeMap &attrs); - - void InsertOp(const std::string &type, - const NameTensorMap &inputs, - const NameTensorMap &outputs, - const framework::AttributeMap &attrs); - - TracedProgramTuple CreateProgramDesc( - const std::vector> &feed_vars, - const std::string &feed_prefix, - const std::vector> &fetch_vars, - const std::string &fetch_prefix, - const std::string &tmp_prefix) const; - bool ContainVar(const std::weak_ptr &var) const; - void Reset(); - - private: - void InsertVarIfNotExist(const std::shared_ptr &new_var, - bool is_input); - - private: - std::vector> ops_; - VarDescMetaMap vars_; - VarBaseSet non_exist_input_vars_; -}; - -} // namespace jit -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/imperative/layout_autotune.cc b/paddle/fluid/imperative/layout_autotune.cc index 006021488aa57..7836572b0c426 100644 --- a/paddle/fluid/imperative/layout_autotune.cc +++ b/paddle/fluid/imperative/layout_autotune.cc @@ -145,7 +145,7 @@ LayoutAutotuneGuard::LayoutAutotuneGuard(std::shared_ptr tracer, } LayoutAutotuneGuard::~LayoutAutotuneGuard() { - if (pre_layout_autotune_) { + if (pre_layout_autotune_) { // NOLINT tracer_->EnableLayoutAutoTune(); } else { tracer_->DisableLayoutAutoTune(); diff --git a/paddle/fluid/imperative/nccl_context.cc b/paddle/fluid/imperative/nccl_context.cc index d70d40808f915..3ed9b97bfc362 100644 --- a/paddle/fluid/imperative/nccl_context.cc +++ b/paddle/fluid/imperative/nccl_context.cc @@ -67,7 +67,7 @@ void NCCLParallelContext::Init() { std::vector nccl_ids; nccl_ids.resize(strategy_.nrings_); - if (strategy_.local_rank_ == 0) { + if (strategy_.local_rank_ == 0) { // NOLINT // generate the unique ncclid on the root worker for (auto &nccl_id : nccl_ids) { platform::dynload::ncclGetUniqueId(&nccl_id); diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 0a5d44a1e1e57..5ae9e43752491 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -33,8 +33,8 @@ #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/profiler/event_tracing.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/utils/string/string_helper.h" COMMON_DECLARE_bool(sort_sum_gradient); @@ -366,7 +366,7 @@ class GradientAccumulationInfo { if (!grad_var_) { grad_var_ = std::make_shared(true, mapped_grad_var_->Name()); grad_var_->SetOverriddenStopGradient(false); - if (sort_gradient_) { + if (sort_gradient_) { // NOLINT accumulator_ = std::make_unique( grad_var_->SharedVar().get()); } else { diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 8129ea244f489..a60c81a4c22d9 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -660,7 +660,7 @@ void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { - if (run_phi_kernel_) { + if (run_phi_kernel_) { // NOLINT PreparedOpRunPtImpl(op_, kernel_key_, arg_map_fn_, @@ -692,7 +692,7 @@ void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { - if (run_phi_kernel_) { + if (run_phi_kernel_) { // NOLINT PreparedOpRunPtImpl(op_, kernel_key_, arg_map_fn_, @@ -724,7 +724,7 @@ void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, const framework::AttributeMap& default_attrs) { - if (run_phi_kernel_) { + if (run_phi_kernel_) { // NOLINT PreparedOpRunPtImpl(op_, kernel_key_, arg_map_fn_, diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 70c36b27d31c0..4a0d417595b8f 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -405,31 +405,31 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, switch (AttrTypeID(attr)) { case framework::proto::AttrType::FLOAT: kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST(float, attr)))); + phi::Scalar(PADDLE_GET_CONST(float, attr))); break; case framework::proto::AttrType::FLOAT64: kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST(double, attr)))); + phi::Scalar(PADDLE_GET_CONST(double, attr))); break; case framework::proto::AttrType::INT: kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST(int, attr)))); + phi::Scalar(PADDLE_GET_CONST(int, attr))); break; case framework::proto::AttrType::LONG: kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST(int64_t, attr)))); + phi::Scalar(PADDLE_GET_CONST(int64_t, attr))); break; case framework::proto::AttrType::STRING: kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST(std::string, attr)))); + phi::Scalar(PADDLE_GET_CONST(std::string, attr))); break; case framework::proto::AttrType::BOOLEAN: kernel_ctx->EmplaceBackAttr( - std::move(phi::Scalar(PADDLE_GET_CONST(bool, attr)))); + phi::Scalar(PADDLE_GET_CONST(bool, attr))); break; case framework::proto::AttrType::SCALAR: - kernel_ctx->EmplaceBackAttr(std::move(phi::Scalar( - PADDLE_GET_CONST(paddle::experimental::Scalar, attr)))); + kernel_ctx->EmplaceBackAttr(phi::Scalar( + PADDLE_GET_CONST(paddle::experimental::Scalar, attr))); break; default: PADDLE_THROW(platform::errors::Unimplemented( @@ -448,20 +448,20 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, auto& attr = *attr_ptr; switch (AttrTypeID(attr)) { case framework::proto::AttrType::INTS: - kernel_ctx->EmplaceBackAttr(std::move( - phi::IntArray(PADDLE_GET_CONST(std::vector, attr)))); + kernel_ctx->EmplaceBackAttr( + phi::IntArray(PADDLE_GET_CONST(std::vector, attr))); break; case framework::proto::AttrType::LONGS: - kernel_ctx->EmplaceBackAttr(std::move( - phi::IntArray(PADDLE_GET_CONST(std::vector, attr)))); + kernel_ctx->EmplaceBackAttr( + phi::IntArray(PADDLE_GET_CONST(std::vector, attr))); break; case framework::proto::AttrType::INT: - kernel_ctx->EmplaceBackAttr(std::move( - phi::IntArray(&PADDLE_GET_CONST(int32_t, attr), 1))); + kernel_ctx->EmplaceBackAttr( + phi::IntArray(&PADDLE_GET_CONST(int32_t, attr), 1)); break; case framework::proto::AttrType::LONG: - kernel_ctx->EmplaceBackAttr(std::move( - phi::IntArray(&PADDLE_GET_CONST(int64_t, attr), 1))); + kernel_ctx->EmplaceBackAttr( + phi::IntArray(&PADDLE_GET_CONST(int64_t, attr), 1)); break; default: PADDLE_THROW(platform::errors::Unimplemented( @@ -481,7 +481,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, variables.push_back(var_base->MutableVar()); } kernel_ctx->EmplaceBackAttr( - std::move(framework::MakePhiIntArrayFromVarList(variables))); + framework::MakePhiIntArrayFromVarList(variables)); } } break; @@ -559,7 +559,7 @@ void BuildDygraphPhiKernelContext(const phi::KernelSignature& kernel_signature, PADDLE_ENFORCE_NOT_NULL( attr_ptr, platform::errors::NotFound("(%s) is not found in AttributeMap when " - "buildind dygraph KernelContext.", + "building dygraph KernelContext.", attr_names[i])); auto& attr = *attr_ptr; switch (attr_defs[i].type_index) { diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 461c2d3ff4bb8..526935a5182be 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -24,8 +24,8 @@ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #endif -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace imperative { @@ -227,7 +227,7 @@ void SplitTensorsWithType( void Group::ConcatTensors(const platform::DeviceContext &context) { auto place = context.GetPlace(); - if (platform::is_gpu_place(place)) { + if (platform::is_gpu_place(place)) { // NOLINT #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ConcatTensorsWithType(static_cast(context), dense_tensors_, @@ -263,7 +263,7 @@ void Group::ConcatTensors(const platform::DeviceContext &context) { void Group::SplitTensors(const platform::DeviceContext &context) { auto place = context.GetPlace(); - if (platform::is_gpu_place(place)) { + if (platform::is_gpu_place(place)) { // NOLINT #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) SplitTensorsWithType(static_cast(context), &dense_contents_, @@ -493,8 +493,10 @@ void Reducer::PrepareDeps(const std::unordered_set &init_nodes) { "using PyLayer in a DataParallel model, you can skip gradient " "synchronization among multiple cards by 'no_sync', and " "manually implement 'all_reduce' before model optimization. " - "There is an example showing specific implemetation processing " - "in offical docs: https://www.paddlepaddle.org.cn/documentation" + "There is an example showing specific implementation " + "processing " + "in official docs: " + "https://www.paddlepaddle.org.cn/documentation" "/docs/api/paddle/DataParallel_cn.html")); } ++node_deps_[grad_pending_node.get()]; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 48b51265421c5..3eff589fee703 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -30,10 +30,10 @@ #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler/event_tracing.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/utils/string/string_helper.h" COMMON_DECLARE_bool(use_mkldnn); COMMON_DECLARE_string(tracer_mkldnn_ops_on); @@ -44,8 +44,6 @@ namespace paddle { namespace imperative { thread_local std::string Tracer::python_stack_ = ""; -thread_local bool Tracer::enable_program_desc_tracing_ = false; - thread_local bool Tracer::has_grad_ = true; thread_local bool Tracer::use_layout_autotune_ = false; @@ -367,11 +365,6 @@ void Tracer::TraceOpImpl(const std::string& type, "Operator %s raises an unknown exception.", type)); } - if (enable_program_desc_tracing_) { - VLOG(5) << "Trace op " << type << " into ProgramDesc"; - program_desc_tracer_->InsertOp(type, new_ins, outs, attrs); - } - { platform::RecordEvent node_creation_record_event( "grad_node_creation", platform::TracerEventType::OperatorInner, 1); @@ -594,14 +587,6 @@ bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins, return false; } -void Tracer::SetEnableProgramDescTracing(bool enabled) { - enable_program_desc_tracing_ = enabled; -} - -bool Tracer::IsProgramDescTracingEnabled() const { - return enable_program_desc_tracing_; -} - void Tracer::SetAmpDtype(std::string amp_dtype) { VLOG(4) << "set amp_dtype to " << amp_dtype; g_current_amp_attrs->SetAmpDtype(amp_dtype); @@ -660,8 +645,8 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature( if (phi::KernelFactory::Instance().HasStructuredKernel(type)) { return phi::KernelSignature(op->Type().c_str()); } else { - return phi::KernelSignature(std::move( - opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx))); + return phi::KernelSignature( + opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)); } } diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index b6f61c36f670b..ed82b5e52a737 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -26,7 +26,6 @@ #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/imperative/amp_auto_cast.h" #include "paddle/fluid/imperative/basic_engine.h" -#include "paddle/fluid/imperative/jit/program_desc_tracer.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/platform/macros.h" @@ -63,7 +62,6 @@ class Tracer { public: Tracer() : basic_engine_(new BasicEngine()), - program_desc_tracer_(new jit::ProgramDescTracer()), generator_(new UniqueNameGenerator()) { expected_place_ = platform::CPUPlace(); } @@ -126,14 +124,6 @@ class Tracer { const NameTensorMap& outs, bool trace_backward); - void SetEnableProgramDescTracing(bool enabled); - - bool IsProgramDescTracingEnabled() const; - - jit::ProgramDescTracer* GetProgramDescTracer() { - return program_desc_tracer_.get(); - } - // Note(Aurelius84): The `tmp` is used as prefix key while naming a temporary // intermediate var both in imperative and static graph mode. But the // `UniqueNameGenerator` in C++ and `unique_name.py` in Python doesn't share @@ -187,7 +177,6 @@ class Tracer { private: std::unique_ptr basic_engine_; - std::unique_ptr program_desc_tracer_; std::unique_ptr generator_; platform::Place expected_place_; GarbageCollectorMap gcs_; diff --git a/paddle/fluid/imperative/type_defs.h b/paddle/fluid/imperative/type_defs.h index 08f3c8d4a0fc2..5913ea7aad07f 100644 --- a/paddle/fluid/imperative/type_defs.h +++ b/paddle/fluid/imperative/type_defs.h @@ -32,9 +32,6 @@ class OpBase; class GradOpNode; class Tracer; -using WeakNameVarBaseMap = - std::map>>; - namespace details { template struct NameVarMapTrait {}; diff --git a/paddle/fluid/imperative/var_helper.cc b/paddle/fluid/imperative/var_helper.cc index bafea5a720d3a..9561962935ffe 100644 --- a/paddle/fluid/imperative/var_helper.cc +++ b/paddle/fluid/imperative/var_helper.cc @@ -50,7 +50,8 @@ void InitializeVariable(paddle::framework::Variable *var, var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) { var->GetMutable(); - } else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) { + } else if (var_type == + paddle::framework::proto::VarType::FETCH_LIST) { // NOLINT var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::STEP_SCOPES) { var->GetMutable>(); diff --git a/paddle/fluid/imperative/var_helper.h b/paddle/fluid/imperative/var_helper.h index ebf3e49c51870..1a74d987e7e2b 100644 --- a/paddle/fluid/imperative/var_helper.h +++ b/paddle/fluid/imperative/var_helper.h @@ -40,7 +40,7 @@ void InitializeVariable(paddle::framework::Variable* var, template const paddle::platform::Place& GetPlace(const std::shared_ptr& var); template -const std::string& GetNameFromVar(std::shared_ptr var); +TEST_API const std::string& GetNameFromVar(std::shared_ptr var); template bool CheckCachedKey(std::shared_ptr tensor, const phi::KernelKey& key); diff --git a/paddle/fluid/imperative/xccl_context.cc b/paddle/fluid/imperative/xccl_context.cc index 1ed821d09c346..1eca9f9361419 100644 --- a/paddle/fluid/imperative/xccl_context.cc +++ b/paddle/fluid/imperative/xccl_context.cc @@ -50,13 +50,12 @@ static void XcclAllReduce(const phi::DenseTensor &src, auto *dst_ptr = phi::DeviceContextPool::Instance() .Get(src.place()) ->Alloc(dst, src.dtype()); - auto xccl_dtype = phi::ccl::ToCCLDataType(src.dtype()); phi::DeviceManager::CCLAllReduce(place.GetDeviceType(), src_ptr, dst_ptr, src.numel(), - xccl_dtype, + src.dtype(), phi::ccl::CCLReduceOp::SUM, comm, stream); @@ -201,12 +200,11 @@ void XCCLParallelContext::Broadcast(framework::Variable *src, int ring_id) { auto stream = comm->stream(); void *src_ptr = src_tensor->data(); - auto xccl_dtype = phi::ccl::ToCCLDataType(src_tensor->dtype()); phi::DeviceManager::CCLBroadcast(place_.GetDeviceType(), src_ptr, src_tensor->numel(), - xccl_dtype, + src_tensor->dtype(), 0, comm->comm(), *stream); diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 88003c6db6ba6..bed777851641a 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -93,7 +93,7 @@ set(SHARED_INFERENCE_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/api/details/zero_copy_tensor.cc ${CMAKE_CURRENT_SOURCE_DIR}/utils/io_utils.cc) -# NOTE(Aurelius84): For inference library, some DEPS is usless +# NOTE(Aurelius84): For inference library, some DEPS is useless # such as non-infer operator related targets et.al. list(REMOVE_ITEM fluid_modules cinn_op_dialect) # NOTE(Aurelisu84): Remove pir dialect related target DEPS for inference diff --git a/paddle/fluid/inference/analysis/analyzer.cc b/paddle/fluid/inference/analysis/analyzer.cc index 5e4c17fef1e65..9c6b7be94b906 100644 --- a/paddle/fluid/inference/analysis/analyzer.cc +++ b/paddle/fluid/inference/analysis/analyzer.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/inference/analysis/passes/passes.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index a87c919bbe2c1..aeaa305191974 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -227,6 +227,7 @@ struct Argument { DECL_ARGUMENT_FIELD(use_cutlass, UseCutlass, bool); DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); + DECL_ARGUMENT_FIELD(use_pir, UsePIR, bool); // Usually use for trt dynamic shape. // TRT will select the best kernel according to opt shape @@ -250,9 +251,20 @@ struct Argument { DECL_ARGUMENT_FIELD(trt_exclude_var_names, TRTExcludeVarNames, std::vector); + DECL_ARGUMENT_FIELD(trt_forbid_dynamic_op, TRTForbidDynamicOp, bool); + DECL_ARGUMENT_FIELD(tensorrt_disabled_ops, TensorRtDisabledOPs, std::vector); + DECL_ARGUMENT_FIELD(trt_parameter_run_fp16, + TRTParameterRunFp16, + std::vector); + DECL_ARGUMENT_FIELD(trt_parameter_run_int8, + TRTParameterRunInt8, + std::vector); + DECL_ARGUMENT_FIELD(trt_parameter_run_bfp16, + TRTParameterRunBfp16, + std::vector); DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int); DECL_ARGUMENT_FIELD(tensorrt_use_static_engine, TensorRtUseStaticEngine, diff --git a/paddle/fluid/inference/analysis/helper.h b/paddle/fluid/inference/analysis/helper.h index e891da8e6d19f..949f3a03f9c41 100644 --- a/paddle/fluid/inference/analysis/helper.h +++ b/paddle/fluid/inference/analysis/helper.h @@ -29,7 +29,7 @@ limitations under the License. */ #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #ifdef _WIN32 #include diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index eca0c8fedd0a2..77052155efaa6 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -27,8 +27,8 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/inference/analysis/argument.h" -#include "paddle/fluid/string/pretty_log.h" #include "paddle/phi/common/data_type.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace inference { @@ -173,6 +173,18 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set( "trt_exclude_var_names", new std::vector(argument->trt_exclude_var_names())); + pass->Set( + "trt_parameter_run_fp16", + new std::vector(argument->trt_parameter_run_fp16())); + pass->Set( + "trt_parameter_run_int8", + new std::vector(argument->trt_parameter_run_int8())); + pass->Set( + "trt_parameter_run_bfp16", + new std::vector(argument->trt_parameter_run_bfp16())); + pass->Set("forbid_dynamic_op", + new bool(argument->trt_forbid_dynamic_op())); + pass->Set("program", new framework::ProgramDesc *(&argument->main_program())); pass->Set("predictor_id", new int(argument->predictor_id())); diff --git a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc index 5e132cc4b6303..77d4e4d045aed 100644 --- a/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.cc @@ -24,7 +24,7 @@ #include "paddle/fluid/framework/type_defs.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/ir_passes/dlnne_subgraph_pass.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc index 2d484a943cf20..619625cf5794a 100644 --- a/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/lite_subgraph_pass.cc @@ -31,7 +31,7 @@ #include "paddle/fluid/inference/lite/engine.h" #include "paddle/fluid/inference/lite/op_teller.h" #include "paddle/fluid/inference/utils/singleton.h" -#include "paddle/fluid/string/pretty_log.h" +#include "paddle/utils/string/pretty_log.h" namespace paddle { namespace inference { @@ -71,7 +71,7 @@ std::vector IOVarsFilter(const std::vector& nodes) { void StrToBinaryFile(const std::string& path, const std::string& str) { std::ofstream file(path.c_str(), std::ios::binary); - file.write(str.c_str(), str.size()); + file.write(str.c_str(), str.size()); // NOLINT file.close(); } @@ -271,7 +271,7 @@ void LiteSubgraphPass::SetUpEngine( Get>("nnadapter_model_cache_token"); lite_api::TargetType target_type = TARGET(kX86); - if (use_gpu) { + if (use_gpu) { // NOLINT target_type = TARGET(kCUDA); } else if (use_xpu) { target_type = TARGET(kXPU); @@ -417,13 +417,11 @@ void LiteSubgraphPass::ApplyImpl(framework::ir::Graph* graph) const { auto& lite_ops_filter = Get>("lite_ops_filter"); auto teller = [&lite_ops_filter](const Node* node) { - if (!node->IsOp() || !node->Op()) - return false; - else if (node->Op()->Type() == "feed" || node->Op()->Type() == "fetch") - return false; - else if (std::find(lite_ops_filter.begin(), - lite_ops_filter.end(), - node->Op()->Type()) != lite_ops_filter.end()) + if (!node->IsOp() || !node->Op() || node->Op()->Type() == "feed" || + node->Op()->Type() == "fetch" || + std::find(lite_ops_filter.begin(), + lite_ops_filter.end(), + node->Op()->Type()) != lite_ops_filter.end()) return false; return inference::lite::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 69b27b1214839..db185b15c03d9 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -14,7 +14,6 @@ // limitations under the License. #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" - #include #include #include @@ -153,12 +152,14 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( auto trt_disabled_ops = Get>("trt_disabled_ops"); auto with_dynamic_shape = Get("with_dynamic_shape"); auto use_explicit_quantization = Get("use_explicit_quantization"); + auto forbid_dynamic_op = Get("forbid_dynamic_op"); auto teller = [&](const framework::ir::Node *node) { if (!node->IsOp() || !node->Op()) return false; if (find(trt_disabled_ops.begin(), trt_disabled_ops.end(), node->Op()->Type()) != trt_disabled_ops.end()) { VLOG(3) << node->Op()->Type().c_str() + << " is diabled by config in TensorRT"; return false; } @@ -172,8 +173,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( } } } - bool is_ok = tensorrt::OpTeller::Global().Tell( - node, no_calib_int8, with_dynamic_shape, use_explicit_quantization); + bool is_ok = tensorrt::OpTeller::Global().Tell(node, + no_calib_int8, + with_dynamic_shape, + forbid_dynamic_op, + use_explicit_quantization); if (!is_ok) VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT"; return is_ok; @@ -471,9 +475,47 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( } auto precision_mode = static_cast(Get("trt_precision_mode")); + auto trt_params_run_fp16 = + Get>("trt_parameter_run_fp16"); + auto trt_params_run_int8 = + Get>("trt_parameter_run_int8"); + auto trt_params_run_bfp16 = + Get>("trt_parameter_run_bfp16"); + + for (const auto ¶ : parameters) { + if (std::find(trt_params_run_fp16.begin(), + trt_params_run_fp16.end(), + para) != trt_params_run_fp16.end()) { + precision_mode = phi::DataType::FLOAT16; + break; + } + } + bool enable_fp16 = false; if (precision_mode == phi::DataType::FLOAT16) enable_fp16 = true; auto enable_int8 = Get("enable_int8"); + + for (const auto ¶ : parameters) { + if (std::find(trt_params_run_int8.begin(), + trt_params_run_int8.end(), + para) != trt_params_run_int8.end()) { + enable_int8 = true; + precision_mode = phi::DataType::INT8; + break; + } + } + + for (const auto ¶ : parameters) { + if (std::find(trt_params_run_bfp16.begin(), + trt_params_run_bfp16.end(), + para) != trt_params_run_bfp16.end()) { + precision_mode = phi::DataType::BFLOAT16; + break; + } + } + bool enable_bfp16 = false; + if (precision_mode == phi::DataType::BFLOAT16) enable_bfp16 = true; + auto use_calib_mode = Get("use_calib_mode"); auto &subgraph_nodes = *framework::ir::Agent(node).subgraph(); auto min_input_shape = @@ -506,8 +548,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( &max_shape_tensor, &optim_shape_tensor); } else { - shape_range_info_path = - Get("model_opt_cache_dir") + "shape_range_info.pbtxt"; + shape_range_info_path = Get("model_opt_cache_dir") + "/" + + "shape_range_info.pbtxt"; if (open(shape_range_info_path.c_str(), O_RDONLY) != -1) { VLOG(1) << "trt dynamic_shape deserialize from " << shape_range_info_path; @@ -719,6 +761,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("calibration_data", calibration_data); op_desc->SetAttr("enable_int8", enable_int8); op_desc->SetAttr("enable_fp16", enable_fp16); + op_desc->SetAttr("enbale_bfp16", enable_bfp16); op_desc->SetAttr("use_calib_mode", use_calib_mode); op_desc->SetAttr("engine_key", engine_key); op_desc->SetAttr("calibration_engine_key", calibration_engine_key); @@ -754,7 +797,7 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp( bool calibration_mode = (enable_int8 && calibration_data.empty() && use_calib_mode); if (calibration_mode) { - // calibraion mode means generate int8 calibration table data process. + // calibration mode means generate int8 calibration table data process. return calibration_engine_key; } diff --git a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc index b422dea840af5..993ab2e8618f4 100644 --- a/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc +++ b/paddle/fluid/inference/analysis/passes/inference_op_replace_pass.cc @@ -16,14 +16,12 @@ #include "paddle/fluid/inference/analysis/argument.h" -COMMON_DECLARE_bool(enable_pir_in_executor); - namespace paddle { namespace inference { namespace analysis { void InferenceOpReplacePass::RunImpl(Argument* argument) { - if (FLAGS_enable_pir_in_executor) { + if (argument->use_pir()) { return; } diff --git a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc index 8106dfbb9e6aa..ea97be8f90a60 100644 --- a/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_graph_build_pass.cc @@ -121,7 +121,7 @@ std::unique_ptr IrGraphBuildPass::LoadModel( bool model_from_memory, bool skip_load_params) { framework::Executor exe(place); - if (!model_from_memory) { + if (!model_from_memory) { // NOLINT return Load(&exe, scope, program_path, params_path, !skip_load_params); } else { return LoadFromMemory(&exe, scope, program_path, params_path); diff --git a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc index 2961d5c66f9f4..2e722f9a7e6e9 100644 --- a/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc +++ b/paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.cc @@ -32,8 +32,6 @@ PD_DEFINE_bool( // NOLINT false, "Keep old mode for developers, the model is saved on cpu not device."); -COMMON_DECLARE_bool(enable_pir_in_executor); - namespace paddle { namespace inference { namespace analysis { @@ -208,9 +206,10 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToXpu(Argument *argument) { #endif void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) { - if (FLAGS_enable_pir_in_executor) { + if (argument->use_pir()) { return; } + PADDLE_ENFORCE_EQ( argument->scope_valid(), true, diff --git a/paddle/fluid/inference/analysis/passes/save_optimized_model_pass.cc b/paddle/fluid/inference/analysis/passes/save_optimized_model_pass.cc index cc463ce45f105..aaf9439d2b9ed 100644 --- a/paddle/fluid/inference/analysis/passes/save_optimized_model_pass.cc +++ b/paddle/fluid/inference/analysis/passes/save_optimized_model_pass.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/scope.h" @@ -37,10 +38,18 @@ void SaveOptimizedModelPass::SaveOptimizedModel(Argument* argument) { framework::ir::GraphToProgram(*graph, &optimized_program_desc); - // Some vars may be deleted by pass, so we need to remove them in block + // TODO(minghaipeng): Move the following code to a separate clean pass. + // Remove the scale and zero point parameters from optimized program. + auto scale_and_zero_point_param = graph->GetOrInit>( + framework::ir::kScaleAndZeroPointParamAttr); framework::BlockDesc* block = optimized_program_desc.MutableBlock(0); for (auto& var_desc : block->AllVars()) { - if (var_desc->Persistable() && !scope.FindVar(var_desc->Name())) { + auto var_name = var_desc->Name(); + if (var_desc->Persistable() && scope.FindVar(var_name) && + std::count(scale_and_zero_point_param.begin(), + scale_and_zero_point_param.end(), + var_name) > 0) { + scope.EraseVars({var_name}); block->RemoveVar(var_desc->Name()); } } @@ -74,7 +83,7 @@ void SaveOptimizedModelPass::SaveOptimizedModel(Argument* argument) { } } - std::string save_params_path = path + ".pdiparams"; + std::string save_params_path = path + "/" + "_optimized.pdiparams"; std::vector save_var_list(save_var_set.begin(), save_var_set.end()); std::sort(save_var_list.begin(), save_var_list.end()); @@ -105,7 +114,7 @@ void SaveOptimizedModelPass::SaveOptimizedModel(Argument* argument) { } } } - std::string save_model_path = path + ".pdmodel"; + std::string save_model_path = path + "/" + "_optimized.pdmodel"; auto str = optimized_program_desc.Proto()->SerializeAsString(); std::ofstream file(save_model_path.c_str(), std::ios::binary); file.write(str.c_str(), str.size()); // NOLINT diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index eda204189c8a6..65a4bea5b1240 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -33,7 +33,8 @@ set(paddle_inference_api_deps trainer_desc_proto custom_operator lod_tensor - scope) + scope + drr) if(WITH_CRYPTO) list(APPEND paddle_inference_api_deps framework_io) diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 0ec5151a92bc5..efe7b83f7df16 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -15,6 +15,7 @@ #include #include #include +#include #include "glog/logging.h" #include "paddle/common/flags.h" @@ -181,6 +182,11 @@ void AnalysisConfig::EnableXpu(int l3_size, bool transformer_encoder_adaptive_seqlen, bool enable_multi_stream) { #if defined(PADDLE_WITH_XPU) || defined(LITE_SUBGRAPH_WITH_XPU) + LOG_FIRST_N(WARNING, 1) + << "Parameters in EnableXpu/enable_xpu is deprecated since version " + "2.6.1, and will be removed in version 3.0! Please use " + "EnableXpu/enable_xpu without parameters, and use " + "SetXpuConfig/set_xpu_config to set options."; use_xpu_ = true; xpu_config_.l3_size = l3_size; xpu_config_.conv_autotune_level = conv_autotune; @@ -462,6 +468,10 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(tensorrt_min_subgraph_size_); CP_MEMBER(tensorrt_precision_mode_); CP_MEMBER(trt_mark_output_); + CP_MEMBER(trt_parameters_run_fp16_); + CP_MEMBER(trt_parameters_run_int8_); + CP_MEMBER(trt_parameters_run_bfp16_); + CP_MEMBER(trt_forbid_dynamic_op_) CP_MEMBER(trt_output_tensor_names_); CP_MEMBER(trt_disabled_ops_); CP_MEMBER(trt_use_dla_); @@ -581,6 +591,11 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(skip_load_params_); CP_MEMBER(use_new_executor_); + CP_MEMBER(use_pir_); + CP_MEMBER(custom_passes_); + CP_MEMBER(custom_pass_only_); + CP_MEMBER(pm_opt_level_); + CP_MEMBER(ir_debug_passes_); if (use_gpu_) { PADDLE_ENFORCE_EQ(use_xpu_, @@ -780,6 +795,11 @@ void AnalysisConfig::MarkTrtEngineOutputs( trt_output_tensor_names_ = output_tensor_names; } +void AnalysisConfig::Exp_DisableTensorRTDynamicShapeOPs( + bool trt_forbid_dynamic_op) { + trt_forbid_dynamic_op_ = trt_forbid_dynamic_op; +} + void AnalysisConfig::EnableTensorRTMemoryOptim(bool engine_memory_sharing, int sharing_identifier) { PADDLE_ENFORCE_EQ( @@ -873,6 +893,21 @@ void AnalysisConfig::Exp_DisableTensorRtSubgraph( var_name_not_trt.end()); } +void AnalysisConfig::Exp_SpecifyTensorRTSubgraphPrecision( + const std::vector &trt_parameters_run_fp16, + const std::vector &trt_parameters_run_int8, + const std::vector &trt_parameters_run_bfp16) { + trt_parameters_run_fp16_.insert(trt_parameters_run_fp16_.end(), + trt_parameters_run_fp16.begin(), + trt_parameters_run_fp16.end()); + trt_parameters_run_int8_.insert(trt_parameters_run_int8_.end(), + trt_parameters_run_int8.begin(), + trt_parameters_run_int8.end()); + trt_parameters_run_bfp16_.insert(trt_parameters_run_bfp16_.end(), + trt_parameters_run_bfp16.begin(), + trt_parameters_run_bfp16.end()); +} + void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; } void AnalysisConfig::SetTensorRtOptimizationLevel(int level) { @@ -891,6 +926,11 @@ void AnalysisConfig::Update() { auto &&info = SerializeInfoCache(); if (info == serialized_info_cache_) return; + std::unordered_set deleted_passes; + if (pass_builder_) { + deleted_passes = pass_builder_->GetAllDeletedPasses(); + } + // Transfer pass_builder and copy the existing compatible passes. if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) || ((use_xpu() ^ pass_builder_->use_xpu())) || @@ -1103,7 +1143,7 @@ void AnalysisConfig::Update() { "but did not have the option -DWITH_CUSTOM_DEVICE compiled.")); #endif } - for (auto &delete_pass : pass_builder()->GetAllDeletedPasses()) { + for (const auto &delete_pass : deleted_passes) { pass_builder_->DeletePass(delete_pass); } } @@ -1128,6 +1168,13 @@ std::string AnalysisConfig::SerializeInfoCache() { ss << tensorrt_max_batchsize_; ss << tensorrt_min_subgraph_size_; ss << trt_mark_output_; + for (auto &name : trt_parameters_run_fp16_) ss << name.c_str(); + ss << ";"; + for (auto &name : trt_parameters_run_int8_) ss << name.c_str(); + ss << ";"; + for (auto &name : trt_parameters_run_bfp16_) ss << name.c_str(); + ss << ";"; + ss << trt_forbid_dynamic_op_; ss << use_dlnne_; ss << dlnne_min_subgraph_size_; @@ -1232,11 +1279,13 @@ float AnalysisConfig::fraction_of_gpu_memory_for_pool() const { size_t gpu_total, gpu_available; platform::SetDeviceId(gpu_device_id_); platform::GpuMemoryUsage(&gpu_available, &gpu_total); - double total_gpu_memory = gpu_total / 1024. / 1024.; + double total_gpu_memory = static_cast(gpu_total) / 1024. / 1024.; float fraction_of_gpu_memory = - static_cast(memory_pool_init_size_mb()) / total_gpu_memory; + static_cast(memory_pool_init_size_mb()) / + static_cast(total_gpu_memory); VLOG(3) << "total_gpu_memory is " << total_gpu_memory - << "M, gpu_available is " << gpu_available / 1024. / 1024. + << "M, gpu_available is " + << static_cast(gpu_available) / 1024. / 1024. << "M, memory_pool_init_size is " << memory_pool_init_size_mb() << "M."; return fraction_of_gpu_memory; @@ -1279,8 +1328,10 @@ NativeConfig AnalysisConfig::ToNativeConfig() const { return config; } -void AnalysisConfig::SwitchIrDebug(int x) { +void AnalysisConfig::SwitchIrDebug(int x, + const std::vector &passes) { ir_debug_ = x; + ir_debug_passes_ = passes; Update(); } @@ -1415,6 +1466,8 @@ std::string AnalysisConfig::Summary() { os.InsertRow({"trt_engine_memory_sharing", trt_engine_memory_sharing_ ? "true" : "false"}); os.InsertRow({"trt_mark_output", trt_mark_output_ ? "true" : "false"}); + os.InsertRow( + {"trt_forbid_dynamic_op", trt_forbid_dynamic_op_ ? "true" : "false"}); #endif } } @@ -1616,4 +1669,13 @@ void AnalysisConfig::EnableCINN() { bool AnalysisConfig::cinn_enabled() const { return use_cinn_; } +void AnalysisConfig::EnableCustomPasses(const std::vector &passes, + bool custom_pass_only) { + custom_passes_ = passes; + custom_pass_only_ = custom_pass_only; +} + +void AnalysisConfig::SetOptimizationLevel(int opt_level) { + pm_opt_level_ = opt_level; +} } // namespace paddle diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index b61e8eaa0577d..a0a61c034d831 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -31,6 +31,7 @@ #include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/naive_executor.h" +#include "paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" @@ -80,7 +81,6 @@ #ifdef PADDLE_WITH_DNNL #include "paddle/fluid/inference/api/mkldnn_quantizer.h" -#include "paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.h" #endif #ifdef PADDLE_WITH_ONNXRUNTIME @@ -113,27 +113,17 @@ #include "paddle/common/flags.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" -#include "paddle/fluid/pir/transforms/constant_folding_pass.h" -#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/silu_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.h" -#include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" -#include "paddle/fluid/pir/transforms/inplace_pass.h" -#include "paddle/fluid/pir/transforms/map_op_to_another_pass.h" -#include "paddle/fluid/pir/transforms/params_sync_among_devices_pass.h" +#include "paddle/fluid/pir/transforms/general/constant_folding_pass.h" +#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/transforms/general/inplace_pass.h" +#include "paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.h" +#include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h" +#include "paddle/fluid/pir/transforms/passes.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" -#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" +#include "paddle/fluid/pir/transforms/shape_optimization_pass.h" #include "paddle/pir/include/pass/pass_manager.h" +#include "paddle/pir/include/pass/pass_registry.h" -COMMON_DECLARE_bool(enable_pir_in_executor); COMMON_DECLARE_bool(pir_apply_inplace_pass); namespace paddle { @@ -375,7 +365,7 @@ AnalysisPredictor::AnalysisPredictor(const AnalysisConfig &config) } if (config_.new_executor_enabled()) { config_.EnableMemoryOptim(false); - if (FLAGS_enable_pir_in_executor) { + if (config_.new_ir_enabled()) { config_.SwitchIrOptim(false); } } @@ -424,8 +414,10 @@ bool AnalysisPredictor::Init( // Use Optimized model to inference if (config_.use_optimized_model_) { std::string optimized_model_path = GetOptimizedModelPath(); - std::string optimized_model = optimized_model_path + ".pdmodel"; - std::string optimized_params = optimized_model_path + ".pdiparams"; + std::string optimized_model = + optimized_model_path + "/" + "_optimized.pdmodel"; + std::string optimized_params = + optimized_model_path + "/" + "_optimized.pdiparams"; if (FileExists(optimized_model) && FileExists(optimized_params)) { config_.SetModel(optimized_model, optimized_params); LOG(INFO) << "Load Optimized model from " << optimized_model_path; @@ -596,7 +588,7 @@ std::string AnalysisPredictor::GetOptimizedModelPath() { ? config_.model_dir() : inference::analysis::GetDirRoot(config_.prog_file()); } - return model_opt_cache_dir + "/" + "_optimized"; + return model_opt_cache_dir; } void AnalysisPredictor::ClearExtraParams() { @@ -608,6 +600,25 @@ void AnalysisPredictor::ClearExtraParams() { op_desc->GetAttr("parameters")); trt_repetitive_params.insert( trt_repetitive_params.end(), trt_params.begin(), trt_params.end()); + // NOTE(ming1753): This is a trick solution to the problem of possible + // absolute paths in the model_opt_cache_dir and shape_range_info_path + // attributes in tensorrt_engine op. + auto model_opt_cache_dir_from_model = PADDLE_GET_CONST( + std::string, op_desc->GetAttr("model_opt_cache_dir")); + auto model_opt_cache_dir = GetOptimizedModelPath(); + if (op_desc->HasAttr("model_opt_cache_dir")) { + op_desc->SetAttr("model_opt_cache_dir", model_opt_cache_dir); + } + if (op_desc->HasAttr("shape_range_info_path")) { + if (config_.shape_range_info_path_.empty()) { + op_desc->SetAttr( + "shape_range_info_path", + model_opt_cache_dir + "/" + "shape_range_info.pbtxt"); + } else { + op_desc->SetAttr("shape_range_info_path", + config_.shape_range_info_path_); + } + } } } @@ -871,16 +882,33 @@ bool AnalysisPredictor::PrepareExecutor() { auto output_names = GetOutputNames(); execution_config.skip_gc_vars.insert(output_names.begin(), output_names.end()); - if (FLAGS_enable_pir_in_executor) { - pir_program_ = std::move( - paddle::TranslateLegacyProgramToProgram(*inference_program_)); + if (config_.new_ir_enabled()) { + pir_program_ = + paddle::TranslateLegacyProgramToProgram(*inference_program_); + + auto ir_printing_conditions = [this](::pir::Pass *pass, + ::pir::Operation *op) { + if (this->config_.ir_debug_passes_.empty()) { + return true; + } + return std::find(this->config_.ir_debug_passes_.begin(), + this->config_.ir_debug_passes_.end(), + pass->name()) != this->config_.ir_debug_passes_.end(); + }; +#ifdef PADDLE_WITH_CINN if (paddle::prim::PrimCommonUtils::IsFwdPrimEnabled()) { VLOG(4) << "[Prim] Decomp program in predictor begin."; DecompProgram decomp_object(pir_program_.get()); decomp_object.decomp_program(); + + auto shape_pm = std::make_shared<::pir::PassManager>( + ::pir::IrContext::Instance(), 2); + ::pir::shape::AddShapeOptimizationPass(shape_pm, *pir_program_.get()); + VLOG(4) << "[ShapeDialect] Run AddShapeOptimizationPass"; + shape_pm->Run(pir_program_.get()); } -#ifdef PADDLE_WITH_CINN + if (config_.cinn_enabled()) { VLOG(4) << "[CINN] Begin ApplyCinnPass"; cinn::dialect::ir::ApplyCinnPass(pir_program_.get(), [&] { @@ -893,103 +921,101 @@ bool AnalysisPredictor::PrepareExecutor() { pass_manager->EnablePrintStatistics(); } if (config_.ir_debug_) { - pass_manager->EnableIRPrinting(); + pass_manager->EnableIRPrinting( + std::make_unique( + ir_printing_conditions, ir_printing_conditions)); } return pass_manager; }); } #endif + // Apply some optimization passes required by the inference + ::pir::PassManager pass_pm(::pir::IrContext::Instance(), + config_.pm_opt_level_); + if (!config_.custom_passes_.empty()) { + for (const auto &custom_pass : config_.custom_passes_) { + pass_pm.AddPass(pir::PassRegistry::Instance().Get(custom_pass)); + } + } if (config_.use_gpu()) { - ::pir::PassManager gpu_pm(::pir::IrContext::Instance(), 2); - //----------------------------------------------------------------------------------------------// - // Functional pass - gpu_pm.AddPass(::pir::CreateMapOpToAnotherPass()); - gpu_pm.AddPass(::pir::CreateIdentityOpCleanPass()); - //----------------------------------------------------------------------------------------------// - - //----------------------------------------------------------------------------------------------// - // Operator fusion pass - gpu_pm.AddPass(::pir::CreateSiluFusePass()); - gpu_pm.AddPass(::pir::CreateConv2dBnFusePass()); - gpu_pm.AddPass(::pir::CreateConv2dAddActFusePass()); - gpu_pm.AddPass(::pir::CreateConv2dAddFusePass()); - gpu_pm.AddPass(::pir::CreateFusedEmbeddingEltwiseLayerNormPass()); - gpu_pm.AddPass(::pir::CreateMultiHeadMatmulFusePass()); - gpu_pm.AddPass(::pir::CreateFcFusePass()); - gpu_pm.AddPass(::pir::CreateFcElementwiseLayerNormFusePass()); - gpu_pm.AddPass(::pir::CreateMatmulScaleFusePass()); - gpu_pm.AddPass(::pir::CreateTransposeFlattenConcatFusePass()); - //----------------------------------------------------------------------------------------------// - - //----------------------------------------------------------------------------------------------// - // Basic pass required by the framework - auto params_sync_among_devices_pass = - ::pir::CreateParamsSyncAmongDevicesPass(); - params_sync_among_devices_pass->SetNotOwned(pir::kPlaceAttr, &place_); - params_sync_among_devices_pass->SetNotOwned(pir::kParamScopeAttr, - sub_scope_); - gpu_pm.AddPass(std::move(params_sync_among_devices_pass)); - - auto constant_folding_pass = ::pir::CreateConstantFoldingPass(); - constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_); - constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, sub_scope_); - gpu_pm.AddPass(std::move(constant_folding_pass)); - - gpu_pm.AddPass(::pir::CreateDeadCodeEliminationPass()); - gpu_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); - //----------------------------------------------------------------------------------------------// - if (!config_.glog_info_disabled()) { - gpu_pm.EnablePrintStatistics(); + // gpu + if (!config_.custom_pass_only_) { + for (const auto &gpu_pass : kPirGpuPasses) { + pass_pm.AddPass(pir::PassRegistry::Instance().Get(gpu_pass)); + } } - if (config_.ir_debug_) { - gpu_pm.EnableIRPrinting(); + +#ifdef PADDLE_WITH_XPU + } else if (config_.use_xpu()) { + // xpu + if (!config_.custom_pass_only_) { + for (const auto &xpu_pass : kPirXpuPasses) { + pass_pm.AddPass( + std::move(pir::PassRegistry::Instance().Get(xpu_pass))); + } } - gpu_pm.Run(pir_program_.get()); +#endif + #ifdef PADDLE_WITH_DNNL } else if (config_.mkldnn_enabled()) { - ::pir::PassManager mkldnn_pm(::pir::IrContext::Instance(), 2); - - mkldnn_pm.AddPass(::pir::CreateConv2dBiasFusePass()); - - auto constant_folding_pass = ::pir::CreateConstantFoldingPass(); - constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_); - constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, sub_scope_); - - mkldnn_pm.AddPass(std::move(constant_folding_pass)); - mkldnn_pm.AddPass(::pir::CreateDeadCodeEliminationPass()); - mkldnn_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); - //----------------------------------------------------------------------------------------------// - if (!config_.glog_info_disabled()) { - mkldnn_pm.EnablePrintStatistics(); - } - if (config_.ir_debug_) { - mkldnn_pm.EnableIRPrinting(); + // mkldnn + if (!config_.custom_pass_only_) { + for (const auto &mkldnn_pass : kPirMkldnnPasses) { + pass_pm.AddPass(pir::PassRegistry::Instance().Get(mkldnn_pass)); + } } - mkldnn_pm.Run(pir_program_.get()); #endif } else { - ::pir::PassManager cpu_pm(::pir::IrContext::Instance(), 2); - - auto constant_folding_pass = ::pir::CreateConstantFoldingPass(); - constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place_); - constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, sub_scope_); - - cpu_pm.AddPass(std::move(constant_folding_pass)); - cpu_pm.AddPass(::pir::CreateDeadCodeEliminationPass()); - cpu_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); - //----------------------------------------------------------------------------------------------// - if (!config_.glog_info_disabled()) { - cpu_pm.EnablePrintStatistics(); - } - if (config_.ir_debug_) { - cpu_pm.EnableIRPrinting(); + // cpu + if (!config_.custom_pass_only_) { + for (const auto &cpu_pass : kPirCpuPasses) { + pass_pm.AddPass(pir::PassRegistry::Instance().Get(cpu_pass)); + } } - cpu_pm.Run(pir_program_.get()); } - pir_program_ = std::move( - paddle::dialect::PdOpLowerToKernelPass(pir_program_.get(), place_)); + if (!config_.glog_info_disabled()) { + pass_pm.EnablePrintStatistics(); + } + if (config_.ir_debug_) { + pass_pm.EnableIRPrinting( + std::make_unique( + ir_printing_conditions, ir_printing_conditions)); + } + pass_pm.Run(pir_program_.get()); + + // Apply some basic passes required by the framework + ::pir::PassManager basic_pass_pm(::pir::IrContext::Instance(), + config_.pm_opt_level_); + + auto params_sync_among_devices_pass = + ::pir::CreateParamsSyncAmongDevicesPass(); + params_sync_among_devices_pass->SetNotOwned(pir::Pass::kPlaceAttr, + &place_); + params_sync_among_devices_pass->SetNotOwned(pir::Pass::kParamScopeAttr, + sub_scope_); + basic_pass_pm.AddPass(std::move(params_sync_among_devices_pass)); + auto constant_folding_pass = ::pir::CreateConstantFoldingPass(); + constant_folding_pass->SetNotOwned(pir::Pass::kPlaceAttr, &place_); + constant_folding_pass->SetNotOwned(pir::Pass::kParamScopeAttr, + sub_scope_); + basic_pass_pm.AddPass(std::move(constant_folding_pass)); + basic_pass_pm.AddPass(::pir::CreateDeadCodeEliminationPass()); + basic_pass_pm.AddPass(::pir::CreateReplaceFetchWithShadowOutputPass()); + if (!config_.glog_info_disabled()) { + basic_pass_pm.EnablePrintStatistics(); + } + if (config_.ir_debug_) { + basic_pass_pm.EnableIRPrinting( + std::make_unique( + ir_printing_conditions, ir_printing_conditions)); + } + basic_pass_pm.Run(pir_program_.get()); + //----------------------------------------------------------------------------------------------// + + pir_program_ = + paddle::dialect::PdOpLowerToKernelPass(pir_program_.get(), place_); ::pir::PassManager lowered_pm(::pir::IrContext::Instance(), 3); if (FLAGS_pir_apply_inplace_pass) { @@ -999,7 +1025,9 @@ bool AnalysisPredictor::PrepareExecutor() { lowered_pm.EnablePrintStatistics(); } if (config_.ir_debug_) { - lowered_pm.EnableIRPrinting(); + lowered_pm.EnableIRPrinting( + std::make_unique( + ir_printing_conditions, ir_printing_conditions)); } lowered_pm.Run(pir_program_.get()); @@ -1013,7 +1041,7 @@ bool AnalysisPredictor::PrepareExecutor() { } } - if (config_.enable_memory_optim_) { + if (config_.enable_memory_optim_ && !config_.use_optimized_model_) { auto *pass_res_info = inference::analysis::PassResultInfoForRuntime::Instance(); auto reuse_table = @@ -1272,7 +1300,7 @@ bool AnalysisPredictor::LoadConverterConfig( int64_t key = std::stoll(one_line[0]); for (size_t i = 1; i < one_line.size(); ++i) { int64_t val = std::stoll(one_line[i]); - if (ring_to_rank) { + if (ring_to_rank) { // NOLINT if (ring_id_to_ranks->find(key) == ring_id_to_ranks->end()) { ring_id_to_ranks->insert({key, std::vector()}); } @@ -1412,7 +1440,7 @@ bool AnalysisPredictor::Run(const std::vector &inputs, HookCollectShapeRangeInfo(); } - if (config_.new_executor_enabled()) { + if (config_.new_executor_enabled()) { // NOLINT executor_->RunInterpreterCore(); } else { // Run the inference program @@ -1485,7 +1513,7 @@ bool AnalysisPredictor::Run(const std::vector &inputs, HookCollectShapeRangeInfo(); } - if (config_.new_executor_enabled()) { + if (config_.new_executor_enabled()) { // NOLINT executor_->RunInterpreterCore(); } else { // Run the inference program @@ -1686,6 +1714,7 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetEnableIrOptim(config_.enable_ir_optim_); argument_->SetEnableMemoryOptim(config_.enable_memory_optim()); argument_->SetModelFromMemory(config_.model_from_memory_); + argument_->SetUsePIR(config_.new_ir_enabled()); // Analyze inference_program argument_->SetPredictorID(predictor_id_); argument_->SetRootPredictorID(root_predictor_id_); @@ -1726,8 +1755,13 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_); argument_->SetTRTMarkOutput(config_.trt_mark_output_); argument_->SetTRTOutputTensorNames(config_.trt_output_tensor_names_); + argument_->SetTRTParameterRunFp16(config_.trt_parameters_run_fp16_); + argument_->SetTRTParameterRunInt8(config_.trt_parameters_run_int8_); + argument_->SetTRTParameterRunBfp16(config_.trt_parameters_run_bfp16_); argument_->SetTensorRtDisabledOPs(config_.trt_disabled_ops_); argument_->SetTRTExcludeVarNames(config_.trt_exclude_var_names_); + argument_->SetTRTForbidDynamicOp(config_.trt_forbid_dynamic_op_); + argument_->SetTensorRtUseDLA(config_.trt_use_dla_); argument_->SetTensorRtDLACore(config_.trt_dla_core_); argument_->SetTensorRtUseStaticEngine(config_.trt_use_static_engine_); @@ -1908,7 +1942,7 @@ void AnalysisPredictor::PrepareArgument() { if (deleted_passes.count(pass)) continue; pass_builder->AppendPass(pass); } - } else if (config_.use_xpu()) { + } else if (config_.use_xpu()) { // NOLINT // All passes support fp16. Not reset pass_builder. } else if (config_.use_custom_device()) { // All passes support fp16. Not reset pass_builder. @@ -1924,14 +1958,14 @@ void AnalysisPredictor::PrepareArgument() { model_precision_ == phi::DataType::FLOAT32) { argument_->SetEnableIrOptim(true); pass_builder->ClearPasses(); - if (!FLAGS_enable_pir_in_executor) { + if (!config_.new_ir_enabled()) { pass_builder->AppendPass("map_op_to_another_pass"); pass_builder->AppendPass("simplify_with_basic_ops_pass"); pass_builder->AppendPass("is_test_pass"); pass_builder->AppendPass("constant_folding_pass"); } pass_builder->AppendPass("auto_mixed_precision_pass"); - if (!FLAGS_enable_pir_in_executor) { + if (!config_.new_ir_enabled()) { pass_builder->AppendPass("inplace_op_var_pass"); } LOG(INFO) << "This model run in GPU mixed precision mode with no ir " @@ -2031,7 +2065,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() { #else if (config_.mkldnn_enabled() || (config_.tensorrt_engine_enabled() && - config_.tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8)) { + config_.tensorrt_precision_mode_ == + AnalysisConfig::Precision::kInt8)) { // NOLINT argument_->PartiallyRelease(); } else { argument_.reset(nullptr); @@ -2053,8 +2088,9 @@ CreatePaddlePredictor( // Register custom operators compiled by the user. // This function can only be executed once per process. static std::once_flag custom_operators_registered; - std::call_once(custom_operators_registered, - []() { inference::RegisterAllCustomOperator(); }); + std::call_once(custom_operators_registered, [config]() { + inference::RegisterAllCustomOperator(config.new_ir_enabled()); + }); auto SetGflags = [](const AnalysisConfig &config) { auto SetGflag = [](const char *name, const char *value) { @@ -2325,7 +2361,7 @@ std::unique_ptr AnalysisPredictor::GetInputTensor( const std::string &name) { framework::Scope *scope = nullptr; #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) - if (config_.dist_config().use_dist_model()) { + if (config_.dist_config().use_dist_model()) { // NOLINT scope = scope_.get(); } else { scope = executor_->GetScope(); @@ -2376,7 +2412,7 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( const std::string &name) { framework::Scope *scope; // NOLINT #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) - if (config_.dist_config().use_dist_model()) { + if (config_.dist_config().use_dist_model()) { // NOLINT scope = scope_.get(); } else { scope = executor_->GetScope(); @@ -2426,7 +2462,7 @@ std::unique_ptr AnalysisPredictor::GetOutputTensor( bool AnalysisPredictor::ZeroCopyRun(bool switch_stream) { inference::DisplayMemoryInfo(place_, "before run"); #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) - if (config_.dist_config().use_dist_model()) { + if (config_.dist_config().use_dist_model()) { // NOLINT VLOG(3) << "ZeroCopyRun will use the fleet executor."; fleet_exe_->Run(config_.dist_config().carrier_id()); return true; @@ -2485,7 +2521,7 @@ bool AnalysisPredictor::ZeroCopyRun(bool switch_stream) { } #endif - if (config_.new_executor_enabled()) { + if (config_.new_executor_enabled()) { // NOLINT executor_->RunInterpreterCore({}, false, switch_stream); } else { executor_->Run(); @@ -2633,7 +2669,7 @@ void AnalysisPredictor::HookCollectShapeRangeInfo() { int32_tensor.data(), int32_tensor.numel() * sizeof(int)); } else if (platform::is_gpu_place(tensor->place())) { -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto *dev_ctx = pool.Get(tensor->place()); auto &int32_tensor = *tensor; if (tensor->dtype() == phi::DataType::INT64) { @@ -2751,7 +2787,7 @@ void AnalysisPredictor::StatisticShapeRangeInfo() { bool AnalysisPredictor::LoadProgramDesc() { // Initialize the inference program std::string filename; - if (!config_.model_dir().empty()) { + if (!config_.model_dir().empty()) { // NOLINT filename = config_.model_dir() + "/__model__"; } else if (!config_.prog_file().empty()) { // All parameters are saved in a single file. @@ -2856,7 +2892,7 @@ bool AnalysisPredictor::LoadParameters() { } uint64_t AnalysisPredictor::TryShrinkMemory() { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (config_.use_gpu()) { paddle::platform::EmptyCache(); } @@ -3069,49 +3105,99 @@ void AnalysisPredictor::SaveOptimModel(const std::string &dir) { exe.Run(save_program, scope(), 0, true, true); } -void AnalysisPredictor::RegisterInputHook(const InputTensorHookFunc &hookfunc) { - std::call_once(register_input_hook_flag_, [this] { - executor_->RegisterInputHook( - [this](framework::OperatorBase *op, framework::Scope *scope) { - for (auto &input : op->Inputs()) { - for (auto &var_name : input.second) { +void AnalysisPredictor::RegisterOutputHook( + const OutputTensorHookFunc &hookfunc) { + if (config_.new_ir_enabled()) { + std::call_once(register_output_hook_flag_, [this] { + executor_->RegisterOutputHook( + [this](framework::InstructionBase *instr, + framework::ValueExecutionInfo *value_exe_info, + framework::Scope *scope) { + for (auto &output : instr->Outputs()) { + auto var_name = value_exe_info->GetVarName(output.first); auto *var = scope->FindVar(var_name); if (!var || !var->IsType()) continue; auto dense_tensor = var->Get(); if (!dense_tensor.initialized()) continue; auto tensor = paddle::Tensor( std::make_shared(dense_tensor), var_name); - for (auto &hookfunc : this->input_hookfuncs_) { - hookfunc(op->Type(), var_name, tensor); + for (auto &hookfunc : this->output_hookfuncs_) { + hookfunc(instr->Name() + ":" + std::to_string(instr->Id()), + var_name, + tensor); } } - } - }); - }); - input_hookfuncs_.push_back(hookfunc); + }); + }); + output_hookfuncs_.push_back(hookfunc); + } else { + std::call_once(register_output_hook_flag_, [this] { + executor_->RegisterOutputHook( + [this](framework::OperatorBase *op, framework::Scope *scope) { + for (auto &output : op->Outputs()) { + for (auto &var_name : output.second) { + auto *var = scope->FindVar(var_name); + if (!var || !var->IsType()) continue; + auto dense_tensor = var->Get(); + if (!dense_tensor.initialized()) continue; + auto tensor = paddle::Tensor( + std::make_shared(dense_tensor), var_name); + for (auto &hookfunc : this->output_hookfuncs_) { + hookfunc(op->Type(), var_name, tensor); + } + } + } + }); + }); + output_hookfuncs_.push_back(hookfunc); + } } -void AnalysisPredictor::RegisterOutputHook( - const OutputTensorHookFunc &hookfunc) { - std::call_once(register_output_hook_flag_, [this] { - executor_->RegisterOutputHook( - [this](framework::OperatorBase *op, framework::Scope *scope) { - for (auto &output : op->Outputs()) { - for (auto &var_name : output.second) { +void AnalysisPredictor::RegisterInputHook(const InputTensorHookFunc &hookfunc) { + if (config_.new_ir_enabled()) { + std::call_once(register_input_hook_flag_, [this] { + executor_->RegisterInputHook( + [this](framework::InstructionBase *instr, + framework::ValueExecutionInfo *value_exe_info, + framework::Scope *scope) { + for (auto &input : instr->Inputs()) { + auto var_name = value_exe_info->GetVarName(input.first); auto *var = scope->FindVar(var_name); if (!var || !var->IsType()) continue; auto dense_tensor = var->Get(); if (!dense_tensor.initialized()) continue; auto tensor = paddle::Tensor( std::make_shared(dense_tensor), var_name); - for (auto &hookfunc : this->output_hookfuncs_) { - hookfunc(op->Type(), var_name, tensor); + for (auto &hookfunc : this->input_hookfuncs_) { + hookfunc(instr->Name() + ":" + std::to_string(instr->Id()), + var_name, + tensor); } } - } - }); - }); - output_hookfuncs_.push_back(hookfunc); + }); + }); + input_hookfuncs_.push_back(hookfunc); + } else { + std::call_once(register_input_hook_flag_, [this] { + executor_->RegisterInputHook( + [this](framework::OperatorBase *op, framework::Scope *scope) { + for (auto &input : op->Inputs()) { + for (auto &var_name : input.second) { + auto *var = scope->FindVar(var_name); + if (!var || !var->IsType()) continue; + auto dense_tensor = var->Get(); + if (!dense_tensor.initialized()) continue; + auto tensor = paddle::Tensor( + std::make_shared(dense_tensor), var_name); + for (auto &hookfunc : this->input_hookfuncs_) { + hookfunc(op->Type(), var_name, tensor); + } + } + } + }); + }); + input_hookfuncs_.push_back(hookfunc); + } } template <> @@ -3412,7 +3498,7 @@ uint64_t Predictor::TryShrinkMemory() { return predictor_->TryShrinkMemory(); } void Predictor::RegisterOutputHook(const OutputTensorHookFunc &hookfunc) { predictor_->RegisterOutputHook(hookfunc); } -void Predictor::RegisterInputHook(const OutputTensorHookFunc &hookfunc) { +void Predictor::RegisterInputHook(const InputTensorHookFunc &hookfunc) { predictor_->RegisterInputHook(hookfunc); } @@ -3549,39 +3635,39 @@ bool InternalUtils::RunWithRuntimeConfig(paddle_infer::Predictor *p, void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c, bool with_interleaved) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) c->trt_with_interleaved_ = with_interleaved; #endif } void InternalUtils::SetTransformerPosid( paddle_infer::Config *c, const std::string &tensorrt_transformer_posid) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) c->tensorrt_transformer_posid_ = tensorrt_transformer_posid; #endif } void InternalUtils::SetTransformerMaskid( paddle_infer::Config *c, const std::string &tensorrt_transformer_maskid) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) c->tensorrt_transformer_maskid_ = tensorrt_transformer_maskid; #endif } void InternalUtils::DisableTensorRtHalfOps( paddle_infer::Config *c, const std::unordered_set &ops) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) c->trt_ops_run_float_ = ops; #endif } void InternalUtils::SyncStream(paddle_infer::Predictor *p) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto *pred = dynamic_cast(p->predictor_.get()); paddle::platform::DeviceContextPool &pool = paddle::platform::DeviceContextPool::Instance(); auto *dev_ctx = reinterpret_cast(pool.Get(pred->place_)); - cudaStreamSynchronize(dev_ctx->stream()); + paddle::gpuStreamSynchronize(dev_ctx->stream()); #endif } void InternalUtils::SyncStream(cudaStream_t stream) { @@ -3590,5 +3676,11 @@ void InternalUtils::SyncStream(cudaStream_t stream) { #endif } +void InternalUtils::SyncStream(hipStream_t stream) { +#ifdef PADDLE_WITH_HIP + hipStreamSynchronize(stream); +#endif +} + } // namespace experimental } // namespace paddle_infer diff --git a/paddle/fluid/inference/api/analysis_predictor.h b/paddle/fluid/inference/api/analysis_predictor.h index 1c107e936d69a..fe494cab93a90 100644 --- a/paddle/fluid/inference/api/analysis_predictor.h +++ b/paddle/fluid/inference/api/analysis_predictor.h @@ -29,7 +29,7 @@ #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/inference/api/resource_manager.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h" -#include "paddle/fluid/string/printf.h" +#include "paddle/utils/string/printf.h" #if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE) #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" diff --git a/paddle/fluid/inference/api/api_impl.cc b/paddle/fluid/inference/api/api_impl.cc index c8eaa1c3ebd1e..1ae582feb4acf 100644 --- a/paddle/fluid/inference/api/api_impl.cc +++ b/paddle/fluid/inference/api/api_impl.cc @@ -101,7 +101,7 @@ bool NativePaddlePredictor::Init( executor_ = std::make_unique(place_); // Initialize the inference program - if (!config_.model_dir.empty()) { + if (!config_.model_dir.empty()) { // NOLINT // Parameters are saved in separate files sited in // the specified `dirname`. inference_program_ = paddle::inference::Load( @@ -286,7 +286,7 @@ bool NativePaddlePredictor::SetFeed(const std::vector &inputs, } input.set_lod(lod); int idx = -1; - if (config_.specify_input_name) { + if (config_.specify_input_name) { // NOLINT idx = static_cast(feed_names_[inputs[i].name]); } else { idx = PADDLE_GET_CONST(int, feeds_[i]->GetAttr("col")); diff --git a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt index 727af4e00605e..833fc98d36dba 100644 --- a/paddle/fluid/inference/api/demo_ci/CMakeLists.txt +++ b/paddle/fluid/inference/api/demo_ci/CMakeLists.txt @@ -8,6 +8,7 @@ option(USE_TENSORRT "Compile demo with TensorRT." OFF) option(WITH_ONNXRUNTIME "Compile demo with ONNXRuntime" OFF) option(WITH_SHARED_PHI "Compile demo with phi shared lib" ON) option(CUSTOM_OPERATOR_FILES "List of file names for custom operators" "") +option(CUSTOM_PASS_FILES "List of file names for custom passes" "") if(NOT WITH_STATIC_LIB) add_definitions("-DPADDLE_WITH_SHARED_LIB") @@ -85,7 +86,7 @@ else() if(WITH_MKL) set(FLAG_OPENMP "-fopenmp") endif() - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 ${FLAG_OPENMP}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 ${FLAG_OPENMP}") endif() if(WITH_GPU) @@ -262,10 +263,14 @@ if(CUSTOM_OPERATOR_FILES) include_directories("${CUDA_INCLUDE_DIRS}") endif() add_library(pd_infer_custom_op SHARED ${CUSTOM_OPERATOR_FILES}) - target_link_libraries(pd_infer_custom_op ${DEPS}) set(DEPS ${DEPS} pd_infer_custom_op) endif() +if(CUSTOM_PASS_FILES) + add_library(pd_infer_custom_pass SHARED ${CUSTOM_PASS_FILES}) + set(DEPS ${DEPS} pd_infer_custom_pass) +endif() + add_executable(${DEMO_NAME} ${DEMO_NAME}.cc) target_link_libraries(${DEMO_NAME} ${DEPS}) if(WIN32) diff --git a/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc b/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc index b4c8cccb8e790..f9c777f983704 100644 --- a/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc +++ b/paddle/fluid/inference/api/demo_ci/custom_op_demo.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -47,12 +47,13 @@ void run(Predictor *predictor, int main(int argc, char **argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - paddle::AnalysisConfig config; + Config config; config.EnableUseGpu(100, 0); config.SetModel(FLAGS_modeldir + "/custom_relu.pdmodel", FLAGS_modeldir + "/custom_relu.pdiparams"); config.EnableNewExecutor(true); - auto predictor{paddle_infer::CreatePredictor(config)}; + config.EnableNewIR(true); + auto predictor = CreatePredictor(config); std::vector input_shape = {1, 1, 28, 28}; std::vector input_data(1 * 1 * 28 * 28, 1); std::vector out_data; diff --git a/paddle/fluid/inference/api/demo_ci/custom_pass_demo.cc b/paddle/fluid/inference/api/demo_ci/custom_pass_demo.cc new file mode 100644 index 0000000000000..bd335401e736f --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/custom_pass_demo.cc @@ -0,0 +1,98 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include + +#include "paddle/extension.h" +#include "paddle_inference_api.h" //NOLINT + +DEFINE_string(modeldir, "", "Directory of the inference model."); + +using paddle_infer::Config; +using paddle_infer::CreatePredictor; +using paddle_infer::Predictor; + +std::shared_ptr InitPredictor(bool use_custom_pass) { + Config config; + config.EnableUseGpu(100, 0); + config.SetModel(FLAGS_modeldir + "/inference.pdmodel", + FLAGS_modeldir + "/inference.pdiparams"); + config.EnableNewExecutor(true); + config.EnableNewIR(true); + // config.SwitchIrDebug(true); + if (use_custom_pass) { + config.EnableCustomPasses({"relu_replace_pass"}); + } + + return CreatePredictor(config); +} + +std::vector GetOutputData(const std::shared_ptr &predictor) { + auto input_names = predictor->GetInputNames(); + auto input_shapes = predictor->GetInputTensorShape(); + + for (const auto &input_name : input_names) { + // update input shape's batch size + input_shapes[input_name][0] = 1; + } + + std::vector inputs, outputs; + for (const auto &input_name : input_names) { + auto input_tensor = paddle::full(input_shapes[input_name], + 0.5, + paddle::DataType::FLOAT32, + paddle::GPUPlace{}); + input_tensor.set_name(input_name); + inputs.emplace_back(std::move(input_tensor)); + } + CHECK(predictor->Run(inputs, &outputs)); + + CHECK(outputs[0].place() == paddle::GPUPlace{}); + CHECK(outputs[0].dtype() == paddle::DataType::FLOAT32); + auto output = outputs[0].copy_to(paddle::CPUPlace{}, true); + + std::vector output_data; + for (int64_t i = 0; i < output.numel(); i++) { + output_data.push_back(output.data()[i]); + } + return output_data; +} + +bool AreEqual(const std::vector &vec1, + const std::vector &vec2, + float epsilon) { + if (vec1.size() != vec2.size()) { + return false; + } + for (size_t i = 0; i < vec1.size(); ++i) { + if (std::fabs(vec1[i] - vec2[i]) > epsilon) { + return false; + } + } + return true; +} + +int main(int argc, char **argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + auto base_data = GetOutputData(InitPredictor(false)); + auto custom_data = GetOutputData(InitPredictor(true)); + + CHECK(AreEqual(base_data, custom_data, 1e-3)); + + return 0; +} diff --git a/paddle/fluid/inference/api/demo_ci/custom_relu_pass.cc b/paddle/fluid/inference/api/demo_ci/custom_relu_pass.cc new file mode 100644 index 0000000000000..15164aa3962b7 --- /dev/null +++ b/paddle/fluid/inference/api/demo_ci/custom_relu_pass.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/extension.h" + +namespace { + +class ReluReplacePattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { return "ReluReplacePattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &relu = pat.Op("pd_op.relu"); + relu({&pat.Tensor("in")}, {&pat.Tensor("out")}); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + const auto &custom_relu = res.Op("custom_op.custom_relu"); + custom_relu({&res.Tensor("in")}, {&res.Tensor("out")}); + } +}; + +class ReluReplacePass : public pir::PatternRewritePass { + public: + ReluReplacePass() : pir::PatternRewritePass("relu_replace_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(paddle::drr::Create(context)); + return ps; + } +}; + +} // namespace + +REGISTER_IR_PASS(relu_replace_pass, ReluReplacePass); diff --git a/paddle/fluid/inference/api/demo_ci/run.sh b/paddle/fluid/inference/api/demo_ci/run.sh index 795b414258b56..e1369ca51c5d0 100755 --- a/paddle/fluid/inference/api/demo_ci/run.sh +++ b/paddle/fluid/inference/api/demo_ci/run.sh @@ -113,6 +113,15 @@ else wget -q https://paddle-inference-dist.bj.bcebos.com/inference_demo/custom_operator/custom_relu_infer_model.tgz tar xzf *.tgz fi +cd .. + +#download custom_pass_demo data +mkdir -p custom_pass +cd custom_pass +if [ ! -d resnet50 ]; then + wget https://paddle-inference-dist.bj.bcebos.com/Paddle-Inference-Demo/resnet50.tgz + tar xzf resnet50.tgz +fi # compile and test the demo cd $current_dir @@ -301,13 +310,37 @@ for WITH_STATIC_LIB in ON OFF; do -DCUSTOM_OPERATOR_FILES=$CUSTOM_OPERATOR_FILES \ -DWITH_ONNXRUNTIME=$WITH_ONNXRUNTIME make -j$(nproc) - FLAGS_enable_pir_in_executor=1 ./custom_op_demo \ + ./custom_op_demo \ --modeldir=$DATA_DIR/custom_op/custom_relu_infer_model if [ $? -ne 0 ]; then echo "custom_op_demo runs failed " >> ${current_dir}/test_summary.txt EXIT_CODE=1 fi - fi + fi + + # --------custom pass demo on linux/mac------ + if [ $TEST_GPU_CPU == ON -a $WITH_STATIC_LIB == OFF ]; then + rm -rf * + CUSTOM_OPERATOR_FILES="custom_relu_op.cc;custom_relu_op.cu" + CUSTOM_PASS_FILES="custom_relu_pass.cc" + cmake .. -DPADDLE_LIB=${inference_install_dir} \ + -DWITH_MKL=$TURN_ON_MKL \ + -DDEMO_NAME=custom_pass_demo \ + -DWITH_GPU=$TEST_GPU_CPU \ + -DWITH_STATIC_LIB=OFF \ + -DUSE_TENSORRT=$USE_TENSORRT \ + -DTENSORRT_ROOT=$TENSORRT_ROOT_DIR \ + -DCUSTOM_OPERATOR_FILES=$CUSTOM_OPERATOR_FILES \ + -DCUSTOM_PASS_FILES=${CUSTOM_PASS_FILES} \ + -DWITH_ONNXRUNTIME=$WITH_ONNXRUNTIME + make -j$(nproc) + ./custom_pass_demo \ + --modeldir=$DATA_DIR/custom_pass/resnet50 + if [ $? -ne 0 ]; then + echo "custom_pass_demo runs failed " >> ${current_dir}/test_summary.txt + EXIT_CODE=1 + fi + fi fi done diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor_test.cc b/paddle/fluid/inference/api/details/zero_copy_tensor_test.cc index c3589f4251791..fda408b15df5f 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor_test.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor_test.cc @@ -57,9 +57,10 @@ std::unique_ptr CreateTensor(paddle_infer::PlaceType place, template struct RandomGenerator { - RandomGenerator(double min = (std::numeric_limits::min)(), - double max = (std::numeric_limits::max)()) - : dist_{static_cast(min), static_cast(max)} {} + RandomGenerator( + double min = static_cast((std::numeric_limits::min)()), + double max = static_cast((std::numeric_limits::max)())) + : dist_{min, max} {} T operator()() { return static_cast(dist_(random_engine_)); } private: diff --git a/paddle/fluid/inference/api/helper.cc b/paddle/fluid/inference/api/helper.cc index e9eb090a771d2..416a62e980fe5 100644 --- a/paddle/fluid/inference/api/helper.cc +++ b/paddle/fluid/inference/api/helper.cc @@ -13,16 +13,26 @@ // limitations under the License. #include "paddle/fluid/inference/api/helper.h" +#include +#include "paddle/common/enforce.h" +#include "paddle/common/errors.h" #include "paddle/common/flags.h" #include "paddle/fluid/framework/custom_operator.h" +#include "paddle/fluid/framework/custom_operator_utils.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/pir/drr/src/ir_operation_factory.h" #include "paddle/fluid/platform/init.h" #include "paddle/phi/api/ext/op_meta_info.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/ir_context.h" - -COMMON_DECLARE_bool(enable_pir_in_executor); +#include "paddle/pir/include/core/operation.h" +#include "paddle/pir/include/core/value.h" namespace paddle { namespace inference { @@ -50,14 +60,13 @@ std::string to_string>>( return ss.str(); } -void RegisterAllCustomOperator() { - auto &op_meta_info_map = OpMetaInfoMap::Instance(); - const auto &meta_info_map = op_meta_info_map.GetMap(); +void RegisterAllCustomOperator(bool use_pir) { + const auto &meta_info_map = OpMetaInfoMap::Instance().GetMap(); for (auto &pair : meta_info_map) { - if (FLAGS_enable_pir_in_executor) { - ::pir::IrContext *ctx = ::pir::IrContext::Instance(); + if (use_pir) { auto *custom_dialect = - ctx->GetOrRegisterDialect(); + ::pir::IrContext::Instance() + ->GetOrRegisterDialect(); if (custom_dialect->HasRegistered(pair.first)) { LOG(INFO) << "The operator `" << pair.first << "` has been registered. " @@ -65,9 +74,349 @@ void RegisterAllCustomOperator() { continue; } for (const auto &meta_info : pair.second) { - LOG(INFO) << "register pir custom op :" << pair.first; + LOG(INFO) << "register pir custom op: " << pair.first; custom_dialect->RegisterCustomOp(meta_info); } + + std::string pir_op_name = + paddle::framework::kCustomDialectPrefix + pair.first; + paddle::drr::OperationFactory::Instance().RegisterOperationCreator( + pir_op_name, + [pair, pir_op_name]( + const std::vector<::pir::Value> &inputs, + const ::pir::AttributeMap &attrs, + ::pir::PatternRewriter &rewriter) mutable -> ::pir::Operation * { + const auto &meta_inputs = + paddle::OpMetaInfoHelper::GetInputs(pair.second[0]); + const auto &meta_attrs = + paddle::OpMetaInfoHelper::GetAttrs(pair.second[0]); + const auto &meta_outputs = + paddle::OpMetaInfoHelper::GetOutputs(pair.second[0]); + const auto &inplace_map = + paddle::OpMetaInfoHelper::GetInplaceMap(pair.second[0]); + const auto &inplace_reverse_map = + paddle::OpMetaInfoHelper::GetInplaceReverseMap(pair.second[0]); + auto infershape_func = + OpMetaInfoHelper::GetInferShapeFn(pair.second[0]); + auto inferdtype_func = + OpMetaInfoHelper::GetInferDtypeFn(pair.second[0]); + + PADDLE_ENFORCE_EQ( + meta_inputs.size(), + inputs.size(), + paddle::platform::errors::InvalidArgument( + "The number of inputs for the custom operator [%s] given " + "in the Pattern needs to be consistent with the number at " + "implementation time.", + pir_op_name)); + PADDLE_ENFORCE_EQ( + meta_attrs.size(), + attrs.size(), + paddle::platform::errors::InvalidArgument( + "The number of attrs for the custom operator [%s] given " + "in the Pattern needs to be consistent with the number at " + "implementation time.", + pir_op_name)); + + if (!inplace_map.empty()) { + pir_op_name += "_"; + } + ::pir::OperationArgument argument( + rewriter.ir_context()->GetRegisteredOpInfo(pir_op_name)); + argument.attributes = attrs; + argument.inputs = inputs; + + std::vector argument_outputs; + std::vector> input_shapes; + std::vector input_dtypes; + std::unordered_map input_name2id_map; + std::vector>> vec_input_shapes; + std::vector> vec_input_dtypes; + std::unordered_map vec_input_name2id_map; + std::vector custom_attrs; + int input_index = 0; + int vec_input_index = 0; + + for (size_t i = 0; i < meta_inputs.size(); ++i) { + const auto &meta_input = meta_inputs.at(i); + if (!inputs[i]) { + VLOG(6) << "Add un-initialized tensor " + "because the optional input is None"; + if (paddle::framework::detail::IsDuplicableVar(meta_input)) { + std::vector> vec_input_shape; + std::vector vec_input_dtype; + vec_input_shapes.emplace_back(vec_input_shape); + vec_input_dtypes.emplace_back(vec_input_dtype); + vec_input_name2id_map[meta_inputs[i]] = vec_input_index; + vec_input_index++; + } else { + std::vector input_shape; + DataType input_dtype = DataType::UNDEFINED; + input_shapes.emplace_back(input_shape); + input_dtypes.emplace_back(input_dtype); + input_name2id_map[meta_inputs[i]] = input_index; + input_index++; + } + continue; + } + if (paddle::framework::detail::IsDuplicableVar(meta_input)) { + PADDLE_ENFORCE_EQ( + inputs[i].type().isa<::pir::VectorType>(), + true, + paddle::platform::errors::InvalidArgument( + "The [%d] input of the custom operator [%s] " + "should be a pir::VectorType.", + i, + pir_op_name)); + std::vector> tmp_input_shapes; + std::vector tmp_input_dtypes; + vec_input_name2id_map[meta_inputs[i]] = vec_input_index; + vec_input_index++; + auto input_value_types = + inputs[i].type().dyn_cast<::pir::VectorType>().data(); + for (auto &input_value_type : input_value_types) { + auto input_tensor = + input_value_type + .dyn_cast(); + tmp_input_shapes.push_back( + phi::vectorize(input_tensor.dims())); + tmp_input_dtypes.push_back( + paddle::dialect::TransToPhiDataType( + input_tensor.dtype())); + } + vec_input_shapes.push_back(tmp_input_shapes); + vec_input_dtypes.push_back(tmp_input_dtypes); + } else { + input_name2id_map[meta_inputs[i]] = input_index; + input_index++; + auto input_tensor = + inputs[i] + .type() + .dyn_cast(); + input_shapes.push_back(phi::vectorize(input_tensor.dims())); + input_dtypes.push_back( + paddle::dialect::TransToPhiDataType(input_tensor.dtype())); + } + } + + for (const auto &meta_attr : meta_attrs) { + auto attr_name_and_type = paddle::ParseAttrStr(meta_attr); + auto attr_name = attr_name_and_type[0]; + auto attr_type = attr_name_and_type[1]; + PADDLE_ENFORCE_EQ(attrs.count(attr_name), + true, + paddle::platform::errors::InvalidArgument( + "The attr [%s] in the custom operator [%s] " + "specified in the Pattern needs to be " + "consistent with the implementation", + attr_name, + pir_op_name)); + VLOG(6) << "Custom operator add attrs " << attr_name + << " to CustomOpKernelContext. Attribute type = " + << attr_type; + if (attr_type == "bool") { + auto bool_attr = + attrs.at(attr_name).dyn_cast<::pir::BoolAttribute>().data(); + custom_attrs.emplace_back(bool_attr); + } else if (attr_type == "int") { + int int_attr = attrs.at(attr_name) + .dyn_cast<::pir::Int32Attribute>() + .data(); + custom_attrs.emplace_back(int_attr); + } else if (attr_type == "float") { + float float_attr = attrs.at(attr_name) + .dyn_cast<::pir::FloatAttribute>() + .data(); + custom_attrs.emplace_back(float_attr); + } else if (attr_type == "int64_t") { + int64_t long_attr = attrs.at(attr_name) + .dyn_cast<::pir::Int64Attribute>() + .data(); + custom_attrs.emplace_back(long_attr); + } else if (attr_type == "std::string") { + std::string str_attr = attrs.at(attr_name) + .dyn_cast<::pir::StrAttribute>() + .AsString(); + custom_attrs.emplace_back(str_attr); + } else if (attr_type == "std::vector") { + auto vec_attr = attrs.at(attr_name) + .dyn_cast<::pir::ArrayAttribute>() + .AsVector(); + std::vector vec_int_attr; + for (const auto &int_attr : vec_attr) { + vec_int_attr.push_back( + int_attr.dyn_cast<::pir::Int32Attribute>().data()); + } + custom_attrs.emplace_back(vec_int_attr); + } else if (attr_type == "std::vector") { + auto vec_attr = attrs.at(attr_name) + .dyn_cast<::pir::ArrayAttribute>() + .AsVector(); + std::vector vec_float_attr; + for (const auto &float_attr : vec_attr) { + vec_float_attr.push_back( + float_attr.dyn_cast<::pir::FloatAttribute>().data()); + } + custom_attrs.emplace_back(vec_float_attr); + } else if (attr_type == "std::vector") { + auto vec_attr = attrs.at(attr_name) + .dyn_cast<::pir::ArrayAttribute>() + .AsVector(); + std::vector vec_long_attr; + for (const auto &long_attr : vec_attr) { + vec_long_attr.push_back( + long_attr.dyn_cast<::pir::Int64Attribute>().data()); + } + custom_attrs.emplace_back(vec_long_attr); + } else if (attr_type == "std::vector") { + auto vec_attr = attrs.at(attr_name) + .dyn_cast<::pir::ArrayAttribute>() + .AsVector(); + std::vector vec_string_attr; + for (const auto &string_attr : vec_attr) { + vec_string_attr.push_back( + string_attr.dyn_cast<::pir::StrAttribute>().AsString()); + } + custom_attrs.emplace_back(vec_string_attr); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupported `%s` type value as custom attribute now. " + "Supported data types include `bool`, `int`, `float`, " + "`int64_t`, `std::string`, `std::vector`, " + "`std::vector`, `std::vector`, " + "`std::vector`, Please check whether " + "the attribute data type and data type string are matched.", + attr_type)); + } + } + + paddle::framework::CheckDefaultInferShapeDtype( + infershape_func, inferdtype_func, pair.second[0]); + std::vector> output_shapes = + paddle::framework::RunInferShape(infershape_func, + pair.second[0], + input_shapes, + input_name2id_map, + vec_input_shapes, + vec_input_name2id_map, + custom_attrs); + std::vector output_dtypes = + paddle::framework::RunInferDtype(inferdtype_func, + pair.second[0], + input_dtypes, + input_name2id_map, + vec_input_dtypes, + vec_input_name2id_map, + custom_attrs); + + size_t all_values_num = 0; + // output name -> value num (that output should hold) + std::unordered_map output_name2value_num; + for (const auto &output : meta_outputs) { + if (paddle::framework::detail::IsDuplicableVar(output)) { + PADDLE_ENFORCE_NE(inplace_reverse_map.find(output), + inplace_reverse_map.end(), + phi::errors::InvalidArgument( + "Only support vector output that is set " + "for inplace, Please use " + "`SetInplaceMap` in your output when " + "registry custom operator.")); + const auto &input = inplace_reverse_map.at(output); + auto index = vec_input_name2id_map[input]; + auto &vec_input_shape = vec_input_shapes[index]; + output_name2value_num[output] = vec_input_shape.size(); + } else { + if (inplace_reverse_map.find(output) != + inplace_reverse_map.end()) { + const auto &input = inplace_reverse_map.at(output); + auto index = input_name2id_map[input]; + // input_shapes[index] is dim of tensor, if the dim doesn't + // have element, it must be a optional tensor that is None in + // custom operator + output_name2value_num[output] = + input_shapes[index].empty() ? 0 : 1; + } else { + output_name2value_num[output]++; + } + } + all_values_num += output_name2value_num[output]; + } + + PADDLE_ENFORCE_EQ( + output_shapes.size(), + all_values_num, + phi::errors::InvalidArgument("The number of output shapes " + "after running custom operator's " + "InferShapeFunc is wrong, " + "expected contains %d Tensors' " + "shape, but actually contains %d " + "Tensors' shape", + all_values_num, + output_shapes.size())); + + PADDLE_ENFORCE_EQ( + output_dtypes.size(), + all_values_num, + phi::errors::InvalidArgument("The number of output dtypes " + "after running custom operator's " + "InferDtypeFunc is wrong, " + "expected contains %d Tensors' " + "dtype, but actually contains %d " + "Tensors' dtype", + all_values_num, + output_dtypes.size())); + + size_t value_index = 0; + for (const auto &output : meta_outputs) { + auto value_num = output_name2value_num[output]; + if (value_num == 0) { + // Optional value condition + pir::Type out_type; + argument_outputs.push_back(out_type); + continue; + } + if (paddle::framework::detail::IsDuplicableVar(output)) { + auto value_num = output_name2value_num[output]; + std::vector out_types; + for (size_t j = 0; j < value_num; ++j) { + auto ddims = phi::make_ddim(output_shapes[value_index]); + auto dtype = output_dtypes[value_index]; + phi::DataLayout layout{DataLayout::NCHW}; + phi::LoD lod; + out_types.push_back(paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dtype), + ddims, + layout, + lod, + 0)); + value_index++; + } + pir::Type out_vector_type = + pir::VectorType::get(pir::IrContext::Instance(), out_types); + argument_outputs.push_back(out_vector_type); + } else { + auto ddims = phi::make_ddim(output_shapes[value_index]); + auto dtype = output_dtypes[value_index]; + phi::DataLayout layout{DataLayout::NCHW}; + phi::LoD lod; + auto out_type = paddle::dialect::DenseTensorType::get( + pir::IrContext::Instance(), + paddle::dialect::TransToIrDataType(dtype), + ddims, + layout, + lod, + 0); + argument_outputs.push_back(out_type); + value_index++; + } + } + + argument.AddOutputs(argument_outputs.begin(), + argument_outputs.end()); + ::pir::PassStopGradientsDefaultly(argument); + return rewriter.Build(std::move(argument)); + }); } const auto &all_op_kernels{framework::OperatorWithKernel::AllOpKernels()}; if (all_op_kernels.find(pair.first) == all_op_kernels.end()) { diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 22a5319bb0dbc..28f126f4fd344 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -35,8 +35,8 @@ #include "paddle/fluid/memory/stats.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" +#include "paddle/utils/string/printf.h" extern std::string paddle::framework::DataTypeToString( const framework::proto::VarType::Type type); @@ -431,7 +431,7 @@ static bool IsFileExists(const std::string &path) { return exists; } -void RegisterAllCustomOperator(); +void RegisterAllCustomOperator(bool use_pir); void InitGflagsFromEnv(); diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 46ae4624ea9e8..76222b84d4624 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -78,7 +78,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForRNNWeights( check_var(wh_var, wh_name); phi::DenseTensor* wx_tensor = wx_var->GetMutable(); phi::DenseTensor* wh_tensor = wh_var->GetMutable(); - if (gru) { + if (gru) { // NOLINT scales_[wx_name] = GetMaxChGRUScalingFactor(*wx_tensor, *wh_tensor); } else { scales_[wx_name] = GetMaxChLSTMScalingFactor(*wx_tensor, *wh_tensor); @@ -215,6 +215,7 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateSingleScale( switch (rule) { case ScaleAlgo::MAX: + case ScaleAlgo::KL: scales_[var_name] = GetMaxScalingFactor(var_tensor, is_unsigned); break; case ScaleAlgo::MAX_CH: @@ -227,9 +228,6 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateSingleScale( is_unsigned, /*is_transposed*/ true); break; - case ScaleAlgo::KL: - scales_[var_name] = GetKLScalingFactor(var_tensor, is_unsigned); - break; default: throw std::runtime_error( "MkldnnQuantizer: Unexpected ScaleAlgo specified."); diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.h b/paddle/fluid/inference/api/mkldnn_quantizer.h index 17fe7fff3aa21..7b6549abe5afd 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.h +++ b/paddle/fluid/inference/api/mkldnn_quantizer.h @@ -27,7 +27,7 @@ #include "paddle/fluid/inference/api/details/reset_tensor_array.h" #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" -#include "paddle/fluid/string/printf.h" +#include "paddle/utils/string/printf.h" #include "paddle/utils/test_macros.h" #ifdef PADDLE_WITH_TESTING #include diff --git a/paddle/fluid/inference/api/onnxruntime_predictor.h b/paddle/fluid/inference/api/onnxruntime_predictor.h index 33c37042aac43..463bf76df1f22 100644 --- a/paddle/fluid/inference/api/onnxruntime_predictor.h +++ b/paddle/fluid/inference/api/onnxruntime_predictor.h @@ -27,7 +27,7 @@ #include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/api/paddle_inference_api.h" #include "paddle/fluid/platform/device/gpu/gpu_types.h" -#include "paddle/fluid/string/printf.h" +#include "paddle/utils/string/printf.h" #include "paddle2onnx/converter.h" #ifdef PADDLE_WITH_TESTING diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index cae544ff2c234..dcf17dc4399c2 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -111,6 +111,7 @@ struct PD_INFER_DECL XpuConfig { bool conv_autotune_file_writeback{false}; // Fc autotune level. The Optional values are 0-9. Default 0 means no + // autotune. int fc_autotune_level{0}; // Base fc autotune info is read from fc_autotune_file. std::string fc_autotune_file; @@ -253,7 +254,7 @@ struct PD_INFER_DECL AnalysisConfig { void SetModel(const std::string& model_dir) { model_dir_ = model_dir; } /// - /// \brief Set the combined model with two specific pathes for program and + /// \brief Set the combined model with two specific paths for program and /// parameters. /// /// \param prog_file_path model file path of the combined model. @@ -367,7 +368,7 @@ struct PD_INFER_DECL AnalysisConfig { /// void EnableXpu(int l3_size = 0xfffc00, bool l3_locked = false, - bool conv_autotune = true, + bool conv_autotune = false, const std::string& conv_autotune_file = "", const std::string& transformer_encoder_precision = "int16", bool transformer_encoder_adaptive_seqlen = false, @@ -596,12 +597,12 @@ struct PD_INFER_DECL AnalysisConfig { /// \brief Control whether to perform IR graph optimization. /// If turned off, the AnalysisConfig will act just like a NativeConfig. /// - /// \param x Whether the ir graph optimization is actived. + /// \param x Whether the ir graph optimization is activated. /// void SwitchIrOptim(int x = true) { enable_ir_optim_ = x; } /// /// \brief A boolean state telling whether the ir graph optimization is - /// actived. + /// activated. /// /// \return bool Whether to use ir graph optimization. /// @@ -810,9 +811,29 @@ struct PD_INFER_DECL AnalysisConfig { /// void Exp_DisableTensorRtOPs(const std::vector& ops); + /// + /// \brief Prevent TensorRtSubgraph running in Paddle-TRT + /// NOTE: just experimental, not an official stable API, easy to be broken. + /// void Exp_DisableTensorRtSubgraph( const std::vector& var_name_not_trt); + /// + /// \brief Specify TensorRT subgraph precision,fp16, int8 or bfp16(TensorRT + /// Version>=9.0) NOTE: just experimental, not an official stable API, easy to + /// be broken. + /// + void Exp_SpecifyTensorRTSubgraphPrecision( + const std::vector& trt_parameters_fp16, + const std::vector& trt_parameters_int8, + const std::vector& trt_parameters_bfp16); + + /// + /// \brief Prevent DynamicShape OPs running in Paddle-TRT + /// NOTE: just experimental, not an official stable API, easy to be broken. + /// + void Exp_DisableTensorRTDynamicShapeOPs(bool trt_forbid_dynamic_op); + /// /// \brief Replace some TensorRT plugins to TensorRT OSS( /// https://github.com/NVIDIA/TensorRT), with which some models's inference @@ -879,10 +900,22 @@ struct PD_INFER_DECL AnalysisConfig { /// int tensorrt_optimization_level() { return trt_optimization_level_; } + /// \brief A boolean state telling whether to use new executor. + /// + /// \return bool whether to use new executor. + /// void EnableNewExecutor(bool x = true) { use_new_executor_ = x; } bool new_executor_enabled() const { return use_new_executor_; } + /// \brief A boolean state telling whether to use new IR. + /// + /// \return bool whether to use new IR. + /// + void EnableNewIR(bool x = true) { use_pir_ = x; } + + bool new_ir_enabled() const { return use_pir_; } + /// /// \brief Control whether to use optimized model to inference. /// @@ -934,7 +967,7 @@ struct PD_INFER_DECL AnalysisConfig { /// /// \param x whether to debug IR graph analysis phase. /// - void SwitchIrDebug(int x = true); + void SwitchIrDebug(int x = true, const std::vector& passes = {}); /// /// \brief Turn on MKLDNN. @@ -1206,6 +1239,30 @@ struct PD_INFER_DECL AnalysisConfig { /// bool cinn_enabled() const; + /// + /// \brief Set the custom passes list . + /// + /// \param passes The custom passes list. + /// \param custom_pass_only Custom pass run mode. The default is false, + /// which means that paddle pass will run after custom pass. + /// + void EnableCustomPasses(const std::vector& passes, + bool custom_pass_only = false); + + /// + /// \brief Set pir Optimization level. + /// \param opt_level The optimization level + /// The optimization Level in range [0,4], Default 2. + /// Higher optimization level allows the predictor to apply more passes. + /// If 0, Only basic pass support. + /// If 1, Additional support for functional pass. + /// If 2, Additional support the fusion logical pass,maybe affect precision + /// and speed. + /// If 3, support layout pass, etc. + /// If 4, add the radicaloptimization, maybe affect precision, etc. + /// + void SetOptimizationLevel(int opt_level); + protected: // Update the config. void Update(); @@ -1213,7 +1270,7 @@ struct PD_INFER_DECL AnalysisConfig { std::string SerializeInfoCache(); protected: - // Model pathes. + // Model paths. std::string model_dir_; mutable std::string prog_file_; mutable std::string params_file_; @@ -1271,8 +1328,14 @@ struct PD_INFER_DECL AnalysisConfig { bool trt_use_varseqlen_{false}; bool trt_with_interleaved_{false}; bool trt_mark_output_{false}; + bool trt_forbid_dynamic_op_{false}; + std::vector trt_output_tensor_names_{}; std::vector trt_exclude_var_names_{}; + std::vector trt_parameters_run_fp16_{}; + std::vector trt_parameters_run_int8_{}; + std::vector trt_parameters_run_bfp16_{}; + std::string tensorrt_transformer_posid_{""}; std::string tensorrt_transformer_maskid_{""}; bool trt_use_dla_{false}; @@ -1425,6 +1488,12 @@ struct PD_INFER_DECL AnalysisConfig { // PrepareProgram(). So we add this flag to control the process. bool apply_optim_{false}; bool skip_load_params_{false}; + + bool use_pir_{false}; + std::vector custom_passes_; + bool custom_pass_only_{false}; + int pm_opt_level_{2}; + std::vector ir_debug_passes_; }; } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_api.h b/paddle/fluid/inference/api/paddle_api.h index 8c66b66363603..b6931814ab9e7 100644 --- a/paddle/fluid/inference/api/paddle_api.h +++ b/paddle/fluid/inference/api/paddle_api.h @@ -523,6 +523,7 @@ class PD_INFER_DECL InternalUtils { static void SyncStream(paddle_infer::Predictor* pred); static void SyncStream(cudaStream_t stream); + static void SyncStream(hipStream_t stream); template static void CopyFromCpuWithIoStream(paddle_infer::Tensor* t, const T* data, diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0684064df81e8..f55fab3e71b08 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -528,6 +528,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_dropout_op_pass", "delete_concat_op_pass", "gather_squeeze_pass", + "roformer_relative_pos_fuse_pass", "delete_repeated_ops_pass", "identity_op_clean_pass", "fused_continuous_same_ops_pass", @@ -595,4 +596,39 @@ IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) { passes_.assign({"inference_process_pass"}); } +const std::vector kPirGpuPasses{ + // Functional pass + "map_op_to_another_pass", + "identity_op_clean_pass", + // Operator fusion pass + "silu_fuse_pass", + "conv2d_bn_fuse_pass", + "conv2d_add_act_fuse_pass", + "conv2d_add_fuse_pass", + "embedding_eltwise_layernorm_fuse_pass", + "multihead_matmul_fuse_pass", + "fc_fuse_pass", + "fc_elementwise_layernorm_fuse_pass", + "matmul_scale_fuse_pass", + "matmul_transpose_fuse_pass", + "transpose_flatten_concat_fuse_pass"}; + +const std::vector kPirXpuPasses{// Functional pass + "map_op_to_another_pass", + "identity_op_clean_pass", + // Operator fusion pass + "add_layernorm_xpu_fuse_pass"}; + +const std::vector kPirMkldnnPasses{ + "conv2d_bias_fuse_pass", + "conv2d_transpose_bias_fuse_pass", + "conv3d_bias_fuse_pass", + "batch_norm_act_fuse_pass", + "reshape_transpose_matmul_fuse_pass", + "matmul_elementwise_add_fuse_pass", + "matmul_activation_fuse_pass", + "conv_elementwise_add_mkldnn_fuse_pass"}; + +const std::vector kPirCpuPasses{}; + } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 2318c88741f28..79ef68c853cfb 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -353,4 +353,9 @@ PD_INFER_DECL extern const std::vector kCINNCompilerPasses; PD_INFER_DECL extern const std::vector kGpuLowerPrecisionPasses; PD_INFER_DECL extern const std::vector kTrtLowerPrecisionPasses; +PD_INFER_DECL extern const std::vector kPirGpuPasses; +PD_INFER_DECL extern const std::vector kPirCpuPasses; +PD_INFER_DECL extern const std::vector kPirXpuPasses; +PD_INFER_DECL extern const std::vector kPirMkldnnPasses; + } // namespace paddle diff --git a/paddle/fluid/inference/api/resource_manager.cc b/paddle/fluid/inference/api/resource_manager.cc index b18ca6e1c2a55..c2b26658498bd 100644 --- a/paddle/fluid/inference/api/resource_manager.cc +++ b/paddle/fluid/inference/api/resource_manager.cc @@ -191,7 +191,7 @@ void GPUContextResource::InitGpuEigenDevice() { gpu_eigen_device_ = std::make_unique(eigen_stream_.get()); } -void GPUContextResource::InitDnnHanlde() { +void GPUContextResource::InitDnnHandle() { phi::InitDnnHandle(&dnn_handle_, stream_, place_); } @@ -237,7 +237,7 @@ dnnHandle_t GPUContextResource::GetDnnHandle() const { return dnn_handle_; } std::function GPUContextResource::GetDnnHandleCreator() { return [&]() -> phi::dnnHandle_t { - InitDnnHanlde(); + InitDnnHandle(); return dnn_handle_; }; } @@ -355,7 +355,7 @@ int GPUContextResource::GetGpuMaxThreadsPerBlock() const { return max_threads_per_block_; } -std::array GPUContextResource::GetGpuMaxGridDimSize() const { +std::array GPUContextResource::GetGpuMaxGridDimSize() const { return max_grid_dim_size_; } @@ -367,7 +367,7 @@ ResourceManager& ResourceManager::Instance() { } void ResourceManager::InitCPUResource() { - std::lock_guard lock_gurad(cpu_mutex_); + std::lock_guard lock_guard(cpu_mutex_); if (cpu_resource_ == nullptr) { cpu_resource_ = std::make_unique(); } @@ -382,7 +382,7 @@ CPUContextResource* ResourceManager::GetCPUResource() const { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void* ResourceManager::InitGPUResource(const phi::Place& place, void* stream) { - std::lock_guard lock_gurad(gpu_mutex_); + std::lock_guard lock_guard(gpu_mutex_); if (gpu_resources_.count(stream)) { Increase(stream); return stream; @@ -427,7 +427,7 @@ GPUContextResource* ResourceManager::GetGPUResource(void* stream) const { void ResourceManager::GpuResourceSwitchStream(void* old_stream, void* new_stream) { // NOTE: add lock to support stream rebind in multi-thread - std::lock_guard lock_gurad(gpu_mutex_); + std::lock_guard lock_guard(gpu_mutex_); if (old_stream == new_stream) return; PADDLE_ENFORCE_EQ( gpu_resources_.count(old_stream), diff --git a/paddle/fluid/inference/api/resource_manager.h b/paddle/fluid/inference/api/resource_manager.h index 1f4d4ea420e1b..0ee40644ee5c5 100644 --- a/paddle/fluid/inference/api/resource_manager.h +++ b/paddle/fluid/inference/api/resource_manager.h @@ -81,14 +81,14 @@ class GPUContextResource { int GetGPUMultiProcessors() const; int GetGpuMaxThreadsPerMp() const; int GetGpuMaxThreadsPerBlock() const; - std::array GetGpuMaxGridDimSize() const; + std::array GetGpuMaxGridDimSize() const; private: void InitGPUResource(void* stream); void DestroyGPUResource(); void InitGpuProperties(); void InitGpuEigenDevice(); - void InitDnnHanlde(); + void InitDnnHandle(); void DestroyDnnHandle(); void DestroyBlasHandle(); void InitBlasLtHandle(); @@ -107,7 +107,7 @@ class GPUContextResource { int multi_process_; int max_threads_per_mp_; int max_threads_per_block_; - std::array max_grid_dim_size_; + std::array max_grid_dim_size_; bool owned_stream_{true}; gpuStream_t stream_; diff --git a/paddle/fluid/inference/capi/pd_config.cc b/paddle/fluid/inference/capi/pd_config.cc index 5197b8dede192..c2c8036ece7a8 100644 --- a/paddle/fluid/inference/capi/pd_config.cc +++ b/paddle/fluid/inference/capi/pd_config.cc @@ -275,7 +275,7 @@ void PD_EnableDlnne( int max_batch_size, bool use_static_batch, std::string weight_share_mode, - std::unordered_set disable_nodes_by_ouputs, + std::unordered_set disable_nodes_by_outputs, std::map> dlnne_input_shape_dict, bool use_calib_mode, PD_ACPrecision precision_mode) { @@ -287,7 +287,7 @@ void PD_EnableDlnne( max_batch_size, use_static_batch, weight_share_mode, - disable_nodes_by_ouputs, + disable_nodes_by_outputs, dlnne_input_shape_dict, use_calib_mode, precision_mode); diff --git a/paddle/fluid/inference/capi/pd_predictor.cc b/paddle/fluid/inference/capi/pd_predictor.cc index 39575a196e4f9..72f1b6c277153 100644 --- a/paddle/fluid/inference/capi/pd_predictor.cc +++ b/paddle/fluid/inference/capi/pd_predictor.cc @@ -92,7 +92,7 @@ bool PD_PredictorRun(const PD_AnalysisConfig* config, config, paddle::platform::errors::InvalidArgument( "The pointer of analysis configuration shouldn't be nullptr")); - VLOG(3) << "Predoctor: PD_PredictorRun. "; + VLOG(3) << "Predictor: PD_PredictorRun. "; static std::map> predictors; if (!predictors.count(config->config.model_dir())) { diff --git a/paddle/fluid/inference/paddle_inference.map b/paddle/fluid/inference/paddle_inference.map index 01a989cc568bc..ff95870771374 100644 --- a/paddle/fluid/inference/paddle_inference.map +++ b/paddle/fluid/inference/paddle_inference.map @@ -82,6 +82,8 @@ *Pass*; *profile*; *phi*; + *pir*; + *drr*; PD_*; *cinn*; local: diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index f09e5091ae9b1..f9057ab7b0a21 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -181,9 +181,9 @@ class STanhOpConverter : public ActivationOpConverter { STanhOpConverter() { op_type_ = "stanh"; } }; -class ThreasholdedReluOpConverter : public ActivationOpConverter { +class ThresholdedReluOpConverter : public ActivationOpConverter { public: - ThreasholdedReluOpConverter() { op_type_ = "thresholded_relu"; } + ThresholdedReluOpConverter() { op_type_ = "thresholded_relu"; } }; #endif @@ -201,5 +201,5 @@ REGISTER_TRT_OP_CONVERTER(selu, SeluOpConverter); REGISTER_TRT_OP_CONVERTER(softsign, SoftsignOpConverter); REGISTER_TRT_OP_CONVERTER(softplus, SoftplusOpConverter); REGISTER_TRT_OP_CONVERTER(stanh, STanhOpConverter); -REGISTER_TRT_OP_CONVERTER(thresholded_relu, ThreasholdedReluOpConverter); +REGISTER_TRT_OP_CONVERTER(thresholded_relu, ThresholdedReluOpConverter); #endif diff --git a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc index d7699c7c1003c..9f19b0b41096f 100644 --- a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc @@ -36,7 +36,7 @@ class AffineChannelOpConverter : public OpConverter { std::string output_name = op_desc.Output("Out").front(); auto input_tensor = engine_->GetITensor(input_name); - auto idim = input_tensor->getDimensions(); + auto input_dim = input_tensor->getDimensions(); auto* scale_v = scope.FindVar(scale_name); auto* scale_t = scale_v->GetMutable(); @@ -49,17 +49,17 @@ class AffineChannelOpConverter : public OpConverter { engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values)); // tensorrt scalend layer only support spatial dims >= 2, - // so nhwc is not availabe (spatial dims == 0) + // so nhwc is not available (spatial dims == 0) const int channel_axis = engine_->with_dynamic_shape(); TensorRTEngine::Weight scale_weights{ nvinfer1::DataType::kFLOAT, static_cast(scale_ptr), - static_cast(idim.d[channel_axis])}; + static_cast(input_dim.d[channel_axis])}; TensorRTEngine::Weight bias_weights{ nvinfer1::DataType::kFLOAT, static_cast(bias_ptr), - static_cast(idim.d[channel_axis])}; + static_cast(input_dim.d[channel_axis])}; TensorRTEngine::Weight power_weights{ nvinfer1::DataType::kFLOAT, nullptr, 0}; diff --git a/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc b/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc index a944527313a02..63a02d4e393e8 100644 --- a/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/bitwise_not_op.cc @@ -42,7 +42,7 @@ class BitwiseNotConverter : public OpConverter { nvinfer1::Dims input_dims = input_tensor->getDimensions(); // set up a elementwise -1 tensor, can not get the dims info for - // dynamic_shape so just let it broadcaste + // dynamic_shape so just let it broadcast nvinfer1::Dims neg_one_tensor_dims; neg_one_tensor_dims.nbDims = input_dims.nbDims; for (int i = 0; i < input_dims.nbDims; ++i) { diff --git a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc index 1df92f0641040..37a53d31f47b5 100644 --- a/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/conv3d_op.cc @@ -35,7 +35,7 @@ void ConvertConv3d(TensorRTEngine* engine, auto* Y_v = scope.FindVar(filter_var_name); PADDLE_ENFORCE_NOT_NULL( Y_v, - platform::errors::NotFound("Can not find %s presistale var in scope.", + platform::errors::NotFound("Can not find %s presistable var in scope.", filter_var_name)); auto* Y_t = Y_v->GetMutable(); bool enable_int8 = op_desc.HasAttr("enable_int8"); diff --git a/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc index 6a1cf1951f9a6..df5665b75b34e 100644 --- a/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/cross_multihead_matmul_op.cc @@ -24,8 +24,9 @@ class CrossMultiheadMatMulOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a cross_multihead_mamul op to a corresponding tensorrt " - "network structure"; + VLOG(3) + << "convert a cross_multihead_matmul op to a corresponding tensorrt " + "network structure"; bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (engine_->precision() == phi::DataType::INT8) { with_fp16 = true; @@ -109,7 +110,7 @@ class CrossMultiheadMatMulOpConverter : public OpConverter { weight_q, bias_q); fc_q_layer->setName( - ("multihead_mamul_fc_q(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc_q(Output: " + output_name + ")").c_str()); // add shuffle for fc layer auto* reshape_after_fc_q_layer = @@ -211,7 +212,7 @@ class CrossMultiheadMatMulOpConverter : public OpConverter { weight_kv, bias_kv); fc_layer->setName( - ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc(Output: " + output_name + ")").c_str()); // add shuffle for fc layer auto* reshape_after_fc_layer = diff --git a/paddle/fluid/inference/tensorrt/convert/dequantize_linear_op.cc b/paddle/fluid/inference/tensorrt/convert/dequantize_linear_op.cc index 9b88e14fc9efe..662769e7f24ec 100644 --- a/paddle/fluid/inference/tensorrt/convert/dequantize_linear_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/dequantize_linear_op.cc @@ -32,7 +32,7 @@ class DequantizeLinearOpConverter : public OpConverter { // Create constant layer for scale PADDLE_ENFORCE_NOT_NULL( scale_var, - platform::errors::NotFound("Can not find %s presistale var in scope.", + platform::errors::NotFound("Can not find %s presistable var in scope.", op_desc.Input("Scale")[0])); auto* scale_t = scale_var->GetMutable(); int n_scale = scale_t->numel(); diff --git a/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc index 8b49127cb93db..e5904a1cf7543 100644 --- a/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/flash_multihead_matmul_op.cc @@ -24,11 +24,12 @@ namespace tensorrt { class FlashMultiheadMatMulOpConverter : public OpConverter { public: - void flash_multihead_mamul_trt(const framework::proto::OpDesc& op, - const framework::Scope& scope, - bool test_mode) { - VLOG(3) << "convert a flash_multihead_mamul op to a corresponding tensorrt " - "network structure\n"; + void flash_multihead_matmul_trt(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) { + VLOG(3) + << "convert a flash_multihead_matmul op to a corresponding tensorrt " + "network structure\n"; bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); if (engine_->precision() == phi::DataType::INT8) { @@ -138,7 +139,7 @@ class FlashMultiheadMatMulOpConverter : public OpConverter { weight, bias); fc_layer->setName( - ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc(Output: " + output_name + ")").c_str()); // add shuffle for fc layer reshape_before_mha_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0)); @@ -243,10 +244,10 @@ class FlashMultiheadMatMulOpConverter : public OpConverter { layer, "flash_multihead_matmul", {output_name}, test_mode); } - void flash_multihead_mamul(const framework::proto::OpDesc& op, - const framework::Scope& scope, - bool test_mode) { - VLOG(3) << "convert a flash_multihead_mamul op to a " + void flash_multihead_matmul(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) { + VLOG(3) << "convert a flash_multihead_matmul op to a " "MemoryEfficientAttention OP " "network structure\n"; framework::OpDesc op_desc(op, nullptr); @@ -310,7 +311,7 @@ class FlashMultiheadMatMulOpConverter : public OpConverter { hidden_out, weight, bias); - qkv_fc_layers[i]->setName(("multihead_mamul_fc_" + std::to_string(i) + + qkv_fc_layers[i]->setName(("multihead_matmul_fc_" + std::to_string(i) + "_(Output: " + output_name + ")") .c_str()); } else { @@ -334,7 +335,7 @@ class FlashMultiheadMatMulOpConverter : public OpConverter { matrix_operation_x, *weight_reshape_before_mm[i]->getOutput(0), matrix_operation_y); - qkv_fc_layers[i]->setName(("multihead_mamul_matmul_" + + qkv_fc_layers[i]->setName(("multihead_matmul_matmul_" + std::to_string(i) + "_(Output: " + output_name + ")") .c_str()); @@ -499,9 +500,9 @@ class FlashMultiheadMatMulOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); bool use_trt_fma = PADDLE_GET_CONST(bool, op_desc.GetAttr("use_trt_fma")); if (use_trt_fma) { - flash_multihead_mamul_trt(op, scope, test_mode); + flash_multihead_matmul_trt(op, scope, test_mode); } else { - flash_multihead_mamul(op, scope, test_mode); + flash_multihead_matmul(op, scope, test_mode); } } }; diff --git a/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc b/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc index 5e4dfca1417f8..6ebc1278c277f 100644 --- a/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc +++ b/paddle/fluid/inference/tensorrt/convert/generic_and_custom_plugin_creater.cc @@ -31,7 +31,7 @@ class CustomPluginCreater : public OpConverter { const framework::Scope &scope, bool test_mode) override { framework::OpDesc op_desc(op, nullptr); - VLOG(3) << "convert " << op_desc.Type() << " op to custom pluign layer"; + VLOG(3) << "convert " << op_desc.Type() << " op to custom plugin layer"; std::string plugin_name; @@ -60,7 +60,7 @@ class CustomPluginCreater : public OpConverter { CHECK(creator); // set attrs - std::vector plugindatas; + std::vector plugin_datas; auto &op_attrs_names = OpMetaInfoHelper::GetAttrs(op_info); auto &attrs = op_desc.GetAttrMap(); @@ -74,7 +74,7 @@ class CustomPluginCreater : public OpConverter { for (auto &attr_name_and_type : op_attrs_names) { auto attr_name = attr_name_and_type.substr(0, attr_name_and_type.find_first_of(":")); - nvinfer1::PluginField plugindata; + nvinfer1::PluginField plugin_data; // NOTE: to avoid string rewrite by iterator, deep copy here std::vector plugin_attr_name(attr_name.length() + 1, 0); @@ -82,47 +82,47 @@ class CustomPluginCreater : public OpConverter { attr_name.length() + 1, "%s", attr_name.c_str()); - plugindata.name = plugin_attr_name.data(); + plugin_data.name = plugin_attr_name.data(); if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INT) { int_attrs.push_back(PADDLE_GET_CONST(int, attrs.at(attr_name))); - plugindata.data = &int_attrs.back(); - plugindata.type = nvinfer1::PluginFieldType::kINT32; - plugindata.length = 1; + plugin_data.data = &int_attrs.back(); + plugin_data.type = nvinfer1::PluginFieldType::kINT32; + plugin_data.length = 1; } else if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::FLOAT) { float_attrs.push_back(PADDLE_GET_CONST(float, attrs.at(attr_name))); - plugindata.data = &float_attrs.back(); - plugindata.type = nvinfer1::PluginFieldType::kFLOAT32; - plugindata.length = 1; + plugin_data.data = &float_attrs.back(); + plugin_data.type = nvinfer1::PluginFieldType::kFLOAT32; + plugin_data.length = 1; } else if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::BOOLEAN) { int_attrs.push_back(PADDLE_GET_CONST(bool, attrs.at(attr_name))); - plugindata.data = &int_attrs.back(); - plugindata.type = nvinfer1::PluginFieldType::kINT32; - plugindata.length = 1; + plugin_data.data = &int_attrs.back(); + plugin_data.type = nvinfer1::PluginFieldType::kINT32; + plugin_data.length = 1; } else if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::STRING) { string_attrs.push_back( PADDLE_GET_CONST(std::string, attrs.at(attr_name))); - plugindata.data = string_attrs.back().data(); - plugindata.type = nvinfer1::PluginFieldType::kCHAR; - plugindata.length = + plugin_data.data = string_attrs.back().data(); + plugin_data.type = nvinfer1::PluginFieldType::kCHAR; + plugin_data.length = string_attrs.back().size() + 1; // string ends with ‘\0’ } else if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::INTS) { ints_attrs.push_back( PADDLE_GET_CONST(std::vector, attrs.at(attr_name))); - plugindata.data = ints_attrs.back().data(); - plugindata.type = nvinfer1::PluginFieldType::kINT32; - plugindata.length = ints_attrs.back().size(); + plugin_data.data = ints_attrs.back().data(); + plugin_data.type = nvinfer1::PluginFieldType::kINT32; + plugin_data.length = ints_attrs.back().size(); } else if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::FLOATS) { floats_attrs.push_back( PADDLE_GET_CONST(std::vector, attrs.at(attr_name))); - plugindata.data = floats_attrs.back().data(); - plugindata.type = nvinfer1::PluginFieldType::kFLOAT32; - plugindata.length = floats_attrs.back().size(); + plugin_data.data = floats_attrs.back().data(); + plugin_data.type = nvinfer1::PluginFieldType::kFLOAT32; + plugin_data.length = floats_attrs.back().size(); } else if (op_desc.GetAttrType(attr_name) == framework::proto::AttrType::BOOLEANS) { auto bools_attr = @@ -130,17 +130,17 @@ class CustomPluginCreater : public OpConverter { std::vector convert_to_ints_attr; for (bool i : bools_attr) convert_to_ints_attr.push_back(i); ints_attrs.push_back(convert_to_ints_attr); - plugindata.data = ints_attrs.back().data(); - plugindata.type = nvinfer1::PluginFieldType::kINT32; - plugindata.length = ints_attrs.back().size(); + plugin_data.data = ints_attrs.back().data(); + plugin_data.type = nvinfer1::PluginFieldType::kINT32; + plugin_data.length = ints_attrs.back().size(); } else { CHECK(false) << "UNKNOWN PluginFieldType."; } - plugindatas.push_back(plugindata); + plugin_datas.push_back(plugin_data); } - nvinfer1::PluginFieldCollection plugin_fc{(int32_t)plugindatas.size(), - plugindatas.data()}; + nvinfer1::PluginFieldCollection plugin_fc{(int32_t)plugin_datas.size(), + plugin_datas.data()}; auto *plugin = creator->createPlugin(op_desc.Type().c_str(), &plugin_fc); CHECK(plugin); @@ -175,7 +175,7 @@ class GenericPluginCreater : public OpConverter { const framework::Scope &scope, bool test_mode) override { framework::OpDesc op_desc(op, nullptr); - VLOG(3) << "convert " << op_desc.Type() << " op to generic pluign layer"; + VLOG(3) << "convert " << op_desc.Type() << " op to generic plugin layer"; CHECK(block_); const framework::BlockDesc block_desc( @@ -259,7 +259,7 @@ class CustomGenericPluginCreater : public OpConverter { bool test_mode) override { framework::OpDesc op_desc(op, nullptr); VLOG(3) << "convert " << op_desc.Type() - << " op to custom generic pluign layer"; + << " op to custom generic plugin layer"; nvinfer1::ILayer *layer = nullptr; std::vector inputs; diff --git a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc index 50fa54bcf90c2..c9335f2270621 100644 --- a/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc @@ -69,12 +69,13 @@ class LayerNormOpConverter : public OpConverter { ("layer_norm Scale: reshape: (Output(" + output_name + ")").c_str()); auto layer = TRT_ENGINE_ADD_LAYER( engine_, Normalization, *X, *Scale_reshape, *Bias_reshape, axisMask); + SupportFP32MixPrecision(output_name, op_desc.Type(), layer); layer->setEpsilon(eps); ReplenishLayerAndOutput(layer, "layer_norm", {output_name}, test_mode); #endif #if IS_TRT_VERSION_LT(8600) // For dynamic shape & trt<8.6, - // the shape of mean and variance will be determine in configuPlugin. + // the shape of mean and variance will be determine in configurePlugin. auto* X = engine_->GetITensor(op_desc.Input("X").front()); auto* Bias_v = scope.FindVar(op_desc.Input("Bias").front()); auto* Scale_v = scope.FindVar(op_desc.Input("Scale").front()); diff --git a/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc b/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc index 7cf5dea57d5d4..4f4b09b6173a2 100644 --- a/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/layernorm_shift_partition_op.cc @@ -73,7 +73,7 @@ class LayerNormShiftPartitionOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(bias_weight.get().count, scale_weight.get().count, platform::errors::InvalidArgument( - "The num between bias_weight and cale_weight should " + "The num between bias_weight and scale_weight should " "be equal. (%d vs %d)", bias_weight.get().count, scale_weight.get().count)); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 4e6cab4ff907e..73c43d39357c0 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -25,7 +25,7 @@ class MultiheadMatMulOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a multihead_mamul op to a corresponding tensorrt " + VLOG(3) << "convert a multihead_matmul op to a corresponding tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); // Declare inputs @@ -377,7 +377,7 @@ class MultiheadMatMulOpConverter : public OpConverter { reshape_before_multihead_layer->setInput(1, *Concat(reshape_tensor)); reshape_before_multihead_layer->setName( - ("reshape_before_multihead_mamul(Output: " + output_name + ")") + ("reshape_before_multihead_matmul(Output: " + output_name + ")") .c_str()); if (op_desc.HasAttr("fc_out_threshold")) { @@ -625,7 +625,7 @@ class MultiheadMatMulOpConverter : public OpConverter { bias); } fc_layer->setName( - ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc(Output: " + output_name + ")").c_str()); // add shuffle for CustomQKVToContextPluginDynamic layer auto* reshape_after_fc_layer = @@ -798,7 +798,7 @@ class MultiheadMatMulOpConverter : public OpConverter { reshape_before_fc_layer->setInput( 1, *Concat(reshape_before_fc_shape_tensor)); reshape_before_fc_layer->setName( - ("shuffle_before_multihead_mamul(Output: " + output_name + ")") + ("shuffle_before_multihead_matmul(Output: " + output_name + ")") .c_str()); // add layer fc @@ -834,7 +834,7 @@ class MultiheadMatMulOpConverter : public OpConverter { engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); } fc_layer->setName( - ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc(Output: " + output_name + ")").c_str()); // no need to add shuffle after fc, just change it in // QkvToContextPluginDynamic diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc index 517f5f1e7efc0..f849fff7ab1f2 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_roformer_op.cc @@ -24,7 +24,7 @@ class MultiheadMatMulRoformerOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a multihead_mamul_roformer op to a corresponding " + VLOG(3) << "convert a multihead_matmul_roformer op to a corresponding " "tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index 3b75a79d9b563..af9b53c4b29e0 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -70,7 +70,7 @@ class OpConverter { 1UL, platform::errors::InvalidArgument( "The input op's Input(\"Y\")." - "size() should equal to 1, but reveceid " + "size() should equal to 1, but received " "Input(\"Y\").size() = %u.", op_desc.Input("Y").size())); int op_type_len = op_desc.Type().size(); @@ -173,13 +173,33 @@ class OpConverter { platform::errors::Unimplemented("no OpConverter for optype [%s]", op_desc.Type())); + std::string all_outpus_name = "(Outputs:"; + std::string all_inpus_name = "(Inputs:"; + for (auto it1 : op_desc.OutputNames()) { + for (auto it2 : op_desc.Output(it1)) { + all_outpus_name += it2; + all_outpus_name += ","; + } + } + all_outpus_name += ")"; + for (auto it1 : op_desc.InputNames()) { + for (auto it2 : op_desc.Input(it1)) { + all_inpus_name += it2; + all_inpus_name += ","; + } + } + + all_inpus_name += ")"; + VLOG(1) << op_desc.Type() << all_inpus_name << all_outpus_name + << "are to be converted to TensorRT layer"; + it->SetEngine(engine); engine->SetScope(&scope); it->SetBlockDesc(block); (*it)(op, scope, test_mode); size_t output_num = op_desc.OutputNames().size(); - // only one out settensordynamicRange + // only one out SetTensorDynamicRange if (op_desc.HasAttr("out_threshold")) { float out_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); @@ -197,12 +217,13 @@ class OpConverter { "\"Out\" or \"Y\".", op_desc.Type())); } + auto* output_itensor = engine->GetITensor(output_name); engine->SetTensorDynamicRange(output_itensor, out_scale); VLOG(1) << "Set out scale = " << out_scale << " for tensor " << output_name << "."; } - // outs settensordynamicRange + // outs SetTensorDynamicRange for (size_t i = 0; i < output_num; ++i) { if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) { float out_scale = PADDLE_GET_CONST( @@ -245,12 +266,14 @@ class OpConverter { } } - // Convert a fluid block to tensorrt network, NOTE it just convert operators, - // the INetwork's inputs and outputs should specified in some other modules. + // Convert a fluid block to tensorrt network, NOTE it just convert + // operators, the INetwork's inputs and outputs should specified in some + // other modules. void ConvertBlock(const framework::proto::BlockDesc& block, const std::unordered_set& parameters, const framework::Scope& scope, TensorRTEngine* engine) { + VLOG(1) << "Convert a fluid block to tensorrt network"; std::unique_lock lk(mut_); for (int i = 0; i < block.ops_size(); i++) { const auto& op = block.ops(i); @@ -787,6 +810,9 @@ class OpConverter { VLOG(3) << output_tensor_names[i] << "'s dimension :[" << string::join_strings(tmp_vec, ',') << "]"; + VLOG(1) << "Paddle-TRT inferred " << output_tensor_names[i] + << "'s dimension is :[" << string::join_strings(tmp_vec, ',') + << "]"; // The following check may cause errors in CI, but is necessary in the // latest version. // PADDLE_ENFORCE_GE( diff --git a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc index 529175c7de81a..0ec1336f0e2d1 100644 --- a/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/preln_emb_eltwise_layernorm.cc @@ -103,7 +103,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { slice_stride_dims); // unuseful slice_start_dims slice_layer->setInput(1, *start_tensor); slice_layer->setInput(2, *size_tensor); - slice_layer->setName(("Embeltwise_slice_layer (Output: slice_max_seqlen " + + slice_layer->setName(("EmbEltwise_slice_layer (Output: slice_max_seqlen " + op_desc.Output("Out")[0] + ")") .c_str()); engine_->SetTensorDynamicRange(slice_layer->getOutput(0), 1.0f); @@ -114,7 +114,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { shape_dim.nbDims = 1; shape_dim.d[0] = -1; reshape_layer->setReshapeDimensions(shape_dim); - reshape_layer->setName(("Embeltwise_reshape_layer (Output: max_seqlen " + + reshape_layer->setName(("EmbEltwise_reshape_layer (Output: max_seqlen " + op_desc.Output("Out")[0] + ")") .c_str()); engine_->SetTensorDynamicRange(reshape_layer->getOutput(0), 1.0f); diff --git a/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc index 4a24e7425068f..e8ed4af9cddf7 100644 --- a/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/qk_multihead_matmul_op.cc @@ -23,7 +23,7 @@ class QkMultiheadMatMulOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - VLOG(3) << "convert a qk_multihead_mamul op to a corresponding tensorrt " + VLOG(3) << "convert a qk_multihead_matmul op to a corresponding tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); @@ -142,7 +142,7 @@ class QkMultiheadMatMulOpConverter : public OpConverter { *bias_qk_tensor, elementwise_operation); merge_qk_element_layer->setName( - ("multihead_mamul_fc_qk(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc_qk(Output: " + output_name + ")").c_str()); auto* reshape_after_fc_qk_layer = TRT_ENGINE_ADD_LAYER( engine_, Shuffle, *merge_qk_element_layer->getOutput(0)); @@ -232,7 +232,7 @@ class QkMultiheadMatMulOpConverter : public OpConverter { *bias_v_tensor, elementwise_operation); merge_v_element_layer->setName( - ("multihead_mamul_fc_v(Output: " + output_name + ")").c_str()); + ("multihead_matmul_fc_v(Output: " + output_name + ")").c_str()); // add shuffle for fc layer auto* reshape_after_fc_v_layer = TRT_ENGINE_ADD_LAYER( diff --git a/paddle/fluid/inference/tensorrt/convert/quantize_linear_op.cc b/paddle/fluid/inference/tensorrt/convert/quantize_linear_op.cc index b37a8f327e154..74a8f56ea6c20 100644 --- a/paddle/fluid/inference/tensorrt/convert/quantize_linear_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/quantize_linear_op.cc @@ -33,7 +33,7 @@ class QuantizeLinearOpConverter : public OpConverter { // Create constant layer for scale PADDLE_ENFORCE_NOT_NULL( scale_var, - platform::errors::NotFound("Can not find %s presistale var in scope.", + platform::errors::NotFound("Can not find %s presistable var in scope.", op_desc.Input("Scale")[0])); auto* scale_t = scale_var->GetMutable(); int n_scale = scale_t->numel(); diff --git a/paddle/fluid/inference/tensorrt/convert/range_op.cc b/paddle/fluid/inference/tensorrt/convert/range_op.cc index b44d9d588744a..073b51b8c0734 100644 --- a/paddle/fluid/inference/tensorrt/convert/range_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/range_op.cc @@ -35,15 +35,15 @@ class RangeOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; auto zero_tensor = Add1DConstantLayer(0, output_name + "_zero_tensor_"); - auto fquotient_tensor = FloorDiv(Sub(start, end), step); + auto f_quotient_tensor = FloorDiv(Sub(start, end), step); if (start->getType() == nvinfer1::DataType::kFLOAT) { auto* cast_int32_layer = - TRT_ENGINE_ADD_LAYER(engine_, Identity, *fquotient_tensor); + TRT_ENGINE_ADD_LAYER(engine_, Identity, *f_quotient_tensor); cast_int32_layer->setOutputType(0, nvinfer1::DataType::kINT32); cast_int32_layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); quotient_tensor = cast_int32_layer->getOutput(0); } else { - quotient_tensor = fquotient_tensor; + quotient_tensor = f_quotient_tensor; } auto number_tensor = Max(Sub(zero_tensor, quotient_tensor), zero_tensor); auto* start1 = engine_->GetITensor(op_desc.Input("Start")[0]); diff --git a/paddle/fluid/inference/tensorrt/convert/reshape_op.cc b/paddle/fluid/inference/tensorrt/convert/reshape_op.cc index c31cf1b012a49..c1f226626742f 100644 --- a/paddle/fluid/inference/tensorrt/convert/reshape_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/reshape_op.cc @@ -67,7 +67,7 @@ class ReshapeOpConverter : public OpConverter { layer->getOutput(0)->getDimensions().nbDims, 0, platform::errors::InvalidArgument( - "Errors occures in Paddle-TRT reshape2 op, try to use C++ Api " + "Errors occurs in Paddle-TRT reshape2 op, try to use C++ Api " "config.Exp_DisableTensorRtOPs({\"reshape2\"})\n; or Python Api " "config.exp_disable_tensorrt_ops([\"reshape2\"]) to forbid " "reshape2 op into " diff --git a/paddle/fluid/inference/tensorrt/convert/set_value_op.cc b/paddle/fluid/inference/tensorrt/convert/set_value_op.cc index 1c734d791cdde..29f95a3554fc4 100644 --- a/paddle/fluid/inference/tensorrt/convert/set_value_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/set_value_op.cc @@ -25,7 +25,7 @@ limitations under the License. */ PADDLE_ENFORCE_EQ(vec_##attr_name__.size(), \ 1UL, \ platform::errors::InvalidArgument( \ - "attr axes/starst/ends/steps 's size in " \ + "attr axes/starts/ends/steps 's size in " \ "set_value must be one, but got %d", \ vec_##attr_name__.size())); \ } \ @@ -151,7 +151,7 @@ class SetValueConverter : public OpConverter { platform::errors::InvalidArgument( "ValueTensor‘s rank not equal to Input's rank, " "you should try use C++ API " - "config.exp_disable_tensorrt_ops({\"%s\"}) to forbind this op " + "config.exp_disable_tensorrt_ops({\"%s\"}) to forbid this op " "enter into TRT, " "please find the %s's real name from .pdmodel or shape.txt", output_name, diff --git a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc index 15ef380253949..ab70ebb6ccd81 100644 --- a/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/skip_layernorm.cc @@ -67,17 +67,19 @@ class SkipLayerNormOpConverter : public OpConverter { if ((x_rank == 2 && y_rank == 4) || (y_rank == 2 && x_rank == 4)) { if (x_rank == 2 && y_rank == 4) { - auto* reshape_before_skiplayn = + auto* reshape_before_skip_layer_n = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input1); std::vector reshape_before_tensor; reshape_before_tensor.push_back(GetEleTensorOfShape(Shape(input1), 0)); reshape_before_tensor.push_back(GetEleTensorOfShape(Shape(input1), 1)); reshape_before_tensor.push_back(Add1DConstantLayer(1)); reshape_before_tensor.push_back(Add1DConstantLayer(1)); - reshape_before_skiplayn->setInput(1, *Concat(reshape_before_tensor)); - reshape_before_skiplayn->setName( - ("reshape_before_skiplayn(Output: " + output_name + ")").c_str()); - input1 = reshape_before_skiplayn->getOutput(0); + reshape_before_skip_layer_n->setInput(1, + *Concat(reshape_before_tensor)); + reshape_before_skip_layer_n->setName( + ("reshape_before_skip_layer_n(Output: " + output_name + ")") + .c_str()); + input1 = reshape_before_skip_layer_n->getOutput(0); if (enable_int8) { if (op_desc.HasAttr("X")) { @@ -85,17 +87,19 @@ class SkipLayerNormOpConverter : public OpConverter { } } } else { - auto* reshape_before_skiplayn = + auto* reshape_before_skip_layer_n = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input2); std::vector reshape_before_tensor; reshape_before_tensor.push_back(GetEleTensorOfShape(Shape(input2), 0)); reshape_before_tensor.push_back(GetEleTensorOfShape(Shape(input2), 1)); reshape_before_tensor.push_back(Add1DConstantLayer(1)); reshape_before_tensor.push_back(Add1DConstantLayer(1)); - reshape_before_skiplayn->setInput(1, *Concat(reshape_before_tensor)); - reshape_before_skiplayn->setName( - ("reshape_before_skiplayn(Output: " + output_name + ")").c_str()); - input2 = reshape_before_skiplayn->getOutput(0); + reshape_before_skip_layer_n->setInput(1, + *Concat(reshape_before_tensor)); + reshape_before_skip_layer_n->setName( + ("reshape_before_skip_layer_n(Output: " + output_name + ")") + .c_str()); + input2 = reshape_before_skip_layer_n->getOutput(0); if (enable_int8) { if (op_desc.HasAttr("Y")) { diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index 4a2d38d5e0736..0e2382a2d3fa6 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -20,7 +20,7 @@ class SliceOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { - // This OP is implemented by trt dynamic shpae plugin. + // This OP is implemented by trt dynamic shape plugin. // Dynamic shape plugin requires TRT version greater than 6.0. VLOG(4) << "convert slice op to tensorrt layer"; framework::OpDesc op_desc(op, nullptr); diff --git a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc index 921402a9be5d2..483cd0711ffc6 100644 --- a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc @@ -58,7 +58,7 @@ class SoftMaxOpConverter : public OpConverter { uint32_t axes = std::max(0, input_dims - 3); // TODO(cryoco): Poor workaround. Fix padded dims problem when TRT layers // support Nd. - // Tips: Dynammic shape alreay fixes. + // Tips: Dynamic shape already fixes. int padded_dims = 0; int explicit_batch = 0; if (engine_->with_dynamic_shape()) explicit_batch = 1; diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc index bae9cccde6fa7..c143eb00d2797 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_fc_op.cc @@ -116,7 +116,7 @@ class SparseFcOpConverter : public OpConverter { PADDLE_ENFORCE_NOT_NULL( Y_v, platform::errors::NotFound( - "Can not find %s presistale var of sparse_fc in scope.", w_name)); + "Can not find %s presistable var of sparse_fc in scope.", w_name)); auto* Y_t = Y_v->GetMutable(); int x_num_col_dims = op_desc.HasAttr("x_num_col_dims") diff --git a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc index 74198b3066a88..a0736522e5b14 100644 --- a/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/sparse_multihead_matmul_op.cc @@ -366,7 +366,7 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { } reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); reshape_before_fc_layer->setName( - ("shuffle_before_sparse_multihead_mamul(Output: " + output_name + + ("shuffle_before_sparse_multihead_matmul(Output: " + output_name + ")") .c_str()); @@ -403,7 +403,8 @@ class SparseMultiheadMatMulOpConverter : public OpConverter { engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); } fc_layer->setName( - ("sparse_multihead_mamul_fc(Output: " + output_name + ")").c_str()); + ("sparse_multihead_matmul_fc(Output: " + output_name + ")") + .c_str()); // no need to add shuffle after fc, just change it in // QkvToContextPluginDynamic diff --git a/paddle/fluid/inference/tensorrt/convert/tile_op.cc b/paddle/fluid/inference/tensorrt/convert/tile_op.cc index ffdc71e3af675..c02fe619aa30d 100644 --- a/paddle/fluid/inference/tensorrt/convert/tile_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/tile_op.cc @@ -35,12 +35,6 @@ class TileOpConverter : public OpConverter { auto output_name = op_desc.Output("Out")[0]; if (engine_->with_dynamic_shape()) { - std::vector start(rank, 0); - std::vector stride(rank, 1); - auto start_tensor = - Add1DConstantLayer(start, output_name + "start_tensor"); - auto stride_tensor = - Add1DConstantLayer(stride, output_name + "stride_tensor"); auto input_shape_tensor = Shape(input); nvinfer1::ITensor* repeat_tensor = nullptr; @@ -76,9 +70,26 @@ class TileOpConverter : public OpConverter { itensors.push_back(one_rank_tensor); itensors.push_back(repeat_tensor); repeat_expand_tensor = Concat(itensors); + } + if (rank < repeat_rank) { + auto* one_rank_tensor = + Add1DConstantLayer(std::vector(repeat_rank - rank, 1)); + std::vector itensors; + itensors.push_back(one_rank_tensor); + itensors.push_back(input_shape_tensor); + input_shape_tensor = Concat(itensors); + // need reshape input to more dims. + input = Reshape(input, input_shape_tensor, "reshape_input_befor_slice"); + repeat_expand_tensor = repeat_tensor; } else { repeat_expand_tensor = repeat_tensor; } + std::vector start(std::max(rank, repeat_rank), 0); + std::vector stride(std::max(rank, repeat_rank), 1); + auto start_tensor = + Add1DConstantLayer(start, output_name + "start_tensor"); + auto stride_tensor = + Add1DConstantLayer(stride, output_name + "stride_tensor"); auto output_shape_tensor = Prod(input_shape_tensor, repeat_expand_tensor); auto layer = TRT_ENGINE_ADD_LAYER(engine_, Slice, diff --git a/paddle/fluid/inference/tensorrt/convert/trans_layernorm_op.cc b/paddle/fluid/inference/tensorrt/convert/trans_layernorm_op.cc index dc257beb14683..a5db8ed88c4c0 100644 --- a/paddle/fluid/inference/tensorrt/convert/trans_layernorm_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/trans_layernorm_op.cc @@ -53,7 +53,7 @@ class TransLayerNormOpConverter : public OpConverter { nvinfer1::ILayer* layernorm_layer = nullptr; if (engine_->with_dynamic_shape()) { // For dynamic shape, - // the shape of mean and variance will be determine in configuPlugin. + // the shape of mean and variance will be determine in configurePlugin. std::vector mean_shape{1}; std::vector variance_shape{1}; bool with_fp16 = diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 8901d0a43fd41..347f6f500c7c8 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -247,7 +247,7 @@ class TRTConvertValidation { std::unique_ptr op_desc_; const std::unordered_set& parameters_; framework::Scope& scope_; - // The ITensor of trt does not cotain the batch size, + // The ITensor of trt does not contain the batch size, // bug, in most cases, we need to set batch size for // fluid's tensor shape. This variable indicates // whether to add batch size to tensor shape of fluid. diff --git a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc index ed5f57165d710..942eecc6e0fe6 100644 --- a/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc +++ b/paddle/fluid/inference/tensorrt/dynamic_shape_infermeta.cc @@ -72,7 +72,7 @@ class ExprWrapper { } friend ExprWrapper operator+(int a_value, const ExprWrapper& b) { - return a_value + b; + return b + a_value; } friend ExprWrapper operator-(const ExprWrapper& a, const ExprWrapper& b) { @@ -259,7 +259,7 @@ inline const nvinfer1::IDimensionExpr* CalcOutputSize( return output_size; } -nvinfer1::DimsExprs UnflodInferMeta( +nvinfer1::DimsExprs UnfoldInferMeta( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, @@ -879,7 +879,7 @@ nvinfer1::DimsExprs SolveInferMeta( PD_REGISTER_DYNAMIC_INFER_META_FN(gather_nd, GatherNdInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(yolo_box, YoloBoxInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(instance_norm, InstanceNormInferMeta); -PD_REGISTER_DYNAMIC_INFER_META_FN(unfold, UnflodInferMeta); +PD_REGISTER_DYNAMIC_INFER_META_FN(unfold, UnfoldInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(scatter_nd_add, ScatterNdAddInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(inverse, UnchangedInferMeta); PD_REGISTER_DYNAMIC_INFER_META_FN(moe, MoeInferMeta); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 6bc369de6c89c..2a14702b59d81 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -52,7 +52,7 @@ void TensorRTEngine::Weight::SetDataType(phi::DataType type) { #endif default: paddle::platform::errors::InvalidArgument( - "Paddle-TRT loads weighths failed, found not supported data type %s.", + "Paddle-TRT loads weights failed, found not supported data type %s.", type); break; } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index bb56dfe4d6f9b..e870c5b43a800 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -34,6 +34,43 @@ namespace paddle { namespace inference { namespace tensorrt { +// Check if it is a dynamic shape. If it is a dynamic shape, return true; +// otherwise, return false +bool IsDynamicShapeOp(const framework::OpDesc& desc) { + VLOG(3) << "forbid_dynamic_op_enter_into_trt is open"; + auto* block = desc.Block(); + auto inputs = desc.Inputs(); + for (auto iter : inputs) { + for (auto var_name : iter.second) { + if (block) { + auto* var_desc = block->FindVar(var_name); + const auto shape = var_desc->GetShape(); + for (auto ele : shape) { + if (ele < 0) { + return true; + } + } + } + } + } + + auto outputs = desc.Outputs(); + for (auto iter : outputs) { + for (auto var_name : iter.second) { + if (block) { + auto* var_desc = block->FindVar(var_name); + const auto shape = var_desc->GetShape(); + for (auto ele : shape) { + if (ele < 0) { + return true; + } + } + } + } + } + return false; +} + // Just tell by the op_types. struct SimpleOpTypeSetTeller : public Teller { SimpleOpTypeSetTeller() { // NOLINT @@ -89,6 +126,7 @@ struct SimpleOpTypeSetTeller : public Teller { bool operator()(const framework::OpDesc& desc, bool use_no_calib_int8 = false, bool with_dynamic_shape = false, + bool forbid_dynamic_op_enter_into_trt = false, bool use_explicit_quantization = false) override { const std::string op_type = desc.Type(); @@ -102,6 +140,9 @@ struct SimpleOpTypeSetTeller : public Teller { if (feed_fetch_set.find(op_type) != feed_fetch_set.end()) { return false; } + if (forbid_dynamic_op_enter_into_trt && IsDynamicShapeOp(desc)) { + return false; + } // do not support the op which is labeled the `skip_quant` if ((desc.HasAttr("namescope") && @@ -1460,7 +1501,7 @@ struct SimpleOpTypeSetTeller : public Teller { } if (desc.Output("Out").size() != 1) { VLOG(3) << "The input op's Output(\"Out\").size() " - "should equal to 1, but reveceid Output(\"Out\").size() = " + "should equal to 1, but received Output(\"Out\").size() = " << desc.Output("Out").size() << "."; return false; } @@ -2080,20 +2121,21 @@ struct SimpleOpTypeSetTeller : public Teller { auto inputs = desc.Inputs(); bool has_bias_qk = (inputs.find("BiasQK") == inputs.end()) ? false : true; if (has_bias_qk) { - auto* biasqk_desc = + auto* bias_qk_desc = block->FindVarRecursive(desc.Input("BiasQK").front()); - const auto biasqk_shape = biasqk_desc->GetShape(); + const auto bias_qk_shape = bias_qk_desc->GetShape(); // The BiasQK's shape requires to be // [batch, 1, 1, length] or [batch, head, length, length]. - bool has_same_shape = head_number == biasqk_shape[1] && - input_shape[1] == biasqk_shape[2] && - input_shape[1] == biasqk_shape[3]; - bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && - input_shape[1] == biasqk_shape[3]; - is_broadcastable = - is_broadcastable || (biasqk_shape[0] == 1 && biasqk_shape[1] == 1 && - input_shape[1] == biasqk_shape[2] && - input_shape[1] == biasqk_shape[3]); + bool has_same_shape = head_number == bias_qk_shape[1] && + input_shape[1] == bias_qk_shape[2] && + input_shape[1] == bias_qk_shape[3]; + bool is_broadcastable = bias_qk_shape[1] == 1 && + bias_qk_shape[2] == 1 && + input_shape[1] == bias_qk_shape[3]; + is_broadcastable = is_broadcastable || + (bias_qk_shape[0] == 1 && bias_qk_shape[1] == 1 && + input_shape[1] == bias_qk_shape[2] && + input_shape[1] == bias_qk_shape[3]); if (!(has_same_shape || is_broadcastable)) { VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] << ", 1, 1, " << input_shape[1] << "] " @@ -2101,8 +2143,9 @@ struct SimpleOpTypeSetTeller : public Teller { << input_shape[1] << ", " << input_shape[1] << "] " << "or [" << input_shape[0] << "/1, " << 1 << ", " << input_shape[1] << ", " << input_shape[1] << "] " - << "but got [" << biasqk_shape[0] << ", " << biasqk_shape[1] - << ", " << biasqk_shape[2] << ", " << biasqk_shape[3] << "]."; + << "but got [" << bias_qk_shape[0] << ", " << bias_qk_shape[1] + << ", " << bias_qk_shape[2] << ", " << bias_qk_shape[3] + << "]."; return false; } } else { @@ -2140,23 +2183,24 @@ struct SimpleOpTypeSetTeller : public Teller { auto inputs = desc.Inputs(); bool has_bias_qk = (inputs.find("BiasQK") == inputs.end()) ? false : true; if (has_bias_qk) { - auto* biasqk_desc = + auto* bias_qk_desc = block->FindVarRecursive(desc.Input("BiasQK").front()); - const auto biasqk_shape = biasqk_desc->GetShape(); + const auto bias_qk_shape = bias_qk_desc->GetShape(); // The BiasQK's shape requires to be // [batch, 1, 1, length] or [batch, head, length, length]. - bool has_same_shape = head_number == biasqk_shape[1] && - input_shape[1] == biasqk_shape[2] && - input_shape[1] == biasqk_shape[3]; - bool is_broadcastable = biasqk_shape[1] == 1 && biasqk_shape[2] == 1 && - input_shape[1] == biasqk_shape[3]; + bool has_same_shape = head_number == bias_qk_shape[1] && + input_shape[1] == bias_qk_shape[2] && + input_shape[1] == bias_qk_shape[3]; + bool is_broadcastable = bias_qk_shape[1] == 1 && + bias_qk_shape[2] == 1 && + input_shape[1] == bias_qk_shape[3]; if (!(has_same_shape || is_broadcastable)) { VLOG(3) << "The BiasQK's shape is invalid, expect [" << input_shape[0] << ", 1, 1, " << input_shape[1] << "] or [" << input_shape[0] << ", " << head_number << ", " << input_shape[1] << ", " - << input_shape[1] << "] but [" << biasqk_shape[0] << ", " - << biasqk_shape[1] << ", " << biasqk_shape[2] << ", " - << biasqk_shape[3] << "]."; + << input_shape[1] << "] but [" << bias_qk_shape[0] << ", " + << bias_qk_shape[1] << ", " << bias_qk_shape[2] << ", " + << bias_qk_shape[3] << "]."; return false; } } else { @@ -2237,6 +2281,11 @@ struct SimpleOpTypeSetTeller : public Teller { auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVarRecursive(x_var_name); const auto x_shape = x_var_desc->GetShape(); + + auto dtype = x_var_desc->GetDataType(); + if (dtype != framework::proto::VarType::FP32) { + return false; + } if (!with_dynamic_shape && (x_shape.size() == 1 || x_shape.empty())) { VLOG(3) << op_type << " op does not support input's dim is 1 or 0 in tensorrt " @@ -3197,8 +3246,10 @@ struct GenericPluginTeller : public Teller { bool operator()(const framework::OpDesc& desc, bool use_no_calib_int8 = false, bool with_dynamic_shape = false, + bool forbid_dynamic_op_enter_into_trt = false, bool use_explicit_quantization = false) override { const std::string op_type = desc.Type(); + // only consider dynamic_shape mode if (!with_dynamic_shape) { return false; @@ -3256,6 +3307,9 @@ struct GenericPluginTeller : public Teller { VLOG(3) << op_type << " has no DynamicMetaFn."; return false; } + if (forbid_dynamic_op_enter_into_trt && IsDynamicShapeOp(desc)) { + return false; + } return true; } } @@ -3267,6 +3321,7 @@ struct CustomPluginTeller : public Teller { bool operator()(const framework::OpDesc& desc, bool use_no_calib_int8 = false, bool with_dynamic_shape = false, + bool forbid_dynamic_op_enter_into_trt = false, bool use_explicit_quantization = false) override { const std::string op_type = desc.Type(); std::string expect_plugin_name; @@ -3285,6 +3340,9 @@ struct CustomPluginTeller : public Teller { return true; } return false; + if (forbid_dynamic_op_enter_into_trt && IsDynamicShapeOp(desc)) { + return false; + } } }; @@ -3293,8 +3351,10 @@ struct CustomGenericPluginTeller : public Teller { bool operator()(const framework::OpDesc& desc, bool use_no_calib_int8 = false, bool with_dynamic_shape = false, + bool forbid_dynamic_op_enter_into_trt = false, bool use_explicit_quantization = false) override { const std::string op_type = desc.Type(); + auto& op_meta_info_map = OpMetaInfoMap::Instance(); const auto& meta_info_map = op_meta_info_map.GetMap(); if (meta_info_map.count(op_type) > 0) { @@ -3319,15 +3379,20 @@ struct CustomGenericPluginTeller : public Teller { } VLOG(3) << op_type << " has no meta info"; return false; + if (forbid_dynamic_op_enter_into_trt && IsDynamicShapeOp(desc)) { + return false; + } } }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, bool with_dynamic_shape, + bool forbid_dynamic_op_enter_into_trt, bool use_explicit_quantization) { const std::string op_type = node->Op()->Type(); const framework::OpDesc desc = *node->Op(); + // do not support the op which is labeled the `skip_quant` if ((desc.HasAttr("namescope") && PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) == @@ -3338,6 +3403,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, if ((*default_teller)(desc, use_no_calib_int8, with_dynamic_shape, + forbid_dynamic_op_enter_into_trt, use_explicit_quantization)) { SetOpConverterType(node->Op(), OpConverterType::Default); return true; @@ -3346,6 +3412,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, if ((*generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape, + forbid_dynamic_op_enter_into_trt, use_explicit_quantization)) { SetOpConverterType(node->Op(), OpConverterType::GenericPluginCreater); return true; @@ -3354,6 +3421,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, if ((*custom_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape, + forbid_dynamic_op_enter_into_trt, use_explicit_quantization)) { SetOpConverterType(node->Op(), OpConverterType::CustomPluginCreater); return true; @@ -3362,6 +3430,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, if ((*custom_generic_plugin_teller)(desc, use_no_calib_int8, with_dynamic_shape, + forbid_dynamic_op_enter_into_trt, use_explicit_quantization)) { SetOpConverterType(node->Op(), OpConverterType::CustomGenericPluginCreater); return true; diff --git a/paddle/fluid/inference/tensorrt/op_teller.h b/paddle/fluid/inference/tensorrt/op_teller.h index 69a9061ebdb97..f955396b9ac11 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.h +++ b/paddle/fluid/inference/tensorrt/op_teller.h @@ -34,13 +34,14 @@ namespace tensorrt { /* * Single Op teller definition. - * One can override this and define a more complex tell logic, considerring more + * One can override this and define a more complex tell logic, considering more * issues such as op_desc. */ struct Teller { virtual bool operator()(const framework::OpDesc& desc, bool use_no_calib_int8 = false, bool with_dynamic_shape = false, + bool forbid_dynamic_op_enter_into_trt = false, bool use_explicit_quantization = false) = 0; virtual ~Teller() = default; @@ -77,6 +78,7 @@ class OpTeller { bool Tell(const framework::ir::Node* node, bool use_no_calib_int8 = false, bool with_dynamic_shape = false, + bool forbid_dynamic_op_enter_into_trt = false, bool use_explicit_quantization = false); std::unique_ptr& GetDefaultTeller() { return tellers_.at(0); } diff --git a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu index 76d6f1c3fac94..00e0e2e0441e2 100644 --- a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.cu @@ -279,7 +279,7 @@ void AnchorGeneratorPlugin::configurePlugin( const bool* input_is_broadcast, const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, - int max_batct_size) TRT_NOEXCEPT {} + int max_batch_size) TRT_NOEXCEPT {} nvinfer1::IPluginV2Ext* AnchorGeneratorPlugin::clone() const TRT_NOEXCEPT { auto plugin = new AnchorGeneratorPlugin(data_type_, diff --git a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h index 41766db5f0314..72f11c76767eb 100644 --- a/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/anchor_generator_op_plugin.h @@ -84,7 +84,7 @@ class AnchorGeneratorPlugin : public nvinfer1::IPluginV2Ext { const bool* input_is_broadcast, const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, - int max_batct_size) TRT_NOEXCEPT override; + int max_batch_size) TRT_NOEXCEPT override; nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override; private: @@ -148,10 +148,11 @@ class AnchorGeneratorPluginDynamic : public DynamicPluginTensorRT { AnchorGeneratorPluginDynamic(void const* data, size_t length); ~AnchorGeneratorPluginDynamic(); nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override; - nvinfer1::DimsExprs getOutputDimensions(int outputIndex, - const nvinfer1::DimsExprs* inputs, - int nbInputs, - nvinfer1::IExprBuilder& exprBuilder) + nvinfer1::DimsExprs getOutputDimensions( + int outputIndex, + const nvinfer1::DimsExprs* inputs, + int nbInputs, + nvinfer1::IExprBuilder& exprBuilder) // NOLINT TRT_NOEXCEPT override; bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu index 828f036041927..f7154f6c0dd01 100644 --- a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.cu @@ -829,7 +829,7 @@ void DeformableConvPlugin::configurePlugin( const bool* input_is_broadcast, const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, - int max_batct_size) TRT_NOEXCEPT { + int max_batch_size) TRT_NOEXCEPT { PADDLE_ENFORCE_EQ( nb_inputs, 3, diff --git a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h index dd0a1d5aa9ccb..5a0fbe7e05c16 100644 --- a/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/deformable_conv_op_plugin.h @@ -108,7 +108,7 @@ class DeformableConvPlugin : public nvinfer1::IPluginV2Ext { const bool* input_is_broadcast, const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, - int max_batct_size) TRT_NOEXCEPT override; + int max_batch_size) TRT_NOEXCEPT override; nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override; private: diff --git a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h index e4c76e2d652ee..2d5dde9190103 100644 --- a/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/preln_groupnorm_act_op_plugin.h @@ -144,7 +144,7 @@ class PrelnGroupnormActPluginDynamic : public DynamicPluginTensorRT { const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override { // sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2 - // contians two buffers for sum and squared sum; + // contains two buffers for sum and squared sum; ws_ = sizeof(float) * 2 * in[0].max.d[0] * groups_; } diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h index 0a93559f5ee2c..1260bbb8e2917 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/skip_groupnorm_act_op_plugin.h @@ -139,7 +139,7 @@ class SkipGroupnormActPluginDynamic : public DynamicPluginTensorRT { const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) TRT_NOEXCEPT override { // sizeof(float2) * maxBatchSize * maxNumberOfGroup. float2 - // contians two buffers for sum and squared sum; + // contains two buffers for sum and squared sum; ws_ = sizeof(float) * 2 * in[0].max.d[0] * groups_; } diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc index 93132d4bf34eb..637bd84deaff0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc @@ -19,53 +19,53 @@ namespace inference { namespace tensorrt { namespace plugin { -inline void Seria(void*& buffer, // NOLINT - const std::vector& input_dims, - nvinfer1::DataType data_type, - nvinfer1::PluginFormat data_format, - bool with_fp16) { +inline void Serialize(void*& buffer, // NOLINT + const std::vector& input_dims, + nvinfer1::DataType data_type, + nvinfer1::PluginFormat data_format, + bool with_fp16) { SerializeValue(&buffer, input_dims); SerializeValue(&buffer, data_type); SerializeValue(&buffer, data_format); SerializeValue(&buffer, with_fp16); } -inline void Deseria(void const*& serial_data, - size_t& serial_length, // NOLINT - std::vector* input_dims, - nvinfer1::DataType* data_type, - nvinfer1::PluginFormat* data_format, - bool* with_fp16) { +inline void Deserialize(void const*& serial_data, // NOLINT + size_t& serial_length, // NOLINT + std::vector* input_dims, + nvinfer1::DataType* data_type, + nvinfer1::PluginFormat* data_format, + bool* with_fp16) { DeserializeValue(&serial_data, &serial_length, input_dims); DeserializeValue(&serial_data, &serial_length, data_type); DeserializeValue(&serial_data, &serial_length, data_format); DeserializeValue(&serial_data, &serial_length, with_fp16); } -inline size_t SeriaSize(const std::vector& input_dims, - nvinfer1::DataType data_type, - nvinfer1::PluginFormat data_format, - bool with_fp16) { +inline size_t SerializeSize(const std::vector& input_dims, + nvinfer1::DataType data_type, + nvinfer1::PluginFormat data_format, + bool with_fp16) { return (SerializedSize(input_dims) + SerializedSize(data_type) + SerializedSize(data_format) + SerializedSize(with_fp16)); } void PluginTensorRT::serializeBase(void*& buffer) const { - Seria(buffer, input_dims_, data_type_, data_format_, with_fp16_); + Serialize(buffer, input_dims_, data_type_, data_format_, with_fp16_); } void PluginTensorRT::deserializeBase(void const*& serial_data, size_t& serial_length) { - Deseria(serial_data, - serial_length, - &input_dims_, - &data_type_, - &data_format_, - &with_fp16_); + Deserialize(serial_data, + serial_length, + &input_dims_, + &data_type_, + &data_format_, + &with_fp16_); } size_t PluginTensorRT::getBaseSerializationSize() const { - return SeriaSize(input_dims_, data_type_, data_format_, with_fp16_); + return SerializeSize(input_dims_, data_type_, data_format_, with_fp16_); } bool PluginTensorRT::supportsFormat( @@ -87,21 +87,21 @@ void PluginTensorRT::configureWithFormat(const nvinfer1::Dims* input_dims, } void PluginTensorRTV2Ext::serializeBase(void*& buffer) const { - Seria(buffer, input_dims_, data_type_, data_format_, with_fp16_); + Serialize(buffer, input_dims_, data_type_, data_format_, with_fp16_); } void PluginTensorRTV2Ext::deserializeBase(void const*& serial_data, size_t& serial_length) { - Deseria(serial_data, - serial_length, - &input_dims_, - &data_type_, - &data_format_, - &with_fp16_); + Deserialize(serial_data, + serial_length, + &input_dims_, + &data_type_, + &data_format_, + &with_fp16_); } size_t PluginTensorRTV2Ext::getBaseSerializationSize() const { - return SeriaSize(input_dims_, data_type_, data_format_, with_fp16_); + return SerializeSize(input_dims_, data_type_, data_format_, with_fp16_); } void PluginTensorRTV2Ext::configurePlugin( diff --git a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu index a8bf130978dfd..531c6776fb5e7 100644 --- a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu @@ -421,7 +421,7 @@ void YoloBoxPlugin::configurePlugin(const nvinfer1::Dims* input_dims, const bool* input_is_broadcast, const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, - int max_batct_size) TRT_NOEXCEPT {} + int max_batch_size) TRT_NOEXCEPT {} nvinfer1::IPluginV2Ext* YoloBoxPlugin::clone() const TRT_NOEXCEPT { return new YoloBoxPlugin(data_type_, diff --git a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h index 6c4b6f80dd148..36bc5603b460d 100644 --- a/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h @@ -93,7 +93,7 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext { const bool* input_is_broadcast, const bool* output_is_broadcast, nvinfer1::PluginFormat float_format, - int max_batct_size) TRT_NOEXCEPT override; + int max_batch_size) TRT_NOEXCEPT override; nvinfer1::IPluginV2Ext* clone() const TRT_NOEXCEPT override; private: diff --git a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc index 26cb5166362b2..d4631f7057582 100644 --- a/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc +++ b/paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc @@ -76,7 +76,7 @@ paddle::any PluginArgumentMappingContext::Attr( break; }; default: { - LOG(ERROR) << "Can't conver op's attribute [" << attr_name + LOG(ERROR) << "Can't cover op's attribute [" << attr_name << "] to paddle any."; } } diff --git a/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc b/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc index 97090518153d1..85dddfea2a7c7 100644 --- a/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc +++ b/paddle/fluid/inference/tensorrt/test_arg_mapping_context.cc @@ -21,7 +21,7 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(ArgMappingContexTest, BasicFunction) { +TEST(ArgMappingContextTest, BasicFunction) { paddle::framework::proto::OpDesc op; op.set_type("imaged_op"); auto *input_var = op.add_inputs(); @@ -86,8 +86,8 @@ TEST(ArgMappingContexTest, BasicFunction) { int int_attr = any_cast(context.Attr("int_attr")); EXPECT_EQ(int_attr, 1); - float flaot_attr = any_cast(context.Attr("float_attr")); - EXPECT_EQ(flaot_attr, 1); + float float_attr = any_cast(context.Attr("float_attr")); + EXPECT_EQ(float_attr, 1); std::string string_attr = any_cast(context.Attr("string_attr")); EXPECT_EQ(string_attr, "1"); diff --git a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc index 3cb30da55e407..d611b2ff32d5d 100644 --- a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc +++ b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.cc @@ -30,7 +30,6 @@ TRTInt8Calibrator::TRTInt8Calibrator( std::string engine_name, const platform::Place place) : batch_size_(batch_size), engine_name_(engine_name) { - int i = 0; VLOG(4) << "Init a new calibrator: " << engine_name_; for (const auto& it : buffers) { phi::DenseTensor temp_tensor; @@ -43,7 +42,6 @@ TRTInt8Calibrator::TRTInt8Calibrator( data_buffers_[input_name] = std::pair( static_cast(temp_tensor.mutable_data(place)), data_size); - i += 1; } } diff --git a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h index 82bb7a64168b4..43386ca324c54 100644 --- a/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h +++ b/paddle/fluid/inference/tensorrt/trt_int8_calibrator.h @@ -87,7 +87,7 @@ class TRTCalibratorEngine { std::unique_ptr engine_; }; /* - * Manager to control the TensorRT Int8 calibration creation and deltetion. + * Manager to control the TensorRT Int8 calibration creation and deletion. */ class TRTCalibratorEngineManager { public: diff --git a/paddle/fluid/inference/utils/shape_range_info.proto b/paddle/fluid/inference/utils/shape_range_info.proto index 53f018cb59348..9e980de9d0fd5 100644 --- a/paddle/fluid/inference/utils/shape_range_info.proto +++ b/paddle/fluid/inference/utils/shape_range_info.proto @@ -16,7 +16,7 @@ syntax = "proto2"; package paddle.inference.proto; // To support trt dynamic shape, record the runtime shape -// information of all tmp tensors in the Compution graph. +// information of all tmp tensors in the Computation graph. message ShapeRangeInfos { message ShapeRangeInfo { required string name = 1; diff --git a/paddle/fluid/inference/utils/singleton.h b/paddle/fluid/inference/utils/singleton.h index 5c2a1bf563f21..82a50e6042c76 100644 --- a/paddle/fluid/inference/utils/singleton.h +++ b/paddle/fluid/inference/utils/singleton.h @@ -35,7 +35,7 @@ struct Singleton { }; /* - * An registor for any type. + * An Registry for any type. * NOTE not thread-safe. */ template diff --git a/paddle/fluid/inference/utils/table_printer.cc b/paddle/fluid/inference/utils/table_printer.cc index ba7a8d342e352..19b4a94834a17 100644 --- a/paddle/fluid/inference/utils/table_printer.cc +++ b/paddle/fluid/inference/utils/table_printer.cc @@ -57,18 +57,18 @@ std::string TablePrinter::PrintTable() { } TablePrinter::TablePrinter(const std::vector& header) { - size_t terminal_witdh = 500; + size_t terminal_width = 500; #ifdef _WIN32 CONSOLE_SCREEN_BUFFER_INFO csbi; int ret = GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); if (ret && (csbi.dwSize.X != 0)) { - terminal_witdh = csbi.dwSize.X; + terminal_width = csbi.dwSize.X; } #else struct winsize terminal_size; int status = ioctl(STDOUT_FILENO, TIOCGWINSZ, &terminal_size); if (status == 0 && terminal_size.ws_col != 0) { - terminal_witdh = terminal_size.ws_col; + terminal_width = terminal_size.ws_col; } #endif @@ -77,8 +77,8 @@ TablePrinter::TablePrinter(const std::vector& header) { widths_.emplace_back(0); } - terminal_witdh = terminal_witdh - (2 * num_cols) - (num_cols + 1); - int avg_width = static_cast(terminal_witdh / num_cols); // NOLINT + terminal_width = terminal_width - (2 * num_cols) - (num_cols + 1); + int avg_width = static_cast(terminal_width / num_cols); // NOLINT for (size_t i = 0; i < num_cols; ++i) { shares_.emplace_back(avg_width); diff --git a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc index 99af9a45b6dc8..3ba808c82b9a6 100644 --- a/paddle/fluid/ir_adaptor/translator/attribute_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/attribute_translator.cc @@ -70,7 +70,9 @@ class AttributeVisitor { virtual pir::Attribute operator()( const paddle::experimental::Scalar& scalar) { VLOG(10) << "translating scalar"; - IR_THROW("not support translating paddle::experimental::Scalar"); + PADDLE_THROW( + phi::errors::Unimplemented("not support " + "translating paddle::experimental::Scalar")); } virtual pir::Attribute operator()(const std::vector& strs) { diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 1cb0ab7a3b01a..6d151b48cea19 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -48,7 +48,7 @@ def to_phi_and_fluid_op_name(op_item): op_compat_infos = yaml.safe_load(f) op_name_mappings: Dict[str, str] = {} op_arg_name_mappings: Dict[str, Dict[str, str]] = {} - op_mutable_attribues: Dict[str, Set[str]] = {} + op_mutable_attributes: Dict[str, Set[str]] = {} op_mutable_attribute_infos: Dict[str, Dict[str, List[str]]] = {} for op_compat_item in op_compat_infos: @@ -70,15 +70,15 @@ def insert_new_arg_mappings(op_name: str, arg_mapping: Dict[str, str]): def insert_new_mutable_attributes( op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]] ): - if op_name not in op_mutable_attribues: - op_mutable_attribues[op_name] = set() + if op_name not in op_mutable_attributes: + op_mutable_attributes[op_name] = set() if op_name not in op_mutable_attribute_infos: op_mutable_attribute_infos[op_name] = {} for ( attribute_name, mutable_attribute_info, ) in mutable_attribute_infos.items(): - op_mutable_attribues[op_name].add(attribute_name) + op_mutable_attributes[op_name].add(attribute_name) op_mutable_attribute_infos[op_name][attribute_name] = [] for k, v in mutable_attribute_info.items(): if k == 'tensor_name' or k == 'tensors_name': @@ -164,16 +164,17 @@ def insert_new_mutable_attributes( "atol_tensor": "TolTensor", "out": "Out", } + op_arg_name_mappings['fused_softmax_mask_grad'].update({"out": "Softmax"}) op_arg_name_mappings['push_sparse_v2'].update( {"out_grad_in": "Out@GRAD", "out_grad_out": "Out@GRAD"} ) - op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") + op_name_normalizer_template = env.get_template("op_compat_info.cc.j2") with open(output_source_file, 'wt') as f: - op_compat_definition = op_name_normailzer_template.render( + op_compat_definition = op_name_normalizer_template.render( op_name_pairs=op_name_mappings, op_arg_name_pairs=op_arg_name_mappings, - op_mutable_attributes=op_mutable_attribues, + op_mutable_attributes=op_mutable_attributes, op_mutable_attribute_infos=op_mutable_attribute_infos, ) f.write(op_compat_definition) @@ -184,7 +185,7 @@ def insert_new_mutable_attributes( # ===================================== def ParseArguments(): parser = argparse.ArgumentParser( - description='Generate OP Compatiable info Files By Yaml' + description='Generate OP Compatible info Files By Yaml' ) parser.add_argument('--op_compat_yaml_file', type=str) parser.add_argument('--output_source_file', type=str) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 6e1ec454b6bab..f41a25fe9717c 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -198,9 +198,11 @@ inline pir::Operation* InsertFullOperationForAttributeInput( inline pir::Operation* InsertFullArrayOperationForAttributeInput( pir::IrContext* ctx, pir::Block* block, pir::Attribute attr) { - IR_ENFORCE(attr.isa(), - "Encounter non IntArray type when trying to insert IntArray " - "mutable attribute"); + PADDLE_ENFORCE_EQ( + attr.isa(), + true, + phi::errors::InvalidArgument("Encounter non IntArray type when trying to " + "insert IntArray mutable attribute")); phi::IntArray int_array = attr.dyn_cast().data(); pir::Builder builder(ctx, block); dialect::FullIntArrayOp full_int_array_op = @@ -313,20 +315,24 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx, std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); auto legacy_input_vars = op_desc.Input(legacy_input_name, true); - IR_ENFORCE(legacy_input_vars.size() <= 1, - "Do not support duplicable tensor input, when op have multi " - "kernels. OP is %s", - op_desc.Type()); + PADDLE_ENFORCE_EQ( + legacy_input_vars.size() <= 1, + true, + phi::errors::InvalidArgument("Do not support duplicable tensor input, " + "when op have multi kernels. OP is %s.", + op_desc.Type())); if (legacy_input_vars.empty()) { need_inputs_sig.emplace_back(""); continue; } VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]); - IR_ENFORCE(var != nullptr, - "[op:%s] Input %s should not be null", - op_desc.Type(), - legacy_input_vars[0]); + PADDLE_ENFORCE_NE( + var, + nullptr, + phi::errors::InvalidArgument("[Op:%s] Input %s should not be null", + op_desc.Type(), + legacy_input_vars[0])); if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR) { need_inputs_sig.emplace_back("dense"); @@ -334,9 +340,10 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx, paddle::framework::proto::VarType::SELECTED_ROWS) { need_inputs_sig.emplace_back("selected_rows"); } else { - IR_THROW("Op %d only support densetensor and selected_rows, but not %d", - op_desc.Type(), - var->GetType()); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op %d only support dense tensor and selected_rows, but not %d", + op_desc.Type(), + var->GetType())); } } @@ -364,19 +371,22 @@ pir::OpInfo OpTranscriber::LookUpOpInfo(pir::IrContext* ctx, } } - IR_ENFORCE(!target_op_name.empty(), - "Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name); + PADDLE_ENFORCE_EQ( + !target_op_name.empty(), + true, + phi::errors::InvalidArgument("Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name)); target_op_name = GetPrefix(ctx, op_desc) + target_op_name; if (IsInplace(op_desc) && *target_op_name.rbegin() != '_') { target_op_name += "_"; } if (!op_info) { - IR_THROW("Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name)); } return op_info; @@ -429,9 +439,10 @@ pir::Value OpTranscriber::GetAttributeAsInput(pir::IrContext* ctx, op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name); if (!op_desc.HasAttr(legacy_attr_name)) { - IR_THROW("Op %s arg %s should not be zero size", - op_desc.Type(), - legacy_attr_name); + PADDLE_THROW( + phi::errors::InvalidArgument("Op %s arg %s should not be zero size", + op_desc.Type(), + legacy_attr_name)); } paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" @@ -532,10 +543,12 @@ std::vector OpTranscriber::GenerateOperationInput( // Vector if (legacy_input_vars.size() == 1) { VarDesc* var = op_desc.Block()->FindVarRecursive(legacy_input_vars[0]); - IR_ENFORCE(var != nullptr, - "[op:%s] Input %s should not be null", - op_desc.Type(), - legacy_input_vars[0]); + PADDLE_ENFORCE_NE( + var, + nullptr, + phi::errors::InvalidArgument("[op:%s] Input %s should not be null", + op_desc.Type(), + legacy_input_vars[0])); if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { is_vector = false; @@ -544,15 +557,19 @@ std::vector OpTranscriber::GenerateOperationInput( // if src type is Tensor if (!is_vector) { - IR_ENFORCE(legacy_input_vars.size() == 1u, - "Input %s not found when parsing op %s", - info.name, - op_desc.Type()); - IR_ENFORCE(param_map->count(legacy_input_vars[0]), - "Input [%s: %s] of op [%s] not found in param map", - info.name, - legacy_input_vars[0], - op_desc.Type()); + PADDLE_ENFORCE_EQ( + legacy_input_vars.size(), + 1UL, + phi::errors::InvalidArgument("Input %s not found when parsing op %s", + info.name, + op_desc.Type())); + PADDLE_ENFORCE_NE(param_map->count(legacy_input_vars[0]), + 0UL, + phi::errors::InvalidArgument( + "Input [%s: %s] of op [%s] not found in param map", + info.name, + legacy_input_vars[0], + op_desc.Type())); auto defining_info = (*param_map)[legacy_input_vars[0]]; op_inputs.push_back(defining_info.value); @@ -593,10 +610,13 @@ OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx, VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "] optional " << info.name << " :" << info.type_name << " " << legacy_output_name; - IR_ENFORCE(info.optional, - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_output_name); + PADDLE_ENFORCE_EQ( + info.optional, + true, + phi::errors::InvalidArgument( + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name)); op_output_types.emplace_back(nullptr); continue; } @@ -613,10 +633,12 @@ OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx, // Vector if (legacy_output_vars.size() == 1) { VarDesc* var = block->FindVarRecursive(legacy_output_vars[0]); - IR_ENFORCE(var != nullptr, - "[op:%s] Output %s should not be null", - op_desc.Type(), - legacy_output_vars[0]); + PADDLE_ENFORCE_NE( + var, + nullptr, + phi::errors::InvalidArgument("[op:%s] Output %s should not be null", + op_desc.Type(), + legacy_output_vars[0])); if (var->GetType() == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { pir::Type translated_var_type = @@ -640,10 +662,12 @@ OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx, auto& var_name = legacy_output_vars[0]; VarDesc* var = block->FindVarRecursive(var_name); - IR_ENFORCE(var != nullptr, - "[op:%s] Output %s should not be null", - op_desc.Type(), - var_name); + PADDLE_ENFORCE_NE( + var, + nullptr, + phi::errors::InvalidArgument("[op:%s] Output %s should not be null", + op_desc.Type(), + var_name)); VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " var: " << var_name << " type: " << var->GetType(); @@ -669,10 +693,12 @@ OpTranscriber::GenerateOperationOutput(pir::IrContext* ctx, continue; } VarDesc* var = block->FindVarRecursive(var_name); - IR_ENFORCE(var != nullptr, - "[op:%s] Output %s should not be null", - op_desc.Type(), - var_name); + PADDLE_ENFORCE_NE( + var, + nullptr, + phi::errors::InvalidArgument("[op:%s] Output %s should not be null", + op_desc.Type(), + var_name)); VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "]" << info.name << " var: " << var_name << " type: " << var->GetType(); @@ -842,13 +868,17 @@ struct AssignOpTranscriber : public OpTranscriber { const OpDesc& op_desc) override { std::string target_op_name; - IR_ENFORCE( - op_desc.HasInput("X"), "op %s should have input `X`", op_desc.Type()); + PADDLE_ENFORCE_EQ(op_desc.HasInput("X"), + true, + phi::errors::InvalidArgument( + "op %s should have input `X`", op_desc.Type())); const auto& input_vars = op_desc.Input("X"); - IR_ENFORCE(input_vars.size() == 1, - "op %s should have one input `X`, but got %d.", - op_desc.Type(), - input_vars.size()); + PADDLE_ENFORCE_EQ(input_vars.size() == 1, + true, + phi::errors::InvalidArgument( + "op %s should have one input `X`, but got %d.", + op_desc.Type(), + input_vars.size())); const auto* input_var = op_desc.Block()->FindVarRecursive(input_vars[0]); if (input_var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY) { target_op_name = dialect::AssignArray_Op::name(); @@ -858,7 +888,8 @@ struct AssignOpTranscriber : public OpTranscriber { const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op assign should have corresponding OpInfo %s", target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op assign should have corresponding OpInfo %s.", target_op_name)); } return op_info; @@ -935,9 +966,10 @@ struct AssignValueOpTranscriber : public OpTranscriber { std::string target_op_name = "pd_op.assign_value"; const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( - "Op assign_value should have corresponding OpInfo " - "pd_op.assign_value"); + PADDLE_ENFORCE(false, + phi::errors::InvalidArgument( + "Op assign_value should have corresponding OpInfo " + "pd_op.assign_value")); } return op_info; @@ -968,7 +1000,8 @@ struct AssignValueOpTranscriber : public OpTranscriber { if (op_desc.HasAttr("shape")) { legacy_attr = op_desc.GetAttr("shape"); } else { - IR_THROW("Op assign_value should have attribute `shape` but not find"); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op assign_value should have attribute `shape` but not find")); } pir::Attribute attr_shape = attribute_translator(attr_info_maps.at("shape").type_name, legacy_attr); @@ -977,7 +1010,8 @@ struct AssignValueOpTranscriber : public OpTranscriber { if (op_desc.HasAttr("dtype")) { legacy_attr = op_desc.GetAttr("dtype"); } else { - IR_THROW("Op assign_value should have attribute `dtype` but not find"); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op assign_value should have attribute `dtype` but not find")); } pir::Attribute attr_dtype = attribute_translator(attr_info_maps.at("dtype").type_name, legacy_attr); @@ -1005,10 +1039,11 @@ struct AssignValueOpTranscriber : public OpTranscriber { } } - IR_ENFORCE( - attribute_map.find("values") != attribute_map.end(), - "Op assign_value should have attribute `**_values` or `values` but " - "not find"); + PADDLE_ENFORCE_NE( + attribute_map.find("values"), + attribute_map.end(), + phi::errors::InvalidArgument("Op assign_value should have attribute " + "`**_values` or `values` but not find")); TranslateOpDistAttribute(op_desc, &attribute_map); @@ -1056,16 +1091,20 @@ pir::Value TranslateDropOutStateIn(pir::IrContext* ctx, // `DropoutState` is a tensor VarDesc* dropout_state = op_desc.Block()->FindVarRecursive(legacy_output_vars[0]); - IR_ENFORCE(dropout_state != nullptr, - "[op:%s] Output %s should not be null", - op_desc.Type(), - legacy_output_vars[0]); + PADDLE_ENFORCE_NE( + dropout_state, + nullptr, + phi::errors::InvalidArgument("[op:%s] Output %s should not be null", + op_desc.Type(), + legacy_output_vars[0])); auto& type_translator = TypeTranslator::instance(); pir::Type translated_var_type = type_translator[dropout_state->GetType()](ctx, *dropout_state); - IR_ENFORCE( + PADDLE_ENFORCE_EQ( translated_var_type.isa(), - "Unexpected: Rnn Op's output DropoutState should be a DenseTensor"); + true, + phi::errors::InvalidArgument( + "Unexpected: Rnn Op's output DropoutState should be a DenseTensor")); auto tensor_type = translated_var_type.dyn_cast(); pir::Builder builder(ctx, block); @@ -1116,9 +1155,10 @@ struct EmbeddingGradOpTranscriber : public OpTranscriber { << target_op_name; auto op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name)); } return op_info; @@ -1194,7 +1234,10 @@ struct SplitOpTranscriber : public OpTranscriber { std::vector op_inputs; // process first input auto x_input_vars = op_desc.Input("X"); - IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor"); + PADDLE_ENFORCE_EQ( + x_input_vars.size(), + 1UL, + phi::errors::InvalidArgument("x input of split MUST be a tensor")); auto x_defining_info = (*param_map)[x_input_vars[0]]; op_inputs.push_back(x_defining_info.value); @@ -1224,8 +1267,10 @@ struct SplitOpTranscriber : public OpTranscriber { !op_desc.Input("AxisTensor").empty()) { // get axis from input auto axis_var_list = op_desc.Input("AxisTensor"); - IR_ENFORCE(axis_var_list.size() == 1, - "axis tensor input of split MUST be a tensor"); + PADDLE_ENFORCE_EQ(axis_var_list.size(), + 1UL, + phi::errors::InvalidArgument( + "axis tensor input of split MUST be a tensor")); auto axis_defining_info = (*param_map)[axis_var_list[0]]; op_inputs.push_back(axis_defining_info.value); } else { @@ -1255,6 +1300,16 @@ struct SplitOpTranscriber : public OpTranscriber { return attribute_map; } +#ifdef PADDLE_WITH_DNNL + else if (op_desc.HasAttr("mkldnn_data_type")) { // NOLINT + pir::AttributeMap attribute_map = { + {"mkldnn_data_type", + pir::StrAttribute::get( + ctx, op_desc.GetAttrIfExists("mkldnn_data_type"))}, + }; + return attribute_map; + } +#endif return {}; } @@ -1262,17 +1317,20 @@ struct SplitOpTranscriber : public OpTranscriber { pir::OpInfo LookUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc) override { int num = paddle::get(op_desc.GetAttr("num")); + auto prefix = GetPrefix(ctx, op_desc); std::string target_op_name; if (num > 0) { - target_op_name = "pd_op.split_with_num"; + target_op_name = prefix + "split_with_num"; } else { - target_op_name = "pd_op.split"; + target_op_name = prefix + "split"; } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op assign_value should have corresponding OpInfo pd_op.split"); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op assign_value should have corresponding OpInfo %s.", + target_op_name)); } return op_info; @@ -1359,12 +1417,12 @@ struct AddNOpTranscriber : public OpTranscriber { GetPrefix(ctx, op_desc) + OpNameCompatibleMapping(op_desc.Type()); if (IsInplace(op_desc)) { target_op_name += "_"; - } else { - target_op_name += "_with_kernel"; } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op add_n should have corresponding OpInfo %s", target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op add_n should have corresponding OpInfo %s", target_op_name)); } return op_info; @@ -1383,9 +1441,9 @@ struct TrilAndTriuOpTranscriber : public OpTranscriber { } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( - "Op tril_triu should have corresponding OpInfo pd_op.tril or " - "pd_op.triu."); + PADDLE_THROW( + phi::errors::InvalidArgument("Op tril_triu should have corresponding " + "OpInfo pd_op.tril or pd_op.triu.")); } return op_info; @@ -1404,10 +1462,11 @@ struct TrilAndTriuGradOpTranscriber : public OpTranscriber { } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( - "Op tril_triu_grad should have corresponding OpInfo pd_op.tril_grad " - "or " - "pd_op.triu_grad."); + PADDLE_THROW( + phi::errors::InvalidArgument("Op tril_triu_grad should have " + "corresponding OpInfo pd_op.tril_grad " + "or " + "pd_op.triu_grad.")); } return op_info; @@ -1421,27 +1480,36 @@ ValueInfo GetTensorInfoByVarName(const OpDesc& op_desc, const std::vector& names, TranslationContext* param_map, const std::string& var_name) { - IR_ENFORCE(names.size() == 1, - "Expected op[%s]'s input %s has only 1 variable, but got %d", - op_desc.Type(), - var_name, - names.size()); + PADDLE_ENFORCE_EQ( + names.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has only 1 variable, but got %d", + op_desc.Type(), + var_name, + names.size())); const auto& name = names[0]; - IR_ENFORCE(param_map->count(name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - name); + PADDLE_ENFORCE_GT( + param_map->count(name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", op_desc.Type(), name)); const auto& defining_info = param_map->at(name); pir::Value value = defining_info.value; - IR_ENFORCE( - value, "Expected op[%s]'s input %s is not null", op_desc.Type(), name); + PADDLE_ENFORCE_NE( + value, + nullptr, + phi::errors::PreconditionNotMet( + "Expected op[%s]'s input %s is not null", op_desc.Type(), name)); const pir::Type& type = value.type(); - IR_ENFORCE(type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - name, - type); + PADDLE_ENFORCE_EQ(type.isa(), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + name, + type)); dialect::DenseTensorType tensor_type = type.dyn_cast(); @@ -1469,9 +1537,10 @@ struct MulOpTranscriber : public OpTranscriber { const std::string& target_op_name = paddle::dialect::MatmulOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name)); } return op_info; } @@ -1506,24 +1575,30 @@ struct MulOpTranscriber : public OpTranscriber { const auto& [x_shape, x_tensor_type, x_value] = x_info; - IR_ENFORCE(x_num_col_dims <= static_cast(x_shape.size()), - "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " - "dim of input X %s, but got %d", - op_desc.Type(), - x_shape.size(), - x_num_col_dims); + PADDLE_ENFORCE_EQ( + x_num_col_dims <= static_cast(x_shape.size()), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " + "dim of input X %s, but got %d", + op_desc.Type(), + x_shape.size(), + x_num_col_dims)); ValueInfo y_info = GetTensorInfoByVarName( op_desc, op_desc.Input("Y", true), param_map, "Y"); const auto& [y_shape, y_tensor_type, y_value] = y_info; - IR_ENFORCE(y_num_col_dims <= static_cast(y_shape.size()), - "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " - "dim of input Y %s, but got %d", - op_desc.Type(), - y_shape.size(), - y_num_col_dims); + PADDLE_ENFORCE_EQ( + y_num_col_dims <= static_cast(y_shape.size()), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " + "dim of input Y %s, but got %d", + op_desc.Type(), + y_shape.size(), + y_num_col_dims)); pir::Builder builder(ctx, block); @@ -1638,9 +1713,10 @@ struct MulGradOpTranscriber : public OpTranscriber { << target_op_name; const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name)); } return op_info; } @@ -1675,24 +1751,30 @@ struct MulGradOpTranscriber : public OpTranscriber { const auto& [x_shape, x_tensor_type, x_value] = x_info; - IR_ENFORCE(x_num_col_dims <= static_cast(x_shape.size()), - "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " - "dim of input X %s, but got %d", - op_desc.Type(), - x_shape.size(), - x_num_col_dims); + PADDLE_ENFORCE_EQ( + x_num_col_dims <= static_cast(x_shape.size()), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s attr `x_num_col_dims` less than or equal to " + "dim of input X %s, but got %d", + op_desc.Type(), + x_shape.size(), + x_num_col_dims)); ValueInfo y_info = GetTensorInfoByVarName( op_desc, op_desc.Input("Y", true), param_map, "Y"); const auto& [y_shape, y_tensor_type, y_value] = y_info; - IR_ENFORCE(y_num_col_dims <= static_cast(y_shape.size()), - "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " - "dim of input Y %s, but got %d", - op_desc.Type(), - y_shape.size(), - y_num_col_dims); + PADDLE_ENFORCE_EQ( + y_num_col_dims <= static_cast(y_shape.size()), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s attr `y_num_col_dims` less than or equal to " + "dim of input Y %s, but got %d", + op_desc.Type(), + y_shape.size(), + y_num_col_dims)); ValueInfo out_grad_info = GetTensorInfoByVarName( op_desc, op_desc.Input("Out@GRAD", true), param_map, "Out@GRAD"); @@ -1770,16 +1852,20 @@ struct MulGradOpTranscriber : public OpTranscriber { auto gradReshape = [&](const std::string& var_name) { const auto& grad_output = op_desc.Output(var_name); - IR_ENFORCE(grad_output.size() == 1, - "Expected op[%s]'s output %s has only 1 variable, but got %d", - op_desc.Type(), - var_name, - grad_output.size()); + PADDLE_ENFORCE_EQ( + grad_output.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s output %s has only 1 variable, but got %d", + op_desc.Type(), + var_name, + grad_output.size())); const auto& grad_var_name = grad_output[0]; auto idx_iter = arg_to_idx.find(grad_var_name); if (idx_iter == arg_to_idx.end()) { - IR_THROW("op[%s] should have got its %s", op_desc.Type(), var_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "op[%s] should have got its %s", op_desc.Type(), var_name)); } auto [idx_in_op, idx_in_vec] = idx_iter->second; VLOG(10) << "[output recording]" @@ -1788,26 +1874,32 @@ struct MulGradOpTranscriber : public OpTranscriber { VarDesc* var_desc = op_desc.Block()->FindVarRecursive( op_desc.Input(var_name.substr(0, 1))[0]); - IR_ENFORCE(var_desc != nullptr, - "[op:%s] Input %s should not be null", - op_desc.Type(), - var_name.substr(0, 1)); + PADDLE_ENFORCE_NE( + var_desc, + nullptr, + phi::errors::InvalidArgument("[op:%s] Input %s should not be null", + op_desc.Type(), + var_name.substr(0, 1))); std::vector shape = var_desc->GetShape(); DenseTensorTypeStorage::Dim dim = common::make_ddim(shape); pir::Value value_res = operation->result(idx_in_op); auto reshape_op = builder.Build(value_res, shape); - - IR_ENFORCE(value_res, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - grad_var_name); + PADDLE_ENFORCE_NE(value_res, + nullptr, + phi::errors::PreconditionNotMet( + "Expected op[%s]'s input %s is not null", + op_desc.Type(), + grad_var_name)); pir::Type grad_type = value_res.type(); - IR_ENFORCE(grad_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - grad_var_name, - grad_type); + PADDLE_ENFORCE_EQ( + grad_type.isa(), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + grad_var_name, + grad_type)); dialect::DenseTensorType grad_tensor_type = grad_type.dyn_cast(); @@ -1833,7 +1925,8 @@ struct FillConstant2FullTranscriber : public OpTranscriber { const OpDesc& op_desc) override { const auto& op_info = ctx->GetRegisteredOpInfo(dialect::FullOp::name()); if (!op_info) { - IR_THROW("Op fill_constant should have corresponding OpInfo pd_op.full"); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op fill_constant should have corresponding OpInfo pd_op.full")); } return op_info; @@ -1883,7 +1976,7 @@ struct FillConstant2FullTranscriber : public OpTranscriber { } } switch (place_type) { - case -1: + case -1: // NOLINT attribute_map["place"] = paddle::dialect::PlaceAttribute::get( ctx, phi::Place(phi::AllocationType::UNDEFINED)); break; @@ -1914,9 +2007,9 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber { const OpDesc& op_desc) override { const auto& op_info = ctx->GetRegisteredOpInfo("pd_op.full_with_tensor"); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op fill_constant should have corresponding OpInfo " - "pd_op.full_with_tensor"); + "pd_op.full_with_tensor")); } return op_info; @@ -2015,16 +2108,20 @@ struct SelectInputOpTranscriber : public OpTranscriber { std::vector op_inputs = {}; auto Mask_name = op_desc.Input("Mask")[0]; auto& Input_name = op_desc.Input("X"); - IR_ENFORCE(param_map->count(Mask_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - Mask_name); + PADDLE_ENFORCE_GT(param_map->count(Mask_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Mask_name)); op_inputs.push_back(param_map->at(Mask_name).value); for (auto in_name : Input_name) { - IR_ENFORCE(param_map->count(in_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - in_name); + PADDLE_ENFORCE_GT(param_map->count(in_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + in_name)); op_inputs.push_back(param_map->at(in_name).value); } @@ -2062,9 +2159,10 @@ struct SelectInputOpTranscriber : public OpTranscriber { 0, undefined_prefix.size()) == undefined_prefix) { // do nothing } else { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "select_input only support same type or DenseTensorType with " - "only different dim, but get dtype:[%s, %s], layout:[%s, %s], " + "only different dim, but get dtype:[%s, %s], layout:[%s, " + "%s], " "lod:[%s, %s], offset:[%s, %s].", tensor1.dtype(), tensor2.dtype(), @@ -2073,7 +2171,7 @@ struct SelectInputOpTranscriber : public OpTranscriber { tensor1.lod(), tensor2.lod(), tensor1.offset(), - tensor2.offset()); + tensor2.offset())); } auto undefined_var_type = tensor1; @@ -2083,11 +2181,13 @@ struct SelectInputOpTranscriber : public OpTranscriber { } auto undefine_value = op_inputs[1 + undefined_var_index]; - IR_ENFORCE( + PADDLE_ENFORCE_EQ( undefine_value.defining_op()->isa(), - "undefined_var %s should be generated by assign_value, but got %s", - Input_name[undefined_var_index], - undefine_value.defining_op()); + true, + phi::errors::InvalidArgument("undefined_var %s should be generated " + "by assign_value, but got %s", + Input_name[undefined_var_index], + undefine_value.defining_op())); undefine_value.set_type(target_var_type); undefine_value.defining_op()->set_attribute( @@ -2124,11 +2224,12 @@ struct SelectInputOpTranscriber : public OpTranscriber { tensor1.lod(), tensor1.offset())); } else { - IR_THROW( - "select_input only support same type or DenseTensorType with only " - "different dim, now is %s != %s.", - input1, - input2); + PADDLE_THROW( + phi::errors::InvalidArgument("select_input only support same type or " + "DenseTensorType with only " + "different dim, now is %s != %s.", + input1, + input2)); } pir::Operation* operation = pir::Operation::Create( @@ -2152,15 +2253,19 @@ struct SelectOutputOpTranscriber : public OpTranscriber { std::vector op_inputs = {}; auto Mask_name = op_desc.Input("Mask")[0]; auto& Input_name = op_desc.Input("X")[0]; - IR_ENFORCE(param_map->count(Mask_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - Mask_name); + PADDLE_ENFORCE_GT(param_map->count(Mask_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Mask_name)); op_inputs.push_back(param_map->at(Mask_name).value); - IR_ENFORCE(param_map->count(Input_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - Input_name); + PADDLE_ENFORCE_GT(param_map->count(Input_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + Input_name)); op_inputs.push_back(param_map->at(Input_name).value); pir::AttributeMap attribute_map; @@ -2169,8 +2274,10 @@ struct SelectOutputOpTranscriber : public OpTranscriber { OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types; auto Out_names = op_desc.Output("Out"); - IR_ENFORCE(Out_names.size() == 2, - "Expected SelectOutput's output size is 2."); + PADDLE_ENFORCE_EQ(Out_names.size(), + 2UL, + phi::errors::InvalidArgument( + "Expected SelectOutput's output size is 2.")); for (size_t idx = 0; idx < Out_names.size(); idx++) { VarDesc* var = op_desc.Block()->FindVarRecursive(Out_names[idx]); arg_to_idx[var->Name()] = {idx, 0}; @@ -2199,23 +2306,31 @@ pir::Value TranslateNumClassesForOneHot(pir::IrContext* ctx, if (op_desc.HasInput(legacy_tensor_name) && !op_desc.Input(legacy_tensor_name).empty()) { legacy_vars = op_desc.Input(legacy_tensor_name); - IR_ENFORCE(legacy_vars.size() == 1, - "depth_tensor input of one hot MUST be a tensor"); + PADDLE_ENFORCE_EQ(legacy_vars.size(), + 1UL, + phi::errors::InvalidArgument( + "depth_tensor input of one hot MUST be a tensor")); auto var_name = legacy_vars[0]; - IR_ENFORCE(legacy_vars.size() == 1, - "depth_tensor input of one hot MUST be a tensor"); - IR_ENFORCE(param_map->count(legacy_vars[0]), - "%s should be existed in one_hot_v2 as input depth_tensor.", - legacy_vars[0]); + PADDLE_ENFORCE_EQ(legacy_vars.size(), + 1UL, + phi::errors::InvalidArgument( + "depth_tensor input of one hot MUST be a tensor")); + PADDLE_ENFORCE_NE( + param_map->count(legacy_vars[0]), + 0UL, + phi::errors::InvalidArgument( + "%s should be existed in one_hot_v2 as input depth_tensor.", + legacy_vars[0])); auto defining_info = param_map->at(legacy_vars[0]); return defining_info.value; } auto& attribute_translator = AttributeTranslator::instance(); if (!op_desc.HasAttr(legacy_attr_name)) { - IR_THROW("Op %s arg %s should not be zero size", - op_desc.Type(), - legacy_attr_name); + PADDLE_THROW( + phi::errors::InvalidArgument("Op %s arg %s should not be zero size", + op_desc.Type(), + legacy_attr_name)); } paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" @@ -2240,14 +2355,18 @@ struct OneHotTranscriber : public OpTranscriber { pir::Attribute TranslateDtypeForArange(pir::IrContext* ctx, const OpDesc& op_desc, const OpAttributeInfo& attr_info) { - IR_ENFORCE(op_desc.Input("Start").size() == 1, - "[op:%s] Input [Start]'s size should be equal to 1", - op_desc.Type()); + PADDLE_ENFORCE_EQ( + op_desc.Input("Start").size(), + 1UL, + phi::errors::InvalidArgument( + "[op:%s] Input [Start]'s size should be equal to 1", op_desc.Type())); auto var_desc = op_desc.Block()->FindVarRecursive(op_desc.Input("Start")[0]); - IR_ENFORCE(var_desc != nullptr, - "[op:%s] Input %s should not be null", - op_desc.Type(), - op_desc.Input("Start")[0]); + PADDLE_ENFORCE_NE( + var_desc, + nullptr, + phi::errors::InvalidArgument("[op:%s] Input %s should not be null", + op_desc.Type(), + op_desc.Input("Start")[0])); auto start_proto_dtype = var_desc->GetDataType(); auto start_phi_dtype = phi::TransToPhiDataType(start_proto_dtype); auto dtype_attr = @@ -2311,15 +2430,20 @@ struct ElementwiseTranscriber : public OpTranscriber { } auto x_names = op_desc.Input("X", true); - IR_ENFORCE(x_names.size() == 1, - "Expected op[%s]'s input X has only 1 variable, but got %d", - op_desc.Type(), - x_names.size()); + PADDLE_ENFORCE_EQ( + x_names.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input X has only 1 variable, but got %d", + op_desc.Type(), + x_names.size())); auto x_name = x_names[0]; - IR_ENFORCE(param_map->count(x_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - x_name); + PADDLE_ENFORCE_GT(param_map->count(x_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + x_name)); auto x_defining_info = param_map->at(x_name); if (x_defining_info.generated_by_vector) { InsertSliceOperationForTarget( @@ -2327,30 +2451,39 @@ struct ElementwiseTranscriber : public OpTranscriber { x_defining_info = param_map->at(x_name); } pir::Value x_value = x_defining_info.value; - IR_ENFORCE(x_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - x_name); + PADDLE_ENFORCE_NE( + x_value, + nullptr, + phi::errors::PreconditionNotMet( + "Expected op[%s]'s input %s is not null", op_desc.Type(), x_name)); pir::Type x_type = x_value.type(); - IR_ENFORCE(x_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - x_name, - x_type); + PADDLE_ENFORCE_EQ( + x_type.isa(), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + x_name, + x_type)); dialect::DenseTensorType x_tensor_type = x_type.dyn_cast(); std::vector x_shape = common::vectorize(x_tensor_type.dims()); auto y_names = op_desc.Input("Y", true); - IR_ENFORCE(y_names.size() == 1, - "Expected op[%s]'s input Y has only 1 variable, but got %d", - op_desc.Type(), - y_names.size()); + PADDLE_ENFORCE_EQ( + y_names.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input Y has only 1 variable, but got %d", + op_desc.Type(), + y_names.size())); auto y_name = y_names[0]; - IR_ENFORCE(param_map->count(y_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - y_name); + PADDLE_ENFORCE_GT(param_map->count(y_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + y_name)); auto y_defining_info = param_map->at(y_name); if (y_defining_info.generated_by_vector) { InsertSliceOperationForTarget( @@ -2358,16 +2491,20 @@ struct ElementwiseTranscriber : public OpTranscriber { y_defining_info = param_map->at(y_name); } pir::Value y_value = y_defining_info.value; - IR_ENFORCE(y_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - y_name); + PADDLE_ENFORCE_NE( + y_value, + nullptr, + phi::errors::PreconditionNotMet( + "Expected op[%s]'s input %s is not null", op_desc.Type(), y_name)); pir::Type y_type = y_value.type(); - IR_ENFORCE(y_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_name, - y_type); + PADDLE_ENFORCE_EQ( + y_type.isa(), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + y_name, + y_type)); dialect::DenseTensorType y_tensor_type = y_type.dyn_cast(); std::vector y_shape = common::vectorize(y_tensor_type.dims()); @@ -2381,11 +2518,14 @@ struct ElementwiseTranscriber : public OpTranscriber { // x.rank=y.rank return {x_value, y_value}; } - IR_ENFORCE(append_size > 0, - "Expected op[%s] have append size > 0 with axis=%d but got %d", - op_desc.Type(), - axis, - append_size); + PADDLE_ENFORCE_GT( + append_size, + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s] have append size > 0 with axis=%d but got %d", + op_desc.Type(), + axis, + append_size)); pir::Builder builder(ctx, block); pir::Value y_new; @@ -2427,9 +2567,9 @@ struct GradAddOpTranscriber : public ElementwiseTranscriber { } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op assign_value should have corresponding OpInfo " - "pd_op.assign_value_"); + "pd_op.assign_value_")); } return op_info; @@ -2454,16 +2594,19 @@ struct ElementwiseGradTranscriber : public OpTranscriber { if (y_grad_output.size() < 1) { return; } - IR_ENFORCE( - y_grad_output.size() == 1, - "Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d", - op_desc.Type(), - y_grad_output.size()); + PADDLE_ENFORCE_EQ( + y_grad_output.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s output Y@GRAD has only 1 variable, but got %d", + op_desc.Type(), + y_grad_output.size())); const auto& y_grad_var_name = y_grad_output[0]; auto idx_iter = arg_to_idx.find(y_grad_var_name); if (idx_iter == arg_to_idx.end()) { - IR_THROW("op[%s] should have got its y_grad", op_desc.Type()); + PADDLE_THROW(phi::errors::InvalidArgument( + "op[%s] should have got its y_grad", op_desc.Type())); } auto [idx_in_op, idx_in_vec] = idx_iter->second; VLOG(10) << "[output recording]" @@ -2472,22 +2615,28 @@ struct ElementwiseGradTranscriber : public OpTranscriber { auto y_names = op_desc.Input("Y", true); auto y_name = y_names[0]; - IR_ENFORCE(param_map->count(y_name) > 0, - "Expected op[%s]'s input %s has been parsed", - op_desc.Type(), - y_name); + PADDLE_ENFORCE_GT(param_map->count(y_name), + 0UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s has been parsed", + op_desc.Type(), + y_name)); auto y_defining_info = param_map->at(y_name); pir::Value y_value = y_defining_info.value; - IR_ENFORCE(y_value, - "Expected op[%s]'s input %s is not null", - op_desc.Type(), - y_name); + PADDLE_ENFORCE_NE( + y_value, + nullptr, + phi::errors::PreconditionNotMet( + "Expected op[%s]'s input %s is not null", op_desc.Type(), y_name)); pir::Type y_type = y_value.type(); - IR_ENFORCE(y_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_name, - y_type); + PADDLE_ENFORCE_EQ( + y_type.isa(), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + y_name, + y_type)); dialect::DenseTensorType y_tensor_type = y_type.dyn_cast(); @@ -2495,11 +2644,14 @@ struct ElementwiseGradTranscriber : public OpTranscriber { // if y_grad' shape is same with y, we don't need a reshape pir::Type y_grad_type = value.type(); - IR_ENFORCE(y_grad_type.isa(), - "Expected op[%s]'s input %s is DenseTensor but got %s", - op_desc.Type(), - y_grad_var_name, - y_grad_type); + PADDLE_ENFORCE_EQ( + y_grad_type.isa(), + true, + phi::errors::InvalidArgument( + "Expected op[%s]'s input %s is DenseTensor but got %s", + op_desc.Type(), + y_grad_var_name, + y_grad_type)); dialect::DenseTensorType y_grad_tensor_type = y_grad_type.dyn_cast(); if (y_grad_tensor_type.dims() == y_tensor_type.dims()) { @@ -2526,9 +2678,10 @@ struct SetValueOpTranscriber : public OpTranscriber { op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name); if (!op_desc.HasAttr(legacy_attr_name)) { - IR_THROW("Op %s arg %s should not be zero size", - op_desc.Type(), - legacy_attr_name); + PADDLE_THROW( + phi::errors::InvalidArgument("Op %s arg %s should not be zero size", + op_desc.Type(), + legacy_attr_name)); } framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); VLOG(10) << "[" << op_desc.Type() << "][attribute]" @@ -2548,9 +2701,9 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { std::string target_op_name = dialect::SetValueWithTensorOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op set_value should have corresponding OpInfo " - "pd_op.set_value_with_tensor"); + "pd_op.set_value_with_tensor")); } return op_info; @@ -2568,13 +2721,17 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber { const OpInputInfo& info, pir::Block* block) -> pir::Value { std::vector legacy_input_vars; - IR_ENFORCE(op_desc.HasInput("ValueTensor"), - "[set_value] should have ValueTensor"); + PADDLE_ENFORCE_EQ( + op_desc.HasInput("ValueTensor"), + true, + phi::errors::InvalidArgument("[set_value] should have ValueTensor")); legacy_input_vars = op_desc.Input("ValueTensor", true); - IR_ENFORCE( - legacy_input_vars.size() == 1u, - "[set_value][ValueTensor] should only have 1 variable, but got %d", - legacy_input_vars.size()); + PADDLE_ENFORCE_EQ( + legacy_input_vars.size(), + 1UL, + phi::errors::InvalidArgument("[set_value][ValueTensor] should only " + "have 1 variable, but got %d", + legacy_input_vars.size())); auto var_name = legacy_input_vars[0]; auto defining_info = (*param_map)[var_name]; if (defining_info.generated_by_vector) { @@ -2593,9 +2750,9 @@ struct SetValueGradOpTranscriber : public SetValueWithTensorOpTranscriber { std::string target_op_name = dialect::SetValueWithTensorGradOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op set_value_grad should have corresponding OpInfo " - "pd_op.set_value_with_tensor_grad"); + "pd_op.set_value_with_tensor_grad")); } return op_info; @@ -2670,10 +2827,12 @@ struct FusedFeedForwardOpTranscriber : public OpTranscriber { ctx, param_map, op_desc, operation, arg_to_idx); if (op_desc.HasOutput("Out")) { const auto& output_vars = op_desc.Output("Out"); - IR_ENFORCE(output_vars.size() == 1, - "Expected op[%s]'s Out has only 1 var but got %s", - op_desc.Type(), - output_vars.size()); + PADDLE_ENFORCE_EQ(output_vars.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s Out has only 1 var but got %s", + op_desc.Type(), + output_vars.size())); auto output_var = output_vars[0]; auto fused_feedforward_op = operation->dyn_cast(); @@ -2689,9 +2848,9 @@ struct ShareBufferOpTranscriber : public OpTranscriber { std::string target_op_name = dialect::ShareDataOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op share_buffer should have corresponding OpInfo " - "pd_op.share_data"); + "pd_op.share_data")); } return op_info; @@ -2702,7 +2861,7 @@ struct RandIntOpTranscriber : public OpTranscriber { std::tuple GenerateOperationOutput( pir::IrContext* ctx, const OpDesc& op_desc, - const OpOutputInfoList& output_infos) { + const OpOutputInfoList& output_infos) override { OpOutputMapping arg_to_idx; OpOutputTypeList op_output_types = {}; @@ -2713,10 +2872,11 @@ struct RandIntOpTranscriber : public OpTranscriber { const auto& legacy_output_vars = op_desc.Output(legacy_output_name); auto& var_name = legacy_output_vars[0]; VarDesc* var = block->FindVarRecursive(var_name); - IR_ENFORCE(var != nullptr, - "[op:%s] Output %s should not be null", - op_desc.Type(), - var_name); + PADDLE_ENFORCE_NE( + var, + nullptr, + phi::errors::InvalidArgument( + "[op:%s] Output %s should not be null", op_desc.Type(), var_name)); int dtype_attr_val = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); paddle::framework::proto::VarType::Type var_type = @@ -2726,7 +2886,7 @@ struct RandIntOpTranscriber : public OpTranscriber { paddle::dialect::DenseTensorTypeStorage::Dim dim = common::make_ddim(var->GetShape()); paddle::dialect::DenseTensorTypeStorage::DataLayout layout = - paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED; + paddle::dialect::DenseTensorTypeStorage::DataLayout::NCHW; paddle::dialect::DenseTensorTypeStorage::LoD lod = {}; size_t offset = 0; pir::Type translated_var_type = paddle::dialect::DenseTensorType::get( @@ -2831,9 +2991,9 @@ struct FusedElemwiseAddActivationGradOpTranscriber const OpDesc& op_desc) override { const auto inter_out_grad = op_desc.Output("IntermediateOut@GRAD"); if (inter_out_grad.size() > 0) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "pd_op.fused_elemwise_add_activation_grad doesn't have " - "Intermediate_out_grad output"); + "Intermediate_out_grad output")); } return OpTranscriber::LookUpOpInfo(ctx, op_desc); @@ -2851,10 +3011,11 @@ struct MatrixRankOpTranscriber : public OpTranscriber { } const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( - "Op matrix_rank should have corresponding OpInfo pd_op.matrix_rank " - "or " - "pd_op.matrix_rank_tol."); + PADDLE_THROW( + phi::errors::InvalidArgument("Op matrix_rank should have " + "corresponding OpInfo pd_op.matrix_rank " + "or " + "pd_op.matrix_rank_tol.")); } return op_info; } @@ -2866,9 +3027,9 @@ struct LodArrayLengthOpTranscriber : public OpTranscriber { std::string target_op_name = dialect::ArrayLengthOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op lod_array_length should have corresponding OpInfo " - "pd_op.array_length"); + "pd_op.array_length")); } return op_info; @@ -2886,17 +3047,24 @@ struct LodArrayLengthOpTranscriber : public OpTranscriber { const OpInputInfo& info, pir::Block* block) -> pir::Value { VLOG(10) << "[" << op_desc.Type() << "][input `array`]"; - IR_ENFORCE(op_desc.HasInput("X"), - "Op lod_array_length should have input `X` but not found"); + PADDLE_ENFORCE_EQ( + op_desc.HasInput("X"), + true, + phi::errors::InvalidArgument( + "Op lod_array_length should have input `X` but not found")); const auto& vars = op_desc.Input("X"); - IR_ENFORCE(vars.size() == 1, - "Input `X` should be one variable %s", - op_desc.Type()); + PADDLE_ENFORCE_EQ( + vars.size(), + 1UL, + phi::errors::InvalidArgument("Input `X` should be one variable %s", + op_desc.Type())); VLOG(10) << "[" << op_desc.Type() << "][input `x`] from " << vars[0]; const VarDesc* var_desc = op_desc.Block()->FindVarRecursive(vars[0]); - IR_ENFORCE(var_desc != nullptr, - "VarDesc `%s` should be exist in legacy program", - vars[0]); + PADDLE_ENFORCE_NE( + var_desc, + nullptr, + phi::errors::InvalidArgument( + "VarDesc `%s` should be exist in legacy program", vars[0])); auto defining_value = pir::Value(nullptr); if (param_map->count(var_desc->Name())) { VLOG(10) << "[" << op_desc.Type() << "][input `x`] var: " << vars[0] @@ -2919,9 +3087,9 @@ struct WriteArrayOpTranscriber : public OpTranscriber { std::string target_op_name = dialect::ArrayWrite_Op::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op write_to_array should have corresponding OpInfo " - "pd_op.array_write_"); + "pd_op.array_write_")); } return op_info; @@ -2939,17 +3107,24 @@ struct WriteArrayOpTranscriber : public OpTranscriber { const OpInputInfo& info, pir::Block* block) -> pir::Value { VLOG(10) << "[" << op_desc.Type() << "][input `array`]"; - IR_ENFORCE(op_desc.HasOutput("Out"), - "Op write_to_array should have output `Out` but not found"); + PADDLE_ENFORCE_EQ( + op_desc.HasOutput("Out"), + true, + phi::errors::InvalidArgument( + "Op write_to_array should have output `Out` but not found")); const auto& vars = op_desc.Output("Out"); - IR_ENFORCE(vars.size() == 1, - "Output `Out` should be one variable %s", - op_desc.Type()); + PADDLE_ENFORCE_EQ( + vars.size(), + 1UL, + phi::errors::InvalidArgument("Output `Out` should be one variable %s", + op_desc.Type())); VLOG(10) << "[" << op_desc.Type() << "][input `array`] from " << vars[0]; const VarDesc* var_desc = op_desc.Block()->FindVarRecursive(vars[0]); - IR_ENFORCE(var_desc != nullptr, - "VarDesc `%s` should be exist in legacy program", - vars[0]); + PADDLE_ENFORCE_NE( + var_desc, + nullptr, + phi::errors::InvalidArgument( + "VarDesc `%s` should be exist in legacy program", vars[0])); auto defining_value = pir::Value(nullptr); if (param_map->count(var_desc->Name())) { VLOG(10) << "[" << op_desc.Type() << "][input `array`] var: " << vars[0] @@ -2972,9 +3147,9 @@ struct ReadArrayOpTranscriber : public OpTranscriber { std::string target_op_name = dialect::ArrayReadOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op read_from_array should have corresponding OpInfo " - "pd_op.read_array"); + "pd_op.read_array")); } return op_info; @@ -2986,30 +3161,38 @@ struct SliceOpTranscriber : public OpTranscriber { const OpDesc& op_desc) override { std::string target_op_name = dialect::SliceOp::name(); - IR_ENFORCE(op_desc.HasInput("Input"), - "op %s should have input `Input`", - op_desc.Type()); + PADDLE_ENFORCE_EQ(op_desc.HasInput("Input"), + true, + phi::errors::InvalidArgument( + "op %s should have input `Input`", op_desc.Type())); const auto& input_vars = op_desc.Input("Input"); - IR_ENFORCE(input_vars.size() == 1, - "op %s should have one input `Input`, but got %d.", - op_desc.Type(), - input_vars.size()); + PADDLE_ENFORCE_EQ(input_vars.size(), + 1UL, + phi::errors::InvalidArgument( + "op %s should have one input `Input`, but got %d.", + op_desc.Type(), + input_vars.size())); const auto* input_var = op_desc.Block()->FindVarRecursive(input_vars[0]); if (input_var->GetType() == framework::proto::VarType::LOD_TENSOR_ARRAY) { - IR_ENFORCE(op_desc.HasOutput("Out"), - "op %s should have input `Out`", - op_desc.Type()); + PADDLE_ENFORCE_EQ(op_desc.HasOutput("Out"), + true, + phi::errors::InvalidArgument( + "op %s should have input `Out`", op_desc.Type())); const auto& output_vars = op_desc.Output("Out"); - IR_ENFORCE(output_vars.size() == 1, - "op %s should have one input `Out`, but got %d.", - op_desc.Type(), - output_vars.size()); + PADDLE_ENFORCE_EQ(output_vars.size(), + 1UL, + phi::errors::InvalidArgument( + "op %s should have one input `Out`, but got %d.", + op_desc.Type(), + output_vars.size())); const auto* output_var = op_desc.Block()->FindVarRecursive(output_vars[0]); - IR_ENFORCE(output_var != nullptr, - "op %s should have non-empty output `%s`.", - op_desc.Type(), - output_vars[0]); + PADDLE_ENFORCE_NE(output_var, + nullptr, + phi::errors::InvalidArgument( + "op %s should have non-empty output `%s`.", + op_desc.Type(), + output_vars[0])); if (output_var->GetType() == framework::proto::VarType::LOD_TENSOR) { target_op_name = dialect::SliceArrayDenseOp::name(); @@ -3020,7 +3203,8 @@ struct SliceOpTranscriber : public OpTranscriber { const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW("Op slice should have corresponding OpInfo %s", target_op_name); + PADDLE_THROW(phi::errors::InvalidArgument( + "Op slice should have corresponding OpInfo %s", target_op_name)); } return op_info; @@ -3037,10 +3221,11 @@ struct LegacyMatmulOpTranscriber : public OpTranscriber { } float v = PADDLE_GET_CONST(float, op_desc.GetAttr(attr_name)); if (abs(v - expected_value) > 1e-6f) { - IR_THROW("Expected op[%s]'s attr %s is not %f", - op_desc.Type(), - attr_name, - v); + PADDLE_THROW( + phi::errors::InvalidArgument("Expected op[%s]'s attr %s is not %f", + op_desc.Type(), + attr_name, + v)); } }; @@ -3051,9 +3236,9 @@ struct LegacyMatmulOpTranscriber : public OpTranscriber { std::string target_op_name = dialect::MatmulOp::name(); const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - IR_THROW( + PADDLE_THROW(phi::errors::InvalidArgument( "Op read_from_array should have corresponding OpInfo " - "pd_op.read_array"); + "pd_op.read_array")); } return op_info; @@ -3073,14 +3258,18 @@ struct LegacyMatmulOpTranscriber : public OpTranscriber { } const auto& output_vars = op_desc.Output("Out"); - IR_ENFORCE(output_vars.size() == 1, - "Expected op[%s]'s output `Out` has only 1 variable, but got %d", - op_desc.Type(), - output_vars.size()); + PADDLE_ENFORCE_EQ( + output_vars.size(), + 1UL, + phi::errors::InvalidArgument( + "Expected op[%s]'s output `Out` has only 1 variable, but got %d", + op_desc.Type(), + output_vars.size())); auto idx_iter = arg_to_idx.find(output_vars[0]); if (idx_iter == arg_to_idx.end()) { - IR_THROW("op[%s] should have got its `Out`", op_desc.Type()); + PADDLE_THROW(phi::errors::InvalidArgument( + "op[%s] should have got its `Out`", op_desc.Type())); } auto [idx_in_op, idx_in_vec] = idx_iter->second; VLOG(10) << "[output recording]" diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 608d24a60b577..86828d0dc50d2 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -309,7 +309,7 @@ void ProgramTranslator::TranslateIfOperation( TranslationContext* translation_ctx, pir::Block* dst_block, bool for_bwd) { - VLOG(8) << "=============>Start to translate if op:" << op; + LOG_FIRST_N(INFO, 1) << "Translate ConditionalBlockOp"; auto& type_translator = TypeTranslator::instance(); auto cond_op_cond = op->Input("Cond")[0]; @@ -347,7 +347,9 @@ void ProgramTranslator::TranslateIfOperation( pir::AttributeMap attribute_map; std::vector if_op_output_types; for (auto var_desc : cond_op_output_vars) { - IR_ENFORCE(var_desc != nullptr, "[control flow] Output should not be null"); + PADDLE_ENFORCE_NOT_NULL(var_desc, + phi::errors::PreconditionNotMet( + "[control flow] Output should not be null")); pir::Type translated_var_type = type_translator[var_desc->GetType()](ctx_, *var_desc); if_op_output_types.emplace_back(translated_var_type); @@ -479,7 +481,7 @@ void ProgramTranslator::TranslateWhileOperation( const OpDesc* op, TranslationContext* translation_ctx, pir::Block* dst_block) { - VLOG(8) << "=============>Start to translate while op:" << op; + LOG_FIRST_N(INFO, 1) << "Translate WhileOp"; auto& sub_block = legacy_program_->Block(op->GetBlockAttrId("sub_block")); auto& inputs = op->Output("Out"); auto& cond_var = op->Input("Condition")[0]; @@ -684,10 +686,12 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) { pir::Block::Iterator insert_pos = std::find( block->begin(), block->end(), *defining_op_result.owner()); - IR_ENFORCE( - insert_pos != block->end(), - "Parameter %s must have corresponding its defining operation", - var_name); + PADDLE_ENFORCE_NE(insert_pos, + block->end(), + phi::errors::InvalidArgument( + "Parameter %s must have corresponding its " + "defining operation", + var_name)); insert_pos++; block->insert(insert_pos, op); diff --git a/paddle/fluid/ir_adaptor/translator/type_translator.cc b/paddle/fluid/ir_adaptor/translator/type_translator.cc index 7cd297cf46b62..4378ef5285ceb 100644 --- a/paddle/fluid/ir_adaptor/translator/type_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/type_translator.cc @@ -30,8 +30,48 @@ using DenseTensorType = paddle::dialect::DenseTensorType; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; using SelectedRowsType = paddle::dialect::SelectedRowsType; using SelectedRowsTypeStorage = paddle::dialect::SelectedRowsTypeStorage; +using DataLayout = DenseTensorTypeStorage::DataLayout; +using LoD = DenseTensorTypeStorage::LoD; TypeTranslator::TypeTranslator() { + const auto& HandleTensor = [&](pir::IrContext* ctx, + const VarDesc& var_desc) -> pir::Type { + VLOG(10) << "[vartype translating]" + << "[" << var_desc.Name() << "] from LOD_TENSOR"; + const pir::Type dtype = + this->operator[](var_desc.GetDataType())(ctx, var_desc); + const auto dim = common::make_ddim(var_desc.GetShape()); + const auto layout = DataLayout::NCHW; + const LoD lod = {}; + const size_t offset = 0; + return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset); + }; + const auto& HandleTensorArray = [&](pir::IrContext* ctx, + const VarDesc& var_desc) -> pir::Type { + VLOG(10) << "[vartype translating]" + << "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY"; + const pir::Type dtype = + this->operator[](var_desc.GetDataType())(ctx, var_desc); + const auto dims = common::make_ddim(var_desc.GetShape()); + const auto layout = DataLayout::NCHW; + return paddle::dialect::DenseTensorArrayType::get(ctx, dtype, dims, layout); + }; + + const auto& HandleSelectedRows = [&](pir::IrContext* ctx, + const VarDesc& var_desc) -> pir::Type { + VLOG(10) << "[vartype translating]" + << "[" << var_desc.Name() << "] from SELECTED_ROWS"; + const pir::Type dtype = + this->operator[](var_desc.GetDataType())(ctx, var_desc); + const auto dim = common::make_ddim(var_desc.GetShape()); + const auto layout = DataLayout::NCHW; + const LoD lod = {}; + const size_t offset = 0; + pir::Type SelectedRows = + SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset); + return SelectedRows; + }; + handlers = { {VarType::BOOL, [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { @@ -81,52 +121,9 @@ TypeTranslator::TypeTranslator() { [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { return pir::Complex128Type::get(ctx); }}, - {VarType::LOD_TENSOR, - [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { - VLOG(10) << "[vartype translating]" - << "[" << var_desc.Name() << "] from LOD_TENSOR"; - - pir::Type dtype = - this->operator[](var_desc.GetDataType())(ctx, var_desc); - DenseTensorTypeStorage::Dim dim = - common::make_ddim(var_desc.GetShape()); - DenseTensorTypeStorage::DataLayout layout = - DenseTensorTypeStorage::DataLayout::UNDEFINED; - DenseTensorTypeStorage::LoD lod = {}; - size_t offset = 0; - return DenseTensorType::get(ctx, dtype, dim, layout, lod, offset); - }}, - {VarType::LOD_TENSOR_ARRAY, - [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { - VLOG(10) << "[vartype translating]" - << "[" << var_desc.Name() << "] from LOD_TENSOR_ARRAY"; - pir::Type dtype = - this->operator[](var_desc.GetDataType())(ctx, var_desc); - phi::DDim dims = common::make_ddim(var_desc.GetShape()); - DenseTensorTypeStorage::DataLayout layout = - DenseTensorTypeStorage::DataLayout::UNDEFINED; - - return paddle::dialect::DenseTensorArrayType::get( - ctx, dtype, dims, layout); - }}, - {VarType::SELECTED_ROWS, - [&](pir::IrContext* ctx, const VarDesc& var_desc) -> pir::Type { - VLOG(10) << "[vartype translating]" - << "[" << var_desc.Name() << "] from SELECTED_ROWS"; - - pir::Type dtype = - this->operator[](var_desc.GetDataType())(ctx, var_desc); - - SelectedRowsTypeStorage::Dim dim = - common::make_ddim(var_desc.GetShape()); - SelectedRowsTypeStorage::DataLayout layout = - SelectedRowsTypeStorage::DataLayout::UNDEFINED; - SelectedRowsTypeStorage::LoD lod = {}; - size_t offset = 0; - pir::Type SelectedRows = - SelectedRowsType::get(ctx, dtype, dim, layout, lod, offset); - return SelectedRows; - }}, + {VarType::LOD_TENSOR, HandleTensor}, + {VarType::LOD_TENSOR_ARRAY, HandleTensorArray}, + {VarType::SELECTED_ROWS, HandleSelectedRows}, }; } diff --git a/paddle/fluid/jit/compilation_unit.cc b/paddle/fluid/jit/compilation_unit.cc index 110f012c8e361..be22dfc104165 100644 --- a/paddle/fluid/jit/compilation_unit.cc +++ b/paddle/fluid/jit/compilation_unit.cc @@ -41,7 +41,7 @@ const jit::EngineMap &CompilationUnit::EngineMap() const { return engine_map_; } std::shared_ptr CompilationUnit::Clone(void *stream) { auto x = std::make_shared(); for (auto &it : engine_map_) { - x->SetEngine(it.first, std::move(it.second->Clone(stream))); + x->SetEngine(it.first, it.second->Clone(stream)); } return x; } diff --git a/paddle/fluid/jit/engine/interpreter_engine.cc b/paddle/fluid/jit/engine/interpreter_engine.cc index 5650b45980f69..e8f622641c33b 100644 --- a/paddle/fluid/jit/engine/interpreter_engine.cc +++ b/paddle/fluid/jit/engine/interpreter_engine.cc @@ -86,7 +86,6 @@ std::vector InterpreterEngine::operator()( // the latter can be moved to python side. auto &feed_names = info_->InputArgNames(); - auto &fetch_names = info_->OutputArgNames(); paddle::framework::FetchList outs = inner_interpreter_->Run(feed_names); std::vector outputs; diff --git a/paddle/fluid/jit/engine/predictor_engine.cc b/paddle/fluid/jit/engine/predictor_engine.cc index 847018e07e51c..a753adc51a540 100644 --- a/paddle/fluid/jit/engine/predictor_engine.cc +++ b/paddle/fluid/jit/engine/predictor_engine.cc @@ -66,8 +66,8 @@ PredictorEngine::PredictorEngine( predictor)) {} std::unique_ptr PredictorEngine::Clone(void *stream) { - auto *x = new PredictorEngine( - info_, scope_, place_, std::move(predictor_->Clone(stream))); + auto *x = + new PredictorEngine(info_, scope_, place_, predictor_->Clone(stream)); return std::unique_ptr(x); } diff --git a/paddle/fluid/jit/property.cc b/paddle/fluid/jit/property.cc index 687468df83a3d..37c426bb5401b 100644 --- a/paddle/fluid/jit/property.cc +++ b/paddle/fluid/jit/property.cc @@ -99,7 +99,7 @@ std::unordered_map> Property::Values() { case ValueProto::STRING: *var->GetMutable() = GetString(n); break; - case ValueProto::FLOATS: + case ValueProto::FLOATS: // NOLINT *var->GetMutable>() = GetFloats(n); break; case ValueProto::INTS: diff --git a/paddle/fluid/memory/allocation/CMakeLists.txt b/paddle/fluid/memory/allocation/CMakeLists.txt index 1cde959d49d56..c3e51e508b103 100644 --- a/paddle/fluid/memory/allocation/CMakeLists.txt +++ b/paddle/fluid/memory/allocation/CMakeLists.txt @@ -11,6 +11,7 @@ set(ALLOCATOR_SRCS allocator_strategy.cc allocator_facade.cc auto_growth_best_fit_allocator.cc + auto_growth_best_fit_allocator_v2.cc virtual_memory_auto_growth_best_fit_allocator.cc retry_allocator.cc memory_block.cc diff --git a/paddle/fluid/memory/allocation/allocator_facade.cc b/paddle/fluid/memory/allocation/allocator_facade.cc index eff0a1891ed7b..028fd3425dc84 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.cc +++ b/paddle/fluid/memory/allocation/allocator_facade.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h" +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" #include "paddle/fluid/memory/allocation/cpu_allocator.h" #include "paddle/fluid/memory/allocation/naive_best_fit_allocator.h" #include "paddle/fluid/memory/allocation/retry_allocator.h" @@ -39,8 +40,10 @@ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_context.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" #endif #if CUDA_VERSION >= 10020 @@ -49,6 +52,10 @@ #include "paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h" #include "paddle/fluid/platform/dynload/cuda_driver.h" #endif + +#ifdef PADDLE_WITH_HIP +#include "paddle/fluid/memory/allocation/cuda_malloc_async_allocator.h" // NOLINT +#endif #endif #ifdef PADDLE_WITH_XPU @@ -97,6 +104,12 @@ PADDLE_DEFINE_EXPORTED_bool(use_cuda_managed_memory, "managed memory, only available for auto_growth " "strategy"); +PADDLE_DEFINE_EXPORTED_bool( + use_auto_growth_v2, + false, + "Whether to use AutoGrowthBestFitAllocatorV2 for auto_growth " + "strategy"); + COMMON_DECLARE_string(allocator_strategy); COMMON_DECLARE_uint64(auto_growth_chunk_size_in_mb); COMMON_DECLARE_bool(use_auto_growth_pinned_allocator); @@ -107,7 +120,7 @@ namespace paddle { namespace memory { namespace allocation { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class CUDAGraphAllocator : public Allocator, public std::enable_shared_from_this { @@ -158,7 +171,7 @@ class CUDAGraphAllocator #endif static bool IsCUDAGraphCapturing() { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()); #else return false; @@ -189,6 +202,7 @@ class AllocatorFacadePrivate { strategy_ = GetAllocatorStrategy(); is_stream_safe_cuda_allocator_used_ = false; is_cuda_malloc_async_allocator_used_ = false; + VLOG(2) << "selected allocator strategy:" << int(strategy_) << std::endl; switch (strategy_) { case AllocatorStrategy::kNaiveBestFit: { InitNaiveBestFitCPUAllocator(); @@ -232,7 +246,7 @@ class AllocatorFacadePrivate { // Note(Ruibiao): For GPU multi-stream case without CUDA graph // capturing, the 'allocators_' map(place -> Allocator) hold the - // StreamSafeCUDAAllocator relate to defaultstream (i.e., the stream + // StreamSafeCUDAAllocator relate to default stream (i.e., the stream // directly got from DeviceContext), while the 'cuda_allocators_' map // (place -> map(stream -> Allocator)) hold the StreamSafeCUDAAllocator // relate to non-default stream (i.e., the stream users pass in). The @@ -328,7 +342,7 @@ class AllocatorFacadePrivate { CheckAllocThreadSafe(); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // No need to wrap CUDAGraphAllocator for StreamSafeCUDAAllocator if (!is_stream_safe_cuda_allocator_used_ && UNLIKELY(IsCUDAGraphCapturing())) { @@ -880,11 +894,22 @@ class AllocatorFacadePrivate { << FLAGS_auto_growth_chunk_size_in_mb; #if defined(PADDLE_WITH_HIP) auto cuda_allocator = CreateCUDAAllocator(p); - cuda_allocators_[p][stream] = std::make_shared( - cuda_allocator, - platform::GpuMinChunkSize(), - chunk_size, - allow_free_idle_chunk_); + if (FLAGS_use_auto_growth_v2) { + cuda_allocators_[p][stream] = + std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + p, + chunk_size, + allow_free_idle_chunk_); + } else { + cuda_allocators_[p][stream] = + std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + chunk_size, + allow_free_idle_chunk_); + } #endif #if defined(PADDLE_WITH_CUDA) @@ -911,12 +936,22 @@ class AllocatorFacadePrivate { cuda_allocator, platform::GpuMinChunkSize(), p); } else { auto cuda_allocator = CreateCUDAAllocator(p); - cuda_allocators_[p][stream] = - std::make_shared( - cuda_allocator, - platform::GpuMinChunkSize(), - /*chunk_size=*/chunk_size, - allow_free_idle_chunk_); + if (FLAGS_use_auto_growth_v2) { + cuda_allocators_[p][stream] = + std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + p, + /*chunk_size=*/chunk_size, + allow_free_idle_chunk_); + } else { + cuda_allocators_[p][stream] = + std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + /*chunk_size=*/chunk_size, + allow_free_idle_chunk_); + } } #else auto cuda_allocator = CreateCUDAAllocator(p); @@ -951,9 +986,21 @@ class AllocatorFacadePrivate { VLOG(10) << "not use AlignedAllocator with alignment: " << alignment; underlying_allocator = cuda_allocator; } - - cuda_allocators_[p][stream] = std::make_shared( - underlying_allocator, alignment, chunk_size, allow_free_idle_chunk_); + if (FLAGS_use_auto_growth_v2) { + cuda_allocators_[p][stream] = + std::make_shared( + underlying_allocator, + alignment, + p, + chunk_size, + allow_free_idle_chunk_); + } else { + cuda_allocators_[p][stream] = + std::make_shared(underlying_allocator, + alignment, + chunk_size, + allow_free_idle_chunk_); + } #endif #endif } @@ -966,11 +1013,20 @@ class AllocatorFacadePrivate { << FLAGS_auto_growth_chunk_size_in_mb; #if defined(PADDLE_WITH_HIP) auto cuda_allocator = CreateCUDAAllocator(p); - allocators_[p] = std::make_shared( - cuda_allocator, - platform::GpuMinChunkSize(), - /*chunk_size=*/chunk_size, - allow_free_idle_chunk); + if (FLAGS_use_auto_growth_v2) { + allocators_[p] = std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + p, + /*chunk_size=*/chunk_size, + allow_free_idle_chunk); + } else { + allocators_[p] = std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + /*chunk_size=*/chunk_size, + allow_free_idle_chunk); + } #endif #if defined(PADDLE_WITH_CUDA) @@ -997,11 +1053,20 @@ class AllocatorFacadePrivate { cuda_allocator, platform::GpuMinChunkSize(), p); } else { auto cuda_allocator = CreateCUDAAllocator(p); - allocators_[p] = std::make_shared( - cuda_allocator, - platform::GpuMinChunkSize(), - /*chunk_size=*/chunk_size, - allow_free_idle_chunk); + if (FLAGS_use_auto_growth_v2) { + allocators_[p] = std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + p, + /*chunk_size=*/chunk_size, + allow_free_idle_chunk); + } else { + allocators_[p] = std::make_shared( + cuda_allocator, + platform::GpuMinChunkSize(), + /*chunk_size=*/chunk_size, + allow_free_idle_chunk); + } } #else @@ -1037,8 +1102,17 @@ class AllocatorFacadePrivate { VLOG(10) << "not use AlignedAllocator with alignment: " << alignment; underlying_allocator = cuda_allocator; } - allocators_[p] = std::make_shared( - underlying_allocator, alignment, chunk_size, allow_free_idle_chunk); + if (FLAGS_use_auto_growth_v2) { + allocators_[p] = + std::make_shared(underlying_allocator, + alignment, + p, + chunk_size, + allow_free_idle_chunk); + } else { + allocators_[p] = std::make_shared( + underlying_allocator, alignment, chunk_size, allow_free_idle_chunk); + } #endif #endif } @@ -1119,7 +1193,7 @@ class AllocatorFacadePrivate { allocator = std::make_shared(allocator); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void WrapCUDAGraphAllocator() { for (auto& item : allocators_) { auto& allocator = item.second; @@ -1289,7 +1363,11 @@ class AllocatorFacadePrivate { auto alignment = phi::DeviceManager::GetMinChunkSize(p); custom_device_allocators_[p][stream] = std::make_shared( - custom_allocator, alignment, chunk_size, allow_free_idle_chunk_); + custom_allocator, + alignment, + chunk_size, + allow_free_idle_chunk_, + phi::DeviceManager::GetExtraPaddingSize(p)); } void InitAutoGrowthCustomDeviceAllocator(platform::CustomPlace p, @@ -1303,7 +1381,8 @@ class AllocatorFacadePrivate { custom_allocator, phi::DeviceManager::GetMinChunkSize(p), /*chunk_size=*/chunk_size, - allow_free_idle_chunk); + allow_free_idle_chunk, + phi::DeviceManager::GetExtraPaddingSize(p)); } void WrapStreamSafeCustomDeviceAllocatorForDefault() { @@ -1505,7 +1584,7 @@ AllocatorFacade& AllocatorFacade::Instance() { } AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // if we use cuda_malloc_async_allocator, we don't need to open a private pool // for each graph if (UNLIKELY(IsCUDAGraphCapturing()) && @@ -1696,7 +1775,7 @@ void AllocatorFacade::SetDefaultStream(const platform::CUDAPlace& place, } } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) { PADDLE_ENFORCE_EQ(GetAllocatorStrategy(), AllocatorStrategy::kAutoGrowth, diff --git a/paddle/fluid/memory/allocation/allocator_facade.h b/paddle/fluid/memory/allocation/allocator_facade.h index f80fcac1b2a38..de26eae6eb4ba 100644 --- a/paddle/fluid/memory/allocation/allocator_facade.h +++ b/paddle/fluid/memory/allocation/allocator_facade.h @@ -49,11 +49,12 @@ class AllocatorFacade { const AllocatorFacade& operator=(const AllocatorFacade& o) = delete; ~AllocatorFacade(); - static AllocatorFacade& Instance(); + TEST_API static AllocatorFacade& Instance(); AllocatorFacadePrivate* GetPrivate() const; - const std::shared_ptr& GetAllocator(const platform::Place& place); + TEST_API const std::shared_ptr& GetAllocator( + const platform::Place& place); void* GetBasePtr(const std::shared_ptr& allocation); @@ -88,13 +89,13 @@ class AllocatorFacade { void RecordStream(std::shared_ptr allocation, gpuStream_t stream); void EraseStream(std::shared_ptr allocation, gpuStream_t stream); - const std::shared_ptr& GetAllocator(const platform::Place& place, - gpuStream_t stream); + TEST_API const std::shared_ptr& GetAllocator( + const platform::Place& place, gpuStream_t stream); gpuStream_t GetStream(const std::shared_ptr& allocation) const; void SetDefaultStream(const platform::CUDAPlace& place, gpuStream_t stream); #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void PrepareMemoryPoolForCUDAGraph(int64_t id); void RemoveMemoryPoolOfCUDAGraph(int64_t id); #endif @@ -104,8 +105,8 @@ class AllocatorFacade { phi::stream::stream_t stream); void RecordStream(std::shared_ptr allocation, phi::stream::stream_t stream); - const std::shared_ptr& GetAllocator(const platform::Place& place, - phi::stream::stream_t stream); + TEST_API const std::shared_ptr& GetAllocator( + const platform::Place& place, phi::stream::stream_t stream); phi::stream::stream_t GetStream( const std::shared_ptr& allocation) const; void SetDefaultStream(const platform::CustomPlace& place, @@ -115,7 +116,7 @@ class AllocatorFacade { private: AllocatorFacade(); AllocatorFacadePrivate* m_; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::unordered_map> cuda_graph_map_; std::unordered_map cuda_graph_ref_cnt_; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc index a00b02ab9e01d..2dcc1295fec25 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.cc @@ -20,6 +20,7 @@ #include "paddle/fluid/memory/allocation/aligned_allocator.h" #include "paddle/fluid/platform/flags.h" #include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/backends/device_manager.h" PADDLE_DEFINE_EXPORTED_READONLY_bool( free_idle_chunk, @@ -40,7 +41,6 @@ PADDLE_DEFINE_EXPORTED_READONLY_bool( PADDLE_DEFINE_EXPORTED_READONLY_bool(print_allocator_trace_info, false, "print trace memory info"); - namespace paddle { namespace memory { namespace allocation { @@ -49,11 +49,13 @@ AutoGrowthBestFitAllocator::AutoGrowthBestFitAllocator( const std::shared_ptr &underlying_allocator, size_t alignment, size_t chunk_size, - bool allow_free_idle_chunk) + bool allow_free_idle_chunk, + int extra_padding_size) : underlying_allocator_(underlying_allocator), alignment_(alignment), chunk_size_(std::max(AlignedSize(chunk_size, alignment), alignment)), - allow_free_idle_chunk_(allow_free_idle_chunk) { + allow_free_idle_chunk_(allow_free_idle_chunk), + extra_padding_size_(extra_padding_size) { total_alloc_times_ = 0; total_alloc_size_ = 0; total_free_times_ = 0; @@ -66,8 +68,11 @@ phi::Allocation *AutoGrowthBestFitAllocator::AllocateImpl( platform::RecordEvent record("AutoGrowthBestFitAllocator::Allocate", platform::TracerEventType::UserDefined, 9 /*level*/); - size_t size = AlignedSize(unaligned_size, alignment_); - VLOG(10) << "Allocate " << unaligned_size << " bytes, aligned to " << size; + + size_t size = AlignedSize(unaligned_size + extra_padding_size_, alignment_); + + VLOG(10) << "Allocate " << unaligned_size << " bytes, aligned to " << size + << ", extra size " << extra_padding_size_; std::lock_guard guard(spinlock_); auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h index 138f4a98c4db5..572ca695cef9a 100644 --- a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h @@ -33,7 +33,8 @@ class AutoGrowthBestFitAllocator : public Allocator { const std::shared_ptr &underlying_allocator, size_t alignment, size_t chunk_size = 0, - bool allow_free_idle_chunk = true); + bool allow_free_idle_chunk = true, + int extra_padding_size = 0); bool IsAllocThreadSafe() const override { return true; } @@ -47,7 +48,7 @@ class AutoGrowthBestFitAllocator : public Allocator { return FreeIdleChunks(); } - private: + protected: uint64_t FreeIdleChunks(); void Trace() const; @@ -93,6 +94,7 @@ class AutoGrowthBestFitAllocator : public Allocator { size_t alignment_; size_t chunk_size_; bool allow_free_idle_chunk_; + int extra_padding_size_; // stat info size_t total_alloc_times_; diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc new file mode 100644 index 0000000000000..4565effc375b3 --- /dev/null +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.cc @@ -0,0 +1,170 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" + +#include +#include // NOLINT + +#include "paddle/fluid/memory/allocation/aligned_allocator.h" +#include "paddle/fluid/platform/cuda_device_guard.h" +#include "paddle/fluid/platform/device/gpu/gpu_info.h" +#include "paddle/fluid/platform/flags.h" +#include "paddle/fluid/platform/profiler/event_tracing.h" +#include "paddle/phi/backends/device_manager.h" + +PD_DECLARE_bool(free_idle_chunk); +PD_DECLARE_bool(free_when_no_cache_hit); + +namespace paddle { +namespace memory { +namespace allocation { + +AutoGrowthBestFitAllocatorV2::AutoGrowthBestFitAllocatorV2( + const std::shared_ptr &underlying_allocator, + size_t alignment, + platform::CUDAPlace place, + size_t chunk_size, + bool allow_free_idle_chunk, + int extra_padding_size) + : AutoGrowthBestFitAllocator(underlying_allocator, + alignment, + chunk_size, + true, + extra_padding_size), + place_(place) {} + +phi::Allocation *AutoGrowthBestFitAllocatorV2::AllocateImpl( + size_t unaligned_size) { + platform::RecordEvent record("AutoGrowthBestFitAllocatorV2::Allocate", + platform::TracerEventType::UserDefined, + 9 /*level*/); + + size_t size = AlignedSize(unaligned_size + extra_padding_size_, alignment_); + + VLOG(10) << "Allocate " << unaligned_size << " bytes, aligned to " << size + << ", extra size " << extra_padding_size_; + + std::lock_guard guard(spinlock_); + + BlockIt block_it; + if (AutoGrowthBestFitAllocatorV2State::GetInstance().IsWarmup()) { + auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); + if (iter != free_blocks_.end() && iter->second->size_ >= unaligned_size && + iter->second->size_ <= size) { + block_it = iter->second; + free_blocks_.erase(iter); + block_it->is_free_ = false; + VLOG(10) << "Allocate " << size << " bytes from chunk size " + << block_it->size_ << " by strict_matching_state."; + } else { + size_t actual_avail, actual_total; + { + platform::CUDADeviceGuard guard(place_.device); +#ifdef PADDLE_WITH_HIP + auto result = hipMemGetInfo(&actual_avail, &actual_total); +#else + auto result = cudaMemGetInfo(&actual_avail, &actual_total); +#endif + if (result != gpuSuccess) { + actual_avail = 0; + } + } + + if (actual_avail < size) { + FreeIdleChunks(); + } + + chunks_.emplace_back(static_unique_ptr_cast( + underlying_allocator_->Allocate(size))); + + auto *chunk = &(*chunks_.rbegin()); + size = chunk->allocation_->size(); + uint8_t *p = reinterpret_cast(chunk->allocation_->ptr()); + auto &blocks = chunk->blocks_; + blocks.emplace_back(p, size, false, chunk); + block_it = --(blocks.end()); + VLOG(2) << "Not found and reallocate " << size << "(" + << static_cast(p) << ") by strict_matching_state."; + } + } else { + if (is_first_switch_to_regular_) { + FreeIdleChunks(); + is_first_switch_to_regular_ = false; + } + auto iter = free_blocks_.lower_bound(std::make_pair(size, nullptr)); + + if (iter != free_blocks_.end()) { + block_it = iter->second; + free_blocks_.erase(iter); + auto *chunk = block_it->chunk_; + size_t remaining_size = block_it->size_ - size; + VLOG(10) << "Allocate " << size << " bytes from chunk size " + << block_it->size_ << ", remaining " << remaining_size; + if (remaining_size == 0) { + block_it->is_free_ = false; + } else { + auto remaining_free_block = chunk->blocks_.insert( + block_it, Block(block_it->ptr_, remaining_size, true, chunk)); + free_blocks_.emplace(std::make_pair(remaining_size, block_it->ptr_), + remaining_free_block); + block_it->ptr_ = + reinterpret_cast(block_it->ptr_) + remaining_size; + block_it->size_ = size; + block_it->is_free_ = false; + } + } else { + if (FLAGS_free_when_no_cache_hit) { + FreeIdleChunks(); + } + size_t realloc_size = std::max(size, chunk_size_); + + try { + chunks_.emplace_back(static_unique_ptr_cast( + underlying_allocator_->Allocate(realloc_size))); + } catch (BadAlloc &ex) { + if (FLAGS_free_when_no_cache_hit) throw ex; + FreeIdleChunks(); + chunks_.emplace_back(static_unique_ptr_cast( + underlying_allocator_->Allocate(realloc_size))); + } + + auto *chunk = &(*chunks_.rbegin()); + realloc_size = chunk->allocation_->size(); + uint8_t *p = reinterpret_cast(chunk->allocation_->ptr()); + auto &blocks = chunk->blocks_; + + size_t remaining_size = realloc_size - size; + if (remaining_size > 0) { + blocks.emplace_back(p, remaining_size, true, chunk); + free_blocks_.emplace(std::make_pair(remaining_size, p), + --(blocks.end())); + } + blocks.emplace_back(p + remaining_size, size, false, chunk); + block_it = --(blocks.end()); + VLOG(2) << "Not found and reallocate " << realloc_size << "(" + << static_cast(p) << "), and remaining " + << remaining_size; + } + } + ++total_alloc_times_; + total_alloc_size_ += size; + VLOG(10) << "Alloc " << block_it->size_ << " bytes, ptr = " << block_it->ptr_; + return new BlockAllocation(block_it); +} + +} // namespace allocation +} // namespace memory +} // namespace paddle +#endif diff --git a/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h new file mode 100644 index 0000000000000..82d818e1c1a47 --- /dev/null +++ b/paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h @@ -0,0 +1,71 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include +#include +#include +#include // NOLINT +#include + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator.h" +#include "paddle/fluid/memory/allocation/spin_lock.h" + +namespace paddle { +namespace memory { +namespace allocation { + +class AutoGrowthBestFitAllocatorV2 : public AutoGrowthBestFitAllocator { + public: + AutoGrowthBestFitAllocatorV2( + const std::shared_ptr &underlying_allocator, + size_t alignment, + platform::CUDAPlace place, + size_t chunk_size = 0, + bool allow_free_idle_chunk = true, + int extra_padding_size = 0); + + protected: + phi::Allocation *AllocateImpl(size_t size) override; + + private: + platform::CUDAPlace place_; + bool is_first_switch_to_regular_{true}; +}; + +class AutoGrowthBestFitAllocatorV2State { + public: + AutoGrowthBestFitAllocatorV2State() = default; + + ~AutoGrowthBestFitAllocatorV2State() {} + + void SetWarmup(bool warmup) { is_warmup_ = warmup; } + + bool IsWarmup() { return is_warmup_; } + + static AutoGrowthBestFitAllocatorV2State &GetInstance() { + static AutoGrowthBestFitAllocatorV2State instance; + return instance; + } + + private: + bool is_warmup_{true}; +}; + +} // namespace allocation +} // namespace memory +} // namespace paddle +#endif diff --git a/paddle/fluid/memory/allocation/buddy_allocator.cc b/paddle/fluid/memory/allocation/buddy_allocator.cc index a582955c5d81d..7d4d09c6cd28d 100644 --- a/paddle/fluid/memory/allocation/buddy_allocator.cc +++ b/paddle/fluid/memory/allocation/buddy_allocator.cc @@ -60,8 +60,10 @@ BuddyAllocator::BuddyAllocator( #endif } #endif + VLOG(1) << "min_chunk_size_: " << min_chunk_size_ - << ", max_chunk_size_:" << max_chunk_size_; + << ", max_chunk_size_:" << max_chunk_size_ + << ", extra_padding_size_: " << extra_padding_size_; } BuddyAllocator::~BuddyAllocator() { @@ -86,15 +88,9 @@ inline size_t align(size_t size, size_t alignment) { void* BuddyAllocator::Alloc(size_t unaligned_size) { // adjust allocation alignment - size_t size = align(unaligned_size + sizeof(MemoryBlock::Desc) + extra_padding_size_, min_chunk_size_); -#ifdef PADDLE_WITH_CUSTOM_DEVICE - if (use_custom_device_) { - size = align(unaligned_size + extra_padding_size_, min_chunk_size_); - } -#endif VLOG(10) << "alloc: " << unaligned_size << ", padding for desc: " << sizeof(MemoryBlock::Desc) << ", extra padding: " << extra_padding_size_ diff --git a/paddle/fluid/memory/allocation/cuda_ipc_allocator.cc b/paddle/fluid/memory/allocation/cuda_ipc_allocator.cc index df62c112681b1..be3f578f4942f 100644 --- a/paddle/fluid/memory/allocation/cuda_ipc_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_ipc_allocator.cc @@ -47,17 +47,16 @@ std::shared_ptr GetIpcBasePtr(std::string handle) { // The IpcMemHandle can only open once for the same handle, // so here we cache it here. void *baseptr = nullptr; - auto ipc_handle = - reinterpret_cast(handle.c_str()); - PADDLE_ENFORCE_GPU_SUCCESS(cudaIpcOpenMemHandle( - &baseptr, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); + auto ipc_handle = reinterpret_cast(handle.c_str()); + PADDLE_ENFORCE_GPU_SUCCESS(gpuIpcOpenMemHandle( + &baseptr, *ipc_handle, gpuIpcMemLazyEnablePeerAccess)); // Close ipc handle on the same device. int device_id = platform::GetCurrentDeviceId(); // Add deleter to close ipc handle. auto sp = std::shared_ptr(baseptr, [handle, device_id](void *ptr) { platform::CUDADeviceGuard guard(device_id); std::lock_guard lock(ipc_mutex_); - PADDLE_ENFORCE_GPU_SUCCESS(cudaIpcCloseMemHandle(ptr)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuIpcCloseMemHandle(ptr)); ipc_handle_to_baseptr_.erase(handle); VLOG(6) << "cudaIpcCloseMemHandle for ptr:" << "\t" << ptr; diff --git a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc index cdc3f60da7c7e..7e0c513f5c81c 100644 --- a/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_malloc_async_allocator.cc @@ -27,7 +27,11 @@ #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" +#if defined(PADDLE_WITH_CUDA) #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" +#endif namespace paddle { namespace memory { @@ -47,11 +51,11 @@ void CUDAMallocAsyncAllocation::RecordStreamWithNoGraphCapturing( if (event_map_.find(stream) == event_map_.end()) { gpuEvent_t event; PADDLE_ENFORCE_GPU_SUCCESS( - cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event, stream)); + gpuEventCreateWithFlags(&event, gpuEventDisableTiming)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventRecord(event, stream)); event_map_[stream] = event; } else { - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(event_map_[stream], stream)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventRecord(event_map_[stream], stream)); } } @@ -93,16 +97,16 @@ bool CUDAMallocAsyncAllocation::CanBeFreed(bool synchronize) { for (auto it = event_map_.begin(); it != event_map_.end();) { gpuEvent_t& event = it->second; if (synchronize) { - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(event)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventSynchronize(event)); } else { - gpuError_t err = cudaEventQuery(event); - if (err == cudaErrorNotReady) { + gpuError_t err = gpuEventQuery(event); + if (err == gpuErrorNotReady) { VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; return false; } PADDLE_ENFORCE_GPU_SUCCESS(err); } - PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(event)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuEventDestroy(event)); VLOG(8) << "Destroy event " << event; it = event_map_.erase(it); } @@ -117,7 +121,7 @@ CUDAMallocAsyncAllocator::CUDAMallocAsyncAllocator( place_(place), default_stream_(default_stream) { PADDLE_ENFORCE_GPU_SUCCESS( - cudaStreamCreateWithPriority(&memory_stream_, cudaStreamNonBlocking, 0)); + gpuStreamCreateWithPriority(&memory_stream_, gpuStreamNonBlocking, 0)); } bool CUDAMallocAsyncAllocator::IsAllocThreadSafe() const { return true; } diff --git a/paddle/fluid/memory/allocation/cuda_managed_allocator.cc b/paddle/fluid/memory/allocation/cuda_managed_allocator.cc index 77ca495cacbc7..36659fdbadce2 100644 --- a/paddle/fluid/memory/allocation/cuda_managed_allocator.cc +++ b/paddle/fluid/memory/allocation/cuda_managed_allocator.cc @@ -65,7 +65,7 @@ phi::Allocation* CUDAManagedAllocator::AllocateImpl(size_t size) { std::string err_msg; if (UNLIKELY(is_limited)) { - int64_t limit_size_mb = limit_size >> 20; + int64_t limit_size_mb = limit_size >> 20; // NOLINT err_msg = string::Sprintf( "Or set environment variable `FLAGS_gpu_memory_limit_mb` to a larger " "value. Currently `FLAGS_gpu_memory_limit_mb` is %d, so the maximum " diff --git a/paddle/fluid/memory/allocation/custom_allocator.cc b/paddle/fluid/memory/allocation/custom_allocator.cc index b4c3ebe1b2926..36848ff9cf0b0 100644 --- a/paddle/fluid/memory/allocation/custom_allocator.cc +++ b/paddle/fluid/memory/allocation/custom_allocator.cc @@ -16,6 +16,10 @@ #include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/profiler.h" +#include "paddle/fluid/platform/profiler/trace_event.h" + +COMMON_DECLARE_bool(custom_device_mem_record); namespace paddle { namespace memory { @@ -33,6 +37,14 @@ void CustomAllocator::FreeImpl(phi::Allocation* allocation) { phi::DeviceManager::GetDeviceWithPlace(place_)->MemoryDeallocate( allocation->ptr(), allocation->size()); } + if (FLAGS_custom_device_mem_record) { + DEVICE_MEMORY_STAT_UPDATE( + Reserved, place_.GetDeviceId(), -allocation->size()); + platform::RecordMemEvent(allocation->ptr(), + place_, + allocation->size(), + platform::TracerMemEventType::ReservedFree); + } delete allocation; } @@ -42,6 +54,11 @@ phi::Allocation* CustomAllocator::AllocateImpl(size_t size) { void* ptr = phi::DeviceManager::GetDeviceWithPlace(place_)->MemoryAllocate(size); if (LIKELY(ptr)) { + if (FLAGS_custom_device_mem_record) { + DEVICE_MEMORY_STAT_UPDATE(Reserved, place_.GetDeviceId(), size); + platform::RecordMemEvent( + ptr, place_, size, platform::TracerMemEventType::ReservedAllocate); + } return new Allocation(ptr, size, place_); } diff --git a/paddle/fluid/memory/allocation/memory_block.cc b/paddle/fluid/memory/allocation/memory_block.cc index 0f0a81cf9d118..26a2310c17e27 100644 --- a/paddle/fluid/memory/allocation/memory_block.cc +++ b/paddle/fluid/memory/allocation/memory_block.cc @@ -43,7 +43,9 @@ MemoryBlock* MemoryBlock::GetRightBuddy(MetadataCache* cache) { return cache->LoadDesc(this)->right_buddy; } -void MemoryBlock::Split(MetadataCache* cache, size_t size) { +void MemoryBlock::Split(MetadataCache* cache, + size_t size, + size_t extra_padding_size) { auto desc = cache->LoadDesc(this); // make sure the split fits PADDLE_ENFORCE_GE(desc->total_size, @@ -54,8 +56,10 @@ void MemoryBlock::Split(MetadataCache* cache, size_t size) { desc->total_size, size)); + size_t pay_load_size = sizeof(MemoryBlock::Desc) + extra_padding_size; + // bail out if there is no room for another partition - if (desc->total_size - size <= sizeof(MemoryBlock::Desc)) { + if (desc->total_size - size <= pay_load_size) { return; } @@ -71,13 +75,13 @@ void MemoryBlock::Split(MetadataCache* cache, size_t size) { cache->Save(static_cast(right_partition), MemoryBlock::Desc(FREE_CHUNK, desc->index, - remaining_size - sizeof(MemoryBlock::Desc), + remaining_size - pay_load_size, remaining_size, this, new_block_right_buddy)); desc->right_buddy = static_cast(right_partition); - desc->size = size - sizeof(MemoryBlock::Desc); + desc->size = size - pay_load_size; desc->total_size = size; desc->UpdateGuards(); diff --git a/paddle/fluid/memory/allocation/memory_block.h b/paddle/fluid/memory/allocation/memory_block.h index 1ddf88ce8b47c..631fca44f5157 100644 --- a/paddle/fluid/memory/allocation/memory_block.h +++ b/paddle/fluid/memory/allocation/memory_block.h @@ -50,7 +50,7 @@ struct MemoryBlock { MemoryBlock* GetRightBuddy(MetadataCache* cache); // Split the allocation into left/right blocks. - void Split(MetadataCache* cache, size_t size); + void Split(MetadataCache* cache, size_t size, size_t extra_padding_size = 0); // Merge left and right blocks together. void Merge(MetadataCache* cache, MemoryBlock* right_buddy); diff --git a/paddle/fluid/memory/allocation/mmap_allocator.cc b/paddle/fluid/memory/allocation/mmap_allocator.cc index 3b371ed20e59c..f9647032a6a59 100644 --- a/paddle/fluid/memory/allocation/mmap_allocator.cc +++ b/paddle/fluid/memory/allocation/mmap_allocator.cc @@ -54,11 +54,14 @@ struct CountInfo { std::atomic refcount; }; -void AllocateMemoryMap( - std::string filename, int flags, size_t size, void **map_ptr_, int *fd_) { +void AllocateMemoryMap(std::string filename, + int *shared_fd, + int flags, + size_t size, + void **map_ptr_) { // TODO(@ZHUI): support win32 int file_flags = 0; - int fd = -1; + int fd = *shared_fd; if (flags & MAPPED_SHAREDMEM) { file_flags = O_RDWR | O_CREAT; } else { @@ -71,7 +74,7 @@ void AllocateMemoryMap( file_flags &= ~O_CREAT; } - if (!(flags & MAPPED_FROMFD)) { + if (!(flags & MAPPED_FROMFD) && fd == -1) { if (flags & MAPPED_SHAREDMEM) { fd = shm_open(filename.c_str(), file_flags, (mode_t)0600); PADDLE_ENFORCE_NE( @@ -83,14 +86,12 @@ void AllocateMemoryMap( VLOG(6) << "shm_open: " << filename; MemoryMapFdSet::Instance().Insert(filename); } - } else { - fd = -1; } PADDLE_ENFORCE_EQ(ftruncate(fd, size), 0, platform::errors::Unavailable( - "Fruncate a file to a specified length failed!")); + "Truncate a file to a specified length failed!")); if (flags & MAPPED_SHAREDMEM) { *map_ptr_ = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); @@ -98,41 +99,47 @@ void AllocateMemoryMap( *map_ptr_ = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, 0); } + if (flags & MAPPED_UNLINK) { + VLOG(6) << "shm_unlink: " << filename; + shm_unlink(filename.c_str()); + } + PADDLE_ENFORCE_NE(*map_ptr_, MAP_FAILED, platform::errors::Unavailable( "Memory map failed when create shared memory.")); - if (flags & MAPPED_KEEPFD) { - *fd_ = fd; + *shared_fd = fd; + VLOG(6) << "keep fd: " << *shared_fd; } else { PADDLE_ENFORCE_NE(::close(fd), -1, platform::errors::Unavailable( - "Error closing memory maped file <", filename, ">")); + "Error closing memory mapped file <", filename, ">")); - *fd_ = -1; + *shared_fd = -1; } } std::shared_ptr AllocateRefcountedMemoryMapAllocation(std::string filename, + int shared_fd, int flags, size_t size, int buffer_id) { - int fd = -1; + int fd = shared_fd; void *base_ptr = nullptr; if (buffer_id == -1) { - AllocateMemoryMap(filename, flags, size + mmap_alignment, &base_ptr, &fd); + AllocateMemoryMap(filename, &fd, flags, size + mmap_alignment, &base_ptr); VLOG(4) << "Create and mmap a new shm: " << filename; } else { base_ptr = MemoryMapAllocationPool::Instance().GetById(buffer_id).mmap_ptr_; VLOG(4) << "Get a cached shm " << filename; } - void *aliged_base_ptr = + void *aligned_base_ptr = static_cast(static_cast(base_ptr) + mmap_alignment); return std::make_shared( - aliged_base_ptr, size, filename, flags, fd, buffer_id); + aligned_base_ptr, size, filename, fd, flags, buffer_id); } RefcountedMemoryMapAllocation::RefcountedMemoryMapAllocation( @@ -145,11 +152,22 @@ RefcountedMemoryMapAllocation::RefcountedMemoryMapAllocation( : MemoryMapAllocation(ptr, size, ipc_name, fd, flags) { // must reset base ptr first. buffer_id_ = buffer_id; + fd_ = fd; + flags_ = flags; resetBaseptr(); initializeRefercount(); } void MemoryMapAllocation::close() { + if (!closed_fd_) { + closed_fd_ = true; + if (flags_ & MAPPED_KEEPFD) { + PADDLE_ENFORCE_NE(::close(fd_), + -1, + platform::errors::Unavailable( + "Error closing file descriptor <", fd_, ">")); + } + } if (closed_) { return; } @@ -193,6 +211,15 @@ void RefcountedMemoryMapAllocation::close() { void *data = map_ptr_; CountInfo *info = reinterpret_cast(data); --info->refcount; + if (flags_ & MAPPED_KEEPFD) { + closed_fd_ = true; + PADDLE_ENFORCE_NE(::close(fd_), + -1, + platform::errors::Unavailable( + "Error closing file descriptor <", fd_, ">")); + VLOG(6) << "close fd: " << fd_; + } + if (FLAGS_use_shm_cache && buffer_id_ != -1) { return; } else { @@ -260,6 +287,7 @@ std::shared_ptr AllocateMemoryMapWriterAllocation( const std::string &ipc_name = GetIPCName(); int flags = O_RDWR | O_CREAT; int fd = shm_open(ipc_name.c_str(), flags, 0600); + PADDLE_ENFORCE_NE(fd, -1, platform::errors::Unavailable( @@ -267,7 +295,7 @@ std::shared_ptr AllocateMemoryMapWriterAllocation( PADDLE_ENFORCE_EQ(ftruncate(fd, size), 0, platform::errors::Unavailable( - "Fruncate a file to a specified length failed!")); + "Truncate a file to a specified length failed!")); void *ptr = mmap(nullptr, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); PADDLE_ENFORCE_NE(ptr, @@ -283,7 +311,6 @@ std::shared_ptr RebuildMemoryMapReaderAllocation( const std::string &ipc_name, size_t size) { int flags = O_RDWR | O_CREAT; flags &= ~O_CREAT; - int fd = shm_open(ipc_name.c_str(), flags, 0600); PADDLE_ENFORCE_NE(fd, -1, @@ -337,7 +364,7 @@ MemoryMapAllocationPool *MemoryMapAllocationPool::pool_ = nullptr; void MemoryMapAllocationPool::Insert(const MemoryMapInfo &memory_map) { std::lock_guard guard(mtx_); memory_map_allocations_.push_back(memory_map); - VLOG(4) << this << "Intsert a new shm: " << memory_map.file_name_; + VLOG(4) << this << "Insert a new shm: " << memory_map.file_name_; } int MemoryMapAllocationPool::FindFromCache(const int &flag, diff --git a/paddle/fluid/memory/allocation/mmap_allocator.h b/paddle/fluid/memory/allocation/mmap_allocator.h index 412e3a3545769..64a3ae9de7658 100644 --- a/paddle/fluid/memory/allocation/mmap_allocator.h +++ b/paddle/fluid/memory/allocation/mmap_allocator.h @@ -44,13 +44,17 @@ enum MappedModes { class MemoryMapAllocation : public Allocation { public: - explicit MemoryMapAllocation(void *ptr, size_t size, std::string ipc_name) + explicit MemoryMapAllocation(void *ptr, + size_t size, + std::string ipc_name, + int fd) : Allocation(ptr, size, platform::CPUPlace()), ipc_name_(std::move(ipc_name)), + fd_(fd), map_ptr_(ptr), map_size_(size) {} explicit MemoryMapAllocation( - void *ptr, size_t size, std::string ipc_name, int flags, int fd) + void *ptr, size_t size, std::string ipc_name, int fd, int flags) : Allocation(ptr, size, platform::CPUPlace()), ipc_name_(std::move(ipc_name)), fd_(fd), @@ -59,6 +63,7 @@ class MemoryMapAllocation : public Allocation { map_size_(size) {} inline const std::string &ipc_name() const { return ipc_name_; } + inline const int shared_fd() const { return fd_; } virtual void close(); @@ -71,6 +76,7 @@ class MemoryMapAllocation : public Allocation { void *map_ptr_ = nullptr; size_t map_size_ = 0; bool closed_ = false; + bool closed_fd_ = false; }; class RefcountedMemoryMapAllocation : public MemoryMapAllocation { @@ -93,11 +99,15 @@ class RefcountedMemoryMapAllocation : public MemoryMapAllocation { void resetBaseptr(); }; -void AllocateMemoryMap( - std::string filename, int flags, size_t size, void **base_ptr_, int *fd_); +void AllocateMemoryMap(std::string filename, + int *shared_fd, + int flags, + size_t size, + void **base_ptr_); std::shared_ptr AllocateRefcountedMemoryMapAllocation(std::string filename, + int shared_fd, int flags, size_t size, int buffer_id = -1); @@ -111,11 +121,13 @@ class MemoryMapWriterAllocation : public Allocation { ipc_name_(std::move(ipc_name)) {} inline const std::string &ipc_name() const { return ipc_name_; } + inline const int shared_fd() const { return fd_; } ~MemoryMapWriterAllocation() override; private: std::string ipc_name_; + int fd_ = -1; }; class MemoryMapReaderAllocation : public Allocation { @@ -127,11 +139,13 @@ class MemoryMapReaderAllocation : public Allocation { ipc_name_(std::move(ipc_name)) {} inline const std::string &ipc_name() const { return ipc_name_; } + inline const int shared_fd() const { return fd_; } ~MemoryMapReaderAllocation() override; private: std::string ipc_name_; + int fd_ = -1; }; std::shared_ptr AllocateMemoryMapWriterAllocation( diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 612ba0798d2c0..b53e951f516f0 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -23,9 +23,9 @@ #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/split.h" #include "paddle/phi/common/place.h" +#include "paddle/utils/string/printf.h" +#include "paddle/utils/string/split.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif @@ -298,7 +298,7 @@ void *Alloc(const platform::CUDAPlace &place, auto *buddy_allocator = GetGPUBuddyAllocator(place.device); auto *ptr = buddy_allocator->Alloc(size); if (ptr == nullptr) { - platform::CUDADeviceGuard(place.device); + platform::CUDADeviceGuard guard(place.device); size_t avail, total; platform::GpuMemoryUsage(&avail, &total); PADDLE_THROW(platform::errors::ResourceExhausted( @@ -459,6 +459,9 @@ class BuddyAllocatorList { phi::DeviceManager::SetDevice(device_type_, dev_id); platform::CustomPlace place(device_type_, dev_id); + VLOG(10) << "Init BuddyAllocator on " << place + << " with GetExtraPaddingSize " + << phi::DeviceManager::GetExtraPaddingSize(place); allocators_[dev_id] = std::make_unique( std::unique_ptr( new detail::CustomAllocator(device_type_, dev_id)), diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc index 48b18f07456c6..dfcb90dffecb1 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc @@ -18,8 +18,10 @@ #include "paddle/fluid/platform/profiler/event_tracing.h" #include "paddle/phi/backends/gpu/gpu_info.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" #endif namespace paddle { @@ -48,7 +50,7 @@ void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) { [this] { phi::backends::gpu::SetDeviceId(place_.device); }); std::lock_guard lock_guard(outstanding_event_map_lock_); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { graph_capturing_stream_set_.insert(stream); return; @@ -66,7 +68,7 @@ void StreamSafeCUDAAllocation::EraseStream(gpuStream_t stream) { } bool StreamSafeCUDAAllocation::CanBeFreed() { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { return graph_capturing_stream_set_.empty() && outstanding_event_map_.empty(); @@ -86,7 +88,7 @@ bool StreamSafeCUDAAllocation::CanBeFreed() { gpuError_t err = cudaEventQuery(event); if (err == cudaErrorNotReady) { VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; - // Erase the completded event before "it" + // Erase the completed event before "it" outstanding_event_map_.erase(outstanding_event_map_.begin(), it); return false; } @@ -96,7 +98,7 @@ bool StreamSafeCUDAAllocation::CanBeFreed() { gpuError_t err = hipEventQuery(event); if (err == hipErrorNotReady) { VLOG(9) << "Event " << event << " for " << ptr() << " is not completed"; - // Erase the completded event before "it" + // Erase the completed event before "it" outstanding_event_map_.erase(outstanding_event_map_.begin(), it); return false; } @@ -234,7 +236,7 @@ void StreamSafeCUDAAllocator::FreeImpl(phi::Allocation* allocation) { uint64_t StreamSafeCUDAAllocator::ReleaseImpl(const platform::Place& place) { if (UNLIKELY(in_cuda_graph_capturing_)) { - VLOG(7) << "Memory release forbidden in CUDA Graph Captruing"; + VLOG(7) << "Memory release forbidden in CUDA Graph Capturing"; return 0; } @@ -249,8 +251,8 @@ uint64_t StreamSafeCUDAAllocator::ReleaseImpl(const platform::Place& place) { } void StreamSafeCUDAAllocator::ProcessUnfreedAllocations() { - // NOTE(Ruibiao): This condition is to reduce lock competion. It does not need - // to be thread-safe since here occasional misjudgments are permissible. + // NOTE(Ruibiao): This condition is to reduce lock completion. It does not + // need to be thread-safe since here occasional misjudgments are permissible. if (unfreed_allocations_.empty()) { return; } diff --git a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h index 31508a1079922..527455028b698 100644 --- a/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h +++ b/paddle/fluid/memory/allocation/stream_safe_cuda_allocator.h @@ -54,7 +54,7 @@ class StreamSafeCUDAAllocation : public Allocation { std::map outstanding_event_map_; gpuStream_t owning_stream_; SpinLock outstanding_event_map_lock_; - // To compatiable with CUDA Graph, hold the allocator shared_ptr so that + // To compatible with CUDA Graph, hold the allocator shared_ptr so that // Allocator will not deconstruct before Allocation std::shared_ptr allocator_; }; diff --git a/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc index ce63ab807e01e..218068aeb9c97 100644 --- a/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_custom_device_allocator.cc @@ -215,8 +215,8 @@ uint64_t StreamSafeCustomDeviceAllocator::ReleaseImpl( } void StreamSafeCustomDeviceAllocator::ProcessUnfreedAllocations() { - // NOTE(Ruibiao): This condition is to reduce lock competion. It does not need - // to be thread-safe since here occasional misjudgments are permissible. + // NOTE(Ruibiao): This condition is to reduce lock completion. It does not + // need to be thread-safe since here occasional misjudgments are permissible. if (unfreed_allocations_.empty()) { return; } diff --git a/paddle/fluid/memory/allocation/stream_safe_xpu_allocator.cc b/paddle/fluid/memory/allocation/stream_safe_xpu_allocator.cc index 7f48ef5ab5007..9809b1e5358c4 100644 --- a/paddle/fluid/memory/allocation/stream_safe_xpu_allocator.cc +++ b/paddle/fluid/memory/allocation/stream_safe_xpu_allocator.cc @@ -175,8 +175,8 @@ uint64_t StreamSafeXPUAllocator::ReleaseImpl(const platform::Place& place) { } void StreamSafeXPUAllocator::ProcessUnfreedAllocations() { - // NOTE(Ruibiao): This condition is to reduce lock competion. It does not need - // to be thread-safe since here occasional misjudgments are permissible. + // NOTE(Ruibiao): This condition is to reduce lock completion. It does not + // need to be thread-safe since here occasional misjudgments are permissible. if (unfreed_allocations_.empty()) { return; } diff --git a/paddle/fluid/memory/allocation/system_allocator.cc b/paddle/fluid/memory/allocation/system_allocator.cc index 4ca1f21c563fc..a6e19b84ba8d1 100644 --- a/paddle/fluid/memory/allocation/system_allocator.cc +++ b/paddle/fluid/memory/allocation/system_allocator.cc @@ -41,6 +41,7 @@ limitations under the License. */ #include "paddle/fluid/platform/profiler/mem_tracing.h" COMMON_DECLARE_bool(use_pinned_memory); +COMMON_DECLARE_bool(custom_device_mem_record); COMMON_DECLARE_double(fraction_of_gpu_memory_to_use); COMMON_DECLARE_uint64(initial_gpu_memory_in_mb); COMMON_DECLARE_uint64(reallocate_gpu_memory_in_mb); @@ -208,7 +209,8 @@ void* CUDAPinnedAllocator::Alloc(size_t* index, size_t size) { if (size > usable) { LOG(WARNING) << "Cannot malloc " << size / 1024.0 / 1024.0 << " MB pinned memory." - << ", available " << usable / 1024.0 / 1024.0 << " MB"; + << ", available " << usable / 1024.0 / 1024.0 + << " MB"; // NOLINT return nullptr; } @@ -297,6 +299,11 @@ void* CustomAllocator::Alloc(size_t* index, size_t size) { VLOG(4) << "CustomAllocator::Alloc " << p << " size " << size; *index = 0; plug_alloc_size += size; + if (FLAGS_custom_device_mem_record) { + DEVICE_MEMORY_STAT_UPDATE(Reserved, dev_id_, size); + platform::RecordMemEvent( + p, place, size, platform::TracerMemEventType::ReservedAllocate); + } } else { size_t avail, total; @@ -331,6 +338,11 @@ void CustomAllocator::Free(void* p, size_t size, size_t index) { auto place = platform::CustomPlace(dev_type_, dev_id_); auto device = phi::DeviceManager::GetDeviceWithPlace(place); device->MemoryDeallocate(p, size); + if (FLAGS_custom_device_mem_record) { + DEVICE_MEMORY_STAT_UPDATE(Reserved, dev_id_, size); + platform::RecordMemEvent( + p, place, size, platform::TracerMemEventType::ReservedFree); + } } bool CustomAllocator::UseGpu() const { return true; } diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc index 0c5bfe7bd1a90..52399df8ce5ff 100644 --- a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.cc @@ -22,9 +22,8 @@ namespace paddle { namespace memory { namespace allocation { -bool NeedSplit(size_t block_size, size_t alignment, size_t allock_size) { - return block_size > (allock_size * 2) || - (block_size - allock_size) > alignment; +bool NeedSplit(size_t block_size, size_t alignment, size_t alloc_size) { + return block_size > (alloc_size * 2) || (block_size - alloc_size) > alignment; } VirtualMemoryAutoGrowthBestFitAllocator:: diff --git a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h index ce5cbdeb12593..b8c7e38da00b8 100644 --- a/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h +++ b/paddle/fluid/memory/allocation/virtual_memory_auto_growth_best_fit_allocator.h @@ -46,7 +46,7 @@ struct BlockAllocation : public Allocation { * Like AutoGrowthBestFitAllocator, VirtualMemoryAutoGrowthBestFitAllocator will * gradually apply to GPU for video memory as the model uses more video memory. * However, the difference is that VirtualMemoryAutoGrowthBestFitAllocator uses - * nviaid's virtual memory management technology and obtains the virtual memory + * NVIDIA's virtual memory management technology and obtains the virtual memory * address. If the video memory applied for twice is continuous, we can combine * the two video memories later. This combination can greatly reduce * fragmentation. diff --git a/paddle/fluid/memory/malloc.h b/paddle/fluid/memory/malloc.h index a9286499ec24c..dc25b85c8b040 100644 --- a/paddle/fluid/memory/malloc.h +++ b/paddle/fluid/memory/malloc.h @@ -71,7 +71,7 @@ struct ThrustAllocator { place_ = place; stream_ = stream; } - ~ThrustAllocator() { VLOG(2) << "destory allocator"; } + ~ThrustAllocator() { VLOG(2) << "destroy allocator"; } char* allocate(std::ptrdiff_t num_bytes) { VLOG(2) << "allocate " << num_bytes << " bytes"; auto storage = memory::AllocShared( diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 7cdf93514c52c..6ba7b4ac1d613 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -638,12 +638,12 @@ void Copy(phi::Place dst_place, // NOTE: only for (CPUPlace, CUDAPlace and CUDAPinnedPlace) -> (CPUPlace). template <> -void Copy(phi::CPUPlace dst_place, - void* dst, - phi::Place src_place, - const void* src, - size_t num, - void* stream) { +TEST_API void Copy(phi::CPUPlace dst_place, + void* dst, + phi::Place src_place, + const void* src, + size_t num, + void* stream) { Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream); } @@ -835,11 +835,11 @@ TEST_API void Copy(phi::Place dst_place, // NOTE: Only for (CPUPlace and PinnedPlace) -> (CPUPlace). template <> -void Copy(phi::CPUPlace dst_place, - void* dst, - phi::Place src_place, - const void* src, - size_t num) { +TEST_API void Copy(phi::CPUPlace dst_place, + void* dst, + phi::Place src_place, + const void* src, + size_t num) { Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num); } @@ -872,12 +872,12 @@ void Copy(phi::Place dst_place, } template <> -void Copy(phi::CPUPlace dst_place, - void* dst, - phi::Place src_place, - const void* src, - size_t num, - void* stream) { +TEST_API void Copy(phi::CPUPlace dst_place, + void* dst, + phi::Place src_place, + const void* src, + size_t num, + void* stream) { Copy(phi::Place(dst_place.GetType()), dst, src_place, src, num, stream); } diff --git a/paddle/fluid/memory/memcpy.h b/paddle/fluid/memory/memcpy.h index c8d9208c48219..b0a9234817f0a 100644 --- a/paddle/fluid/memory/memcpy.h +++ b/paddle/fluid/memory/memcpy.h @@ -31,7 +31,7 @@ namespace memory { * */ template -void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); +TEST_API void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); /** * \brief Copy memory from one place to another place. @@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num); * */ template -void Copy( +TEST_API void Copy( DstPlace, void* dst, SrcPlace, const void* src, size_t num, void* stream); } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/stats.cc b/paddle/fluid/memory/stats.cc index 39b01c46f389e..2d66a5b6838b0 100644 --- a/paddle/fluid/memory/stats.cc +++ b/paddle/fluid/memory/stats.cc @@ -36,7 +36,7 @@ class StatRegistry { auto it = stat_map_.find(GetStatKey(stat_type, dev_id)); if (it == stat_map_.end()) { PADDLE_THROW(platform::errors::InvalidArgument( - "The STAT type \"%s\" for device %d has not been regeistered.", + "The STAT type \"%s\" for device %d has not been registered.", stat_type.c_str(), dev_id)); } @@ -171,7 +171,7 @@ int RegisterAllStats() { return 0; } -UNUSED static int regiester_all_stats = RegisterAllStats(); +UNUSED static int register_all_stats = RegisterAllStats(); } // namespace memory } // namespace paddle diff --git a/paddle/fluid/memory/stats.h b/paddle/fluid/memory/stats.h index b6d722b62a4b0..78d20d968c968 100644 --- a/paddle/fluid/memory/stats.h +++ b/paddle/fluid/memory/stats.h @@ -42,7 +42,7 @@ struct ThreadLocalStatBase { friend std::ostream& operator<<(std::ostream& os, const ThreadLocalStatBase& stat) { - os << "{cuerrent : " << stat.current << ", peak : " << stat.peak << "}"; + os << "{current : " << stat.current << ", peak : " << stat.peak << "}"; return os; } }; @@ -136,7 +136,7 @@ void HostMemoryStatUpdate(const std::string& stat_type, void LogDeviceMemoryStats(const platform::Place& place, const std::string& op_name); -#define DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, id) \ +#define DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, id) \ case id: \ stat = paddle::memory::Stat< \ paddle::memory::DeviceMemoryStat##item##id>::GetInstance(); \ @@ -146,22 +146,22 @@ void LogDeviceMemoryStats(const platform::Place& place, [&] { \ paddle::memory::StatBase* stat = nullptr; \ switch (id) { \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 0); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 1); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 2); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 3); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 4); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 5); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 6); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 7); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 8); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 9); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 10); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 11); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 12); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 13); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 14); \ - DEVICE_MEMORY_STAT_FUNC_SWITHCH_CASE(item, 15); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 0); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 1); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 2); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 3); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 4); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 5); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 6); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 7); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 8); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 9); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 10); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 11); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 12); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 13); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 14); \ + DEVICE_MEMORY_STAT_FUNC_SWITCH_CASE(item, 15); \ default: \ PADDLE_THROW(paddle::platform::errors::OutOfRange( \ "Only support device id between [0, 15] for device memory stats," \ diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 5d03c833a87c7..280f24bdd6fa6 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -35,8 +35,6 @@ if (WITH_PSCORE) add_subdirectory(pscore) endif() -add_subdirectory(amp) - add_subdirectory(reader) if (NOT WIN32) diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index b848697128731..1e01f587f7464 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/infermeta/backward.h" @@ -94,7 +94,7 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker { // paddle::Tensor dx = this->GetSingleInputGrad("X"); // auto* dx_ptr = this->GetOutputPtr(&dx); // std::string dx_name = this->GetOutputName(dx); -// VLOG(6) << "Runing hardswish_grad composite func"; +// VLOG(6) << "Running hardswish_grad composite func"; // prim::hardswish_grad(x, out_grad, dx_ptr); // this->RecoverOutputName(dx, dx_name); // } @@ -394,19 +394,19 @@ REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(leaky_relu) .AddCheckpoint( - R"ROC(fix leaky_relu, bahavior changed when alpha < 0 or alpha > 1)ROC", + R"ROC(fix leaky_relu, behavior changed when alpha < 0 or alpha > 1)ROC", paddle::framework::compatible::OpVersionDesc() .BugfixWithBehaviorChanged( - "leaky_relu calculate formula before checkponit: out = max(x, " + "leaky_relu calculate formula before checkpoint: out = max(x, " "alpha * x); after checkpoint: out = x if x > 0 else alpha * " "x")); REGISTER_OP_VERSION(hard_shrink) .AddCheckpoint( - R"ROC(fix hard_shrink, bahavior changed when threshold<0)ROC", + R"ROC(fix hard_shrink, behavior changed when threshold<0)ROC", paddle::framework::compatible::OpVersionDesc() .BugfixWithBehaviorChanged( - "hard_shrink calculate formula before checkponit: out = x * " + "hard_shrink calculate formula before checkpoint: out = x * " "((x < -threshold) + (x > threshold)); after checkpoint: out = " "x * (((x < -threshold) + (x > threshold)) > 0)")); diff --git a/paddle/fluid/operators/activation_op.h b/paddle/fluid/operators/activation_op.h index 8280c817b706a..38432f8768f59 100644 --- a/paddle/fluid/operators/activation_op.h +++ b/paddle/fluid/operators/activation_op.h @@ -371,7 +371,7 @@ struct AbsGradGradFunctor : public BaseActivationFunctor { // TODO(dengkaipeng): double gradient calculation for Square/Sqrt need // DOut(dy) as input(not output), tensor extraction is different from -// others. Impliment extraction kernel separately here. +// others. Implement extraction kernel separately here. inline void ExtractDoubleGradTensorWithInputDOut( const framework::ExecutionContext& ctx, const phi::DenseTensor** X, diff --git a/paddle/fluid/operators/amp/CMakeLists.txt b/paddle/fluid/operators/amp/CMakeLists.txt deleted file mode 100644 index cbd9c8b2768b4..0000000000000 --- a/paddle/fluid/operators/amp/CMakeLists.txt +++ /dev/null @@ -1,6 +0,0 @@ -include(operators) -if(WITH_UNITY_BUILD) - # Load Unity Build rules for operators in paddle/fluid/operators/amp. - include(unity_build_rule.cmake) -endif() -register_operators() diff --git a/paddle/fluid/operators/amp/alloc_float_status_op.cc b/paddle/fluid/operators/amp/alloc_float_status_op.cc deleted file mode 100644 index 2c1b4b201e5c3..0000000000000 --- a/paddle/fluid/operators/amp/alloc_float_status_op.cc +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class AllocFloatStatusOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasOutput("FloatStatus"), - "Output", - "FloatStatus", - "alloc_float_status"); - ctx->SetOutputDim("FloatStatus", {8}); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); - } -}; - -class AllocFloatStatusMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddOutput("FloatStatus", - "(Tensor) of shape {8} that holds the float status."); - AddComment(R"DOC( - Produces a float Tensor that holds the float status -)DOC"); - } -}; - -template -class AllocFloatStatusKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Operator alloc_float_status is not supported on CPU")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CPU = phi::CPUContext; - -REGISTER_OPERATOR( - alloc_float_status, - ops::AllocFloatStatusOp, - ops::AllocFloatStatusMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - alloc_float_status, CPU, ALL_LAYOUT, ops::AllocFloatStatusKernel, float) {} diff --git a/paddle/fluid/operators/amp/clear_float_status_op.cc b/paddle/fluid/operators/amp/clear_float_status_op.cc deleted file mode 100644 index d595a26e5575a..0000000000000 --- a/paddle/fluid/operators/amp/clear_float_status_op.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class ClearFloatStatusOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasOutput("FloatStatusOut"), - "Output", - "FloatStatusOut", - "clear_float_status"); - ctx->SetOutputDim("FloatStatusOut", ctx->GetInputDim("FloatStatus")); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); - } -}; - -class ClearFloatStatusMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("FloatStatus", - "(Tensor) of shape {8} that holds the float status."); - AddOutput( - "FloatStatusOut", - "(Tensor) of shape {8} that holds the float status, which is cleared."); - AddComment(R"DOC( - Clear the float status -)DOC"); - } -}; - -template -class ClearFloatStatusKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Operator clear_float_status is not supported on CPU")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR( - clear_float_status, - ops::ClearFloatStatusOp, - ops::ClearFloatStatusMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - clear_float_status, CPU, ALL_LAYOUT, ops::ClearFloatStatusKernel, float) {} diff --git a/paddle/fluid/operators/amp/get_float_status_op.cc b/paddle/fluid/operators/amp/get_float_status_op.cc deleted file mode 100644 index 8700d82976f01..0000000000000 --- a/paddle/fluid/operators/amp/get_float_status_op.cc +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -class GetFloatStatusOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasOutput("FloatStatusOut"), - "Output", - "FloatStatusOut", - "get_float_status"); - ctx->SetOutputDim("FloatStatusOut", ctx->GetInputDim("FloatStatus")); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace()); - } -}; - -class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("FloatStatus", - "(Tensor) of shape {8} that holds the float status."); - AddOutput("FloatStatusOut", - "(Tensor) of shape {8} that holds the get float status."); - AddComment(R"DOC( - Get the float status -)DOC"); - } -}; - -template -class GetFloatStatusKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "Operator get_float_status is not supported on CPU")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -using CPU = phi::CPUContext; - -REGISTER_OPERATOR( - get_float_status, - ops::GetFloatStatusOp, - ops::GetFloatStatusMaker, - paddle::framework::EmptyGradOpMaker, - paddle::framework::EmptyGradOpMaker); - -PD_REGISTER_STRUCT_KERNEL( - get_float_status, CPU, ALL_LAYOUT, ops::GetFloatStatusKernel, float) {} diff --git a/paddle/fluid/operators/amp/unity_build_rule.cmake b/paddle/fluid/operators/amp/unity_build_rule.cmake deleted file mode 100644 index fa460e33c8068..0000000000000 --- a/paddle/fluid/operators/amp/unity_build_rule.cmake +++ /dev/null @@ -1,10 +0,0 @@ -# This file records the Unity Build compilation rules. -# The source files in a `register_unity_group` called are compiled in a unity -# file. -# Generally, the combination rules in this file do not need to be modified. -# If there are some redefined error in compiling with the source file which -# in combination rule, you can remove the source file from the following rules. -register_unity_group(cc check_finite_and_unscale_op.cc - update_loss_scaling_op.cc) -register_unity_group(cu check_finite_and_unscale_op.cu - update_loss_scaling_op.cu) diff --git a/paddle/fluid/operators/assign_value_op.h b/paddle/fluid/operators/assign_value_op.h index 2a6a31ba03004..5ba8b9367e64e 100644 --- a/paddle/fluid/operators/assign_value_op.h +++ b/paddle/fluid/operators/assign_value_op.h @@ -29,7 +29,7 @@ typename std::enable_if::value>::type CopyVectorToTensor( const char* value_name, phi::DenseTensor* out, const framework::ExecutionContext& ctx) { - // phi::DenseTensore dtype is vector, it will be converted to + // phi::DenseTensor dtype is vector, it will be converted to // vector. // at the same time, we can not use vector to hold the value, because // the c++ use bit value to replace byte value. diff --git a/paddle/fluid/operators/attention_lstm_op.cc b/paddle/fluid/operators/attention_lstm_op.cc index 9624f752b780f..6a0775e6331a7 100644 --- a/paddle/fluid/operators/attention_lstm_op.cc +++ b/paddle/fluid/operators/attention_lstm_op.cc @@ -488,7 +488,7 @@ class AttentionLSTMKernel : public framework::OpKernel { // gate act: sigmoid act_gate(D3, lstm_out_data, lstm_out_data); - // candicate act: tanh + // candidate act: tanh act_cand(D, lstm_out_data + D3, lstm_out_data + D3); // a = forget * prev_cell diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index fd05b018bbfb6..996c6af070631 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -308,11 +308,11 @@ void BatchNormOpMaker::Make() { "to true or is_test true. the behavior is equivalent. " "In train mode, when setting use_global_stats True, the " "global mean and variance are also used during train time, " - "the BN acts as scaling and shiffting.") + "the BN acts as scaling and shifting.") .SetDefault(false); AddAttr("trainable_statistics", "(bool, default false) Whether to calculate mean and variance " - "in test mode. If setting true in test mode, mean and variace " + "in test mode. If setting true in test mode, mean and variance " "will be calculated by current batch statistics.") .SetDefault(false); AddComment(R"DOC( @@ -586,7 +586,7 @@ class BatchNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { auto use_global_stats = this->Attr("use_global_stats"); auto trainable_statistics = this->Attr("trainable_statistics"); - VLOG(3) << "Runing batch_norm composite func"; + VLOG(3) << "Running batch_norm composite func"; prim::batch_norm_grad(x, scale, bias, diff --git a/paddle/fluid/operators/beam_search_decode_op_def.h b/paddle/fluid/operators/beam_search_decode_op_def.h index 390f728322322..d358d8255fcf3 100644 --- a/paddle/fluid/operators/beam_search_decode_op_def.h +++ b/paddle/fluid/operators/beam_search_decode_op_def.h @@ -27,7 +27,7 @@ using LoDTensorArray = framework::LoDTensorArray; // all the lod have 2 levels. // The first is source level, the second is sentence level. -// source level describe how many prefixes (branchs) for each source sentence +// source level describe how many prefixes (branches) for each source sentence // (beam). sentence level describe how these candidates belong to the prefixes. const size_t kSourceLevel = 0; const size_t kSentenceLevel = 1; diff --git a/paddle/fluid/operators/chunk_eval_op.h b/paddle/fluid/operators/chunk_eval_op.h index 22b3accba8639..baad8719db37f 100644 --- a/paddle/fluid/operators/chunk_eval_op.h +++ b/paddle/fluid/operators/chunk_eval_op.h @@ -199,7 +199,7 @@ class ChunkEvalKernel : public framework::OpKernel { const int64_t* inference_data = inference->data(); const int64_t* label_data = label->data(); T* precision_data = precision->mutable_data(place); - T* racall_data = recall->mutable_data(place); + T* recall_data = recall->mutable_data(place); T* f1_data = f1->mutable_data(place); int64_t* num_infer_chunks_data = num_infer_chunks->mutable_data(place); @@ -280,14 +280,14 @@ class ChunkEvalKernel : public framework::OpKernel { ? 0 : static_cast(*num_correct_chunks_data) / (*num_infer_chunks_data); - *racall_data = !(*num_label_chunks_data) + *recall_data = !(*num_label_chunks_data) ? 0 : static_cast(*num_correct_chunks_data) / (*num_label_chunks_data); *f1_data = !(*num_correct_chunks_data) ? 0 - : 2 * (*precision_data) * (*racall_data) / - ((*precision_data) + (*racall_data)); + : 2 * (*precision_data) * (*recall_data) / + ((*precision_data) + (*recall_data)); } void EvalOneSeq(const int64_t* output, diff --git a/paddle/fluid/operators/cinn/cinn_launch_context.cc b/paddle/fluid/operators/cinn/cinn_launch_context.cc index f75e77a075177..734987ce92235 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_context.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_context.cc @@ -44,9 +44,9 @@ #include "paddle/fluid/operators/cinn/cinn_op_helper.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/printf.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" +#include "paddle/utils/string/printf.h" #include "paddle/utils/string/string_helper.h" COMMON_DECLARE_string(static_runtime_data_save_path); @@ -412,10 +412,10 @@ std::unique_ptr CinnLaunchContext::BuildCompiledProgram( // build a map that links the name of a Paddle variable to its VarDesc const std::unordered_set& nodes = graph.Nodes(); - std::unordered_map original_vardescs; + std::unordered_map original_var_descs; for (auto* node : nodes) { if (node->IsVar() && node->Var()) { - original_vardescs.emplace(node->Name(), node->Var()); + original_var_descs.emplace(node->Name(), node->Var()); } } @@ -433,8 +433,8 @@ std::unique_ptr CinnLaunchContext::BuildCompiledProgram( framework::VarDesc* var_desc = block->Var(var_name); var_desc->SetType(framework::proto::VarType::LOD_TENSOR); - auto res = original_vardescs.find(var_name); - if (res != original_vardescs.end()) { + auto res = original_var_descs.find(var_name); + if (res != original_var_descs.end()) { auto* ori_desc = res->second; var_desc->SetPersistable(ori_desc->Persistable()); var_desc->SetIsParameter(ori_desc->IsParameter()); diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc index e6afd9277583b..9edb7348b125c 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -22,8 +22,8 @@ #include "paddle/cinn/runtime/cinn_runtime.h" #include "paddle/cinn/runtime/flags.h" #include "paddle/common/flags.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/core/generator.h" +#include "paddle/utils/string/string_helper.h" #if defined(PADDLE_WITH_CUDA) COMMON_DECLARE_bool(cudnn_deterministic); diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.h b/paddle/fluid/operators/cinn/cinn_launch_op.h index c9e9d9222b6a7..2ce23dc965b31 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn/cinn_launch_op.h @@ -43,7 +43,7 @@ using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject; namespace details { -// Tranform Paddle place to CINN target +// Transform Paddle place to CINN target const ::cinn::common::Target& PlaceToCinnTarget(const platform::Place& place); // Print detailed compilation result of graph for debug diff --git a/paddle/fluid/operators/collective/c_allreduce_avg_op.cc b/paddle/fluid/operators/collective/c_allreduce_avg_op.cc new file mode 100644 index 0000000000000..3343406a02b6c --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_avg_op.cc @@ -0,0 +1,45 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class CAllReduceAvgOpMaker : public CAllReduceOpMaker { + protected: + std::string GetName() const override { return "Avg"; } +}; + +DECLARE_INPLACE_OP_INFERER(AllreduceAvgInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_allreduce_avg, + ops::CAllReduceOp, + ops::CAllReduceAvgOpMaker, + ops::AllreduceAvgInplaceInferer) diff --git a/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc b/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc new file mode 100644 index 0000000000000..d3f0b45f64432 --- /dev/null +++ b/paddle/fluid/operators/collective/c_allreduce_avg_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_allreduce_op.h" + +namespace paddle { +namespace operators { +DEFINE_C_ALLREDUCE_CUDA_KERNEL(CAllReduceAvg, kRedAvg) +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(c_allreduce_avg, + GPU, + ALL_LAYOUT, + ops::CAllReduceAvgCUDAKernel, + float, + double, + int, + int64_t, + plat::float16, + plat::bfloat16) {} diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 95e02e35adfc4..55ca03c0bc626 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -48,7 +48,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { -enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg }; class CAllReduceOp : public framework::OperatorWithKernel { public: @@ -391,7 +391,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { stream = ctx.cuda_device_context().stream(); } VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel - << ", redtype:" << static_cast(red_type) + << ", reduce type:" << static_cast(red_type) << ", dtype:" << dtype << ", comm:" << comm << ", stream:" << stream; @@ -413,6 +413,12 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { nccl_red_type = ncclProd; break; +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + case kRedAvg: + nccl_red_type = ncclAvg; + break; +#endif + default: PADDLE_THROW(platform::errors::InvalidArgument( "Invalid reduce type: %d", red_type)); diff --git a/paddle/fluid/operators/collective/c_reduce_avg_op.cc b/paddle/fluid/operators/collective/c_reduce_avg_op.cc new file mode 100644 index 0000000000000..53ce6e221a9f8 --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_avg_op.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace framework { +class OpDesc; +template +class EmptyGradOpMaker; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +} // namespace paddle + +namespace paddle { +namespace operators { + +class CReduceAvgOpMaker : public CReduceOpMaker { + protected: + std::string GetName() const override { return "Avg"; } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_WITHOUT_GRADIENT(c_reduce_avg, + ops::CReduceOp, + ops::CReduceAvgOpMaker); diff --git a/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc b/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc new file mode 100644 index 0000000000000..07d2cc748900e --- /dev/null +++ b/paddle/fluid/operators/collective/c_reduce_avg_op.cu.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/collective/c_reduce_op.h" + +namespace paddle { +namespace operators { +DEFINE_C_REDUCE_CUDA_KERNEL(CReduceAvg, kRedAvg); +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +PD_REGISTER_STRUCT_KERNEL(c_reduce_avg, + GPU, + ALL_LAYOUT, + ops::CReduceAvgCUDAKernel, + float, + double, + int, + int64_t, + plat::float16, + plat::bfloat16) {} diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index e8e240c9b5525..d90fb88fe8f3f 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -50,7 +50,7 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); namespace paddle { namespace operators { -enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd }; +enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd, kRedAvg }; class CReduceOp : public framework::OperatorWithKernel { public: @@ -304,6 +304,12 @@ class CReduceOpCUDAKernel : public framework::OpKernel { nccl_red_type = ncclProd; break; +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 + case kRedAvg: + nccl_red_type = ncclAvg; + break; +#endif + default: PADDLE_ENFORCE_EQ(true, false, diff --git a/paddle/fluid/operators/collective/c_scatter_op.cc b/paddle/fluid/operators/collective/c_scatter_op.cc index 162f4d1478584..40b6eeacf8030 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cc @@ -68,7 +68,7 @@ class CScatterOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(0); AddAttr("root", "(int default 0) root id for broadcasting.") .SetDefault(0); - AddAttr("nranks", "(int default 1) number of ranks.").SetDefault(0); + AddAttr("nranks", "(int default 0) number of ranks.").SetDefault(0); AddAttr( "use_calc_stream", "(bool default false) eject CUDA operations to calculation stream.") diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu index 38133a70f839d..e65ebafad7235 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op.cu @@ -19,13 +19,13 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/cross_entropy.h" #include "paddle/phi/kernels/funcs/math.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/softmax_impl.h" +#include "paddle/utils/string/string_helper.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/common/flags.h" diff --git a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc index 9bdac4888c109..499b25e65974b 100644 --- a/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_softmax_with_cross_entropy_op_xpu.cc @@ -18,7 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/xpu/bkcl_helper.h" -#include "paddle/fluid/string/string_helper.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/axis_utils.h" @@ -26,6 +25,7 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/softmax_impl.h" #include "paddle/phi/kernels/xpu/elementwise.h" #include "paddle/phi/kernels/xpu/reduce.h" +#include "paddle/utils/string/string_helper.h" #if defined(PADDLE_WITH_XPU_BKCL) #include "paddle/common/flags.h" @@ -83,8 +83,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { const auto& logits_dims = logits->dims(); const int axis = logits_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, logits_dims); - const int D = phi::funcs::SizeFromAxis(axis, logits_dims); + const int64_t N = phi::funcs::SizeToAxis(axis, logits_dims); + const int64_t D = phi::funcs::SizeFromAxis(axis, logits_dims); phi::DenseTensor logits_2d, softmax_2d; framework::TensorCopy( @@ -151,8 +151,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { N, 0.0); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant"); - const int start_index = rank * D; - const int end_index = start_index + D; + const int64_t start_index = rank * D; + const int64_t end_index = start_index + D; const auto& label_type = framework::TransToProtoVarType(labels->dtype()); if (label_type == framework::proto::VarType::INT32) { ret = xpu::mask_label_by_index( @@ -224,7 +224,7 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor { opts.reduce_op = distributed::ReduceOp::SUM; pg->AllReduce(in_out, in_out, opts)->Synchronize(); - int dims[4] = {N, D, N, 1}; + int64_t dims[4] = {N, D, N, 1}; ret = xpu::broadcast_div( dev_ctx.x_context(), reinterpret_cast(softmax_2d.data()), @@ -313,8 +313,8 @@ struct CSoftmaxWithCrossEntropyFunctor { const auto& logits_dims = logits->dims(); const int axis = logits_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, logits_dims); - const int D = phi::funcs::SizeFromAxis(axis, logits_dims); + const int64_t N = phi::funcs::SizeToAxis(axis, logits_dims); + const int64_t D = phi::funcs::SizeFromAxis(axis, logits_dims); phi::DenseTensor logits_2d, softmax_2d; framework::TensorCopy( @@ -390,8 +390,8 @@ struct CSoftmaxWithCrossEntropyFunctor { N, 0.0); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant"); - const int start_index = rank * D; - const int end_index = start_index + D; + const int64_t start_index = rank * D; + const int64_t end_index = start_index + D; const auto& label_type = framework::TransToProtoVarType(labels->dtype()); if (label_type == framework::proto::VarType::INT32) { ret = xpu::mask_label_by_index( @@ -485,7 +485,7 @@ struct CSoftmaxWithCrossEntropyFunctor { } { - int dims[4] = {N, D, N, 1}; + int64_t dims[4] = {N, D, N, 1}; ret = xpu::broadcast_div( dev_ctx.x_context(), reinterpret_cast(softmax_2d.data()), @@ -540,11 +540,11 @@ class CSoftmaxWithCrossEntropyGrad : public framework::OpKernel { } const auto softmax_dims = softmax->dims(); const int axis = softmax_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, softmax_dims); - const int D = phi::funcs::SizeFromAxis(axis, softmax_dims); + const int64_t N = phi::funcs::SizeToAxis(axis, softmax_dims); + const int64_t D = phi::funcs::SizeFromAxis(axis, softmax_dims); - const int start_index = rank * D; - const int end_index = start_index + D; + const int64_t start_index = rank * D; + const int64_t end_index = start_index + D; const auto& label_type = framework::TransToProtoVarType(labels->dtype()); int ret = 0; diff --git a/paddle/fluid/operators/collective/gen_bkcl_id_op.cc b/paddle/fluid/operators/collective/gen_bkcl_id_op.cc index 581e6183fe74d..fc765e3bde983 100644 --- a/paddle/fluid/operators/collective/gen_bkcl_id_op.cc +++ b/paddle/fluid/operators/collective/gen_bkcl_id_op.cc @@ -26,7 +26,7 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/gen_comm_id_helper.h" #include "paddle/fluid/platform/place.h" -#include "paddle/fluid/string/split.h" +#include "paddle/utils/string/split.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/common_infer_shape_functions.cc b/paddle/fluid/operators/common_infer_shape_functions.cc index 52836ead345a1..1c13f873818f4 100644 --- a/paddle/fluid/operators/common_infer_shape_functions.cc +++ b/paddle/fluid/operators/common_infer_shape_functions.cc @@ -166,7 +166,7 @@ void BinaryOpBroadcastInferShape(framework::InferShapeContext *ctx) { "For binary broadcastable operator, if X is " "Sparse(VarType.SELECTED_ROWS" "), Y must be scalar, and the size of Y should be 1. " - "But reveived the size of Y = %s.", + "But received the size of Y = %s.", y_dims.size())); PADDLE_ENFORCE_EQ( y_dims[0], @@ -175,7 +175,7 @@ void BinaryOpBroadcastInferShape(framework::InferShapeContext *ctx) { "For binary broadcastable operator, if X is " "Sparse(VarType.SELECTED_ROWS" "), Y must be scalar, the first dimension of Y should be 1. " - "But reveived the first dimension of Y = %s.", + "But received the first dimension of Y = %s.", y_dims[0])); } else if (ctx->GetInputsVarType(x_name).front() != framework::proto::VarType::LOD_TENSOR) { diff --git a/paddle/fluid/operators/common_infer_shape_functions.h b/paddle/fluid/operators/common_infer_shape_functions.h index 5ce21b1de529b..a61686f3f7544 100644 --- a/paddle/fluid/operators/common_infer_shape_functions.h +++ b/paddle/fluid/operators/common_infer_shape_functions.h @@ -34,12 +34,13 @@ framework::DDim BroadcastTwoDims(const framework::DDim& x_dims, int axis = -1); } // shape input(0) -> output(0) without change. -void UnaryOpUnchangedInferShape(framework::InferShapeContext* ctx); +TEST_API void UnaryOpUnchangedInferShape(framework::InferShapeContext* ctx); // shape input(0) -> output(0) without change, check if axis in range [-Rank(x), // Rank(x)-1] -void UnaryOpUnchangedInferShapeCheckAxis(framework::InferShapeContext* ctx); +TEST_API void UnaryOpUnchangedInferShapeCheckAxis( + framework::InferShapeContext* ctx); // broadcast input(0) and input(1) -> output(0) -void BinaryOpBroadcastInferShape(framework::InferShapeContext* ctx); +TEST_API void BinaryOpBroadcastInferShape(framework::InferShapeContext* ctx); } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/compat/conv2d_transpose_bias.pbtxt b/paddle/fluid/operators/compat/conv2d_transpose_bias.pbtxt new file mode 100644 index 0000000000000..bce4fc9f0e114 --- /dev/null +++ b/paddle/fluid/operators/compat/conv2d_transpose_bias.pbtxt @@ -0,0 +1,69 @@ +type: "conv2d_transpose_bias" +def { + inputs { + name: "Input" + } + inputs { + name: "Filter" + } + inputs { + name: "Bias" + } + outputs { + name: "Output" + } + attrs { + name: "output_padding" + type: INTS + } + attrs { + name: "output_size" + type: INTS + } + attrs { + name: "groups" + type: INT + } + attrs { + name: "dilations" + type: INTS + } + attrs { + name: "strides" + type: INTS + } + attrs { + name: "paddings" + type: INTS + } + attrs { + name: "padding_algorithm" + type: STRING + } + attrs { + name: "data_format" + type: STRING + } +} +extra { + attrs { + name: "force_fp32_output" + type: BOOLEAN + } + attrs { + name: "fuse_relu" + type: BOOLEAN + } + attrs { + name: "fuse_activation" + type: STRING + } + attrs { + name: "fuse_alpha" + type: FLOAT + } + attrs { + name: "fuse_beta" + type: FLOAT + } +} diff --git a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h index 1db6159201eb6..dc585a409ee82 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op_helper.h +++ b/paddle/fluid/operators/controlflow/conditional_block_op_helper.h @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/controlflow/conditional_block_op.h" #include "paddle/fluid/operators/controlflow/op_variant.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/controlflow/pylayer_op.cc b/paddle/fluid/operators/controlflow/pylayer_op.cc index c4b06f326a703..bd83c99a0c62d 100644 --- a/paddle/fluid/operators/controlflow/pylayer_op.cc +++ b/paddle/fluid/operators/controlflow/pylayer_op.cc @@ -26,11 +26,12 @@ namespace { // NOLINT enum class PyLayerBlockIndex { kFORWARD = 0, kBACKWARD = 1, kNONE = 2 }; } // namespace -const char PyLayerOp::kInputs[] = "Input"; -const char PyLayerOp::kOutputs[] = "Out"; -const char PyLayerOp::kScope[] = "Scope"; -const char PyLayerOp::kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; -const char PyLayerOp::kBlocks[] = "blocks"; +const char PyLayerOp::kInputs[] = "Input"; // NOLINT +const char PyLayerOp::kOutputs[] = "Out"; // NOLINT +const char PyLayerOp::kScope[] = "Scope"; // NOLINT +const char PyLayerOp::kSkipEagerDeletionVars[] = + "skip_eager_deletion_vars"; // NOLINT +const char PyLayerOp::kBlocks[] = "blocks"; // NOLINT void PyLayerOp::CreateInterpreter( const platform::Place &dev_place, diff --git a/paddle/fluid/operators/controlflow/pylayer_op_helper.h b/paddle/fluid/operators/controlflow/pylayer_op_helper.h index 1295a6cba60a0..8dcb3997927d3 100644 --- a/paddle/fluid/operators/controlflow/pylayer_op_helper.h +++ b/paddle/fluid/operators/controlflow/pylayer_op_helper.h @@ -20,7 +20,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/controlflow/op_variant.h" #include "paddle/fluid/operators/controlflow/pylayer_op.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/controlflow/recurrent_op_helper.h b/paddle/fluid/operators/controlflow/recurrent_op_helper.h index 752a0a1f764eb..37573cc617643 100644 --- a/paddle/fluid/operators/controlflow/recurrent_op_helper.h +++ b/paddle/fluid/operators/controlflow/recurrent_op_helper.h @@ -24,7 +24,7 @@ #include "paddle/fluid/operators/controlflow/op_variant.h" #include "paddle/fluid/operators/recurrent_op.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/operators/controlflow/unity_build_rule.cmake b/paddle/fluid/operators/controlflow/unity_build_rule.cmake index 594ae3a36cf1d..4b88de66fd2f9 100644 --- a/paddle/fluid/operators/controlflow/unity_build_rule.cmake +++ b/paddle/fluid/operators/controlflow/unity_build_rule.cmake @@ -6,15 +6,11 @@ # in combination rule, you can remove the source file from the following rules. register_unity_group( cc - compare_all_op.cc - compare_op.cc conditional_block_infer_op.cc feed_op.cc fetch_op.cc fetch_v2_op.cc get_places_op.cc - logical_op.cc - bitwise_op.cc tensor_array_read_write_op.cc while_op.cc) register_unity_group(cu logical_op.cu bitwise_op.cu compare_op.cu diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index f0b4cb1529421..5c758bbf7ff42 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -113,7 +113,7 @@ class WhileOp : public framework::OperatorBase { const framework::VariableNameMap &output_var_names = op->Outputs(); for (auto &ipt : input_var_names) { for (const std::string &var_name : ipt.second) { - if (StrInVaraiableNameMap(var_name, output_var_names)) { + if (StrInVariableNameMap(var_name, output_var_names)) { no_copy_var_names.insert(var_name); } } diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index 8ddce0da7faac..80b4abe763123 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -16,7 +16,7 @@ #include -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" namespace paddle { namespace framework { @@ -89,7 +89,7 @@ static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, platform::errors::PreconditionNotMet( "Backward output gradient number does not match forward input number." "The number of forward input number is %d and the number of backward " - "output geadient number is %d.", + "output gradient number is %d.", fwd_input.size(), in_grads.size())); @@ -239,8 +239,8 @@ bool GetCondData(const phi::DenseTensor &cond) { return cpu_cond->data()[0]; } -bool StrInVaraiableNameMap(const std::string &name, - const framework::VariableNameMap &var_names) { +bool StrInVariableNameMap(const std::string &name, + const framework::VariableNameMap &var_names) { for (auto &ipt : var_names) { if (std::find(ipt.second.begin(), ipt.second.end(), name) != ipt.second.end()) { diff --git a/paddle/fluid/operators/controlflow/while_op_helper.h b/paddle/fluid/operators/controlflow/while_op_helper.h index 7aa4b6418b6bc..7b4d912745d61 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.h +++ b/paddle/fluid/operators/controlflow/while_op_helper.h @@ -56,8 +56,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( bool GetCondData(const phi::DenseTensor &cond); -bool StrInVaraiableNameMap(const std::string &, - const framework::VariableNameMap &); +bool StrInVariableNameMap(const std::string &, + const framework::VariableNameMap &); void TransferVariablePlace(const framework::Scope *scope, const std::string &var_name, diff --git a/paddle/fluid/operators/crop_op.h b/paddle/fluid/operators/crop_op.h index fdb2c538fd8a3..7d0d4f06392fa 100644 --- a/paddle/fluid/operators/crop_op.h +++ b/paddle/fluid/operators/crop_op.h @@ -89,12 +89,7 @@ void CropFunction(const framework::ExecutionContext& context) { out_dims[0] = x->dims()[0]; } out->mutable_data(out_dims, context.GetPlace()); - auto x_stride = common::stride(x->dims()); auto offsets = GetOffsets(context); - int64_t offset = 0; - for (size_t i = 0; i < offsets.size(); ++i) { - offset += (x_stride[i] * offsets[i]); - } auto x_tensor = EigenTensor::From(*x); auto out_tensor = EigenTensor::From(*out); diff --git a/paddle/fluid/operators/cross_entropy_op.cc b/paddle/fluid/operators/cross_entropy_op.cc index 3a90012e1763a..cc2b4b4252835 100644 --- a/paddle/fluid/operators/cross_entropy_op.cc +++ b/paddle/fluid/operators/cross_entropy_op.cc @@ -239,7 +239,7 @@ class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { "represents the cross entropy loss."); AddAttr("soft_label", "(bool, default false), a flag indicating whether to " - "interpretant the given labels as soft labels.") + "interpret the given labels as soft labels.") .SetDefault(false); AddAttr("ignore_index", "(int, default -100), Specifies a target value that is" @@ -268,10 +268,10 @@ computation. $Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}$ - Please make sure that in this case the summuation of each row of Label + Please make sure that in this case the summation of each row of Label equals one. -3) One-hot cross-entropy with vecterized Input(Label): +3) One-hot cross-entropy with vectorized Input(Label): As a special case of 2), when each row of Input(Label) has only one non-zero element (equals 1), soft-label cross-entropy degenerates to a one-hot cross-entropy with one-hot label representation. diff --git a/paddle/fluid/operators/cross_entropy_op.h b/paddle/fluid/operators/cross_entropy_op.h index d755cb1639572..5b76cc9a65a2b 100644 --- a/paddle/fluid/operators/cross_entropy_op.h +++ b/paddle/fluid/operators/cross_entropy_op.h @@ -62,9 +62,9 @@ class CrossEntropyOpKernel : public framework::OpKernel { }; template -class XeSoftlabelGradFunctor { +class XeSoftLabelGradFunctor { public: - XeSoftlabelGradFunctor(T* dx, + XeSoftLabelGradFunctor(T* dx, const T* dy, // NOLINT const T* x, // NOLINT const T* label, // NOLINT @@ -137,7 +137,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { int64_t class_num = x->dims()[rank - 1]; int64_t ignore_index = ctx.Attr("ignore_index"); if (ctx.Attr("soft_label")) { - XeSoftlabelGradFunctor functor(dx_data, + XeSoftLabelGradFunctor functor(dx_data, dy->data(), x->data(), label->data(), diff --git a/paddle/fluid/operators/cuda_graph_with_in_out.h b/paddle/fluid/operators/cuda_graph_with_in_out.h index 3f65450d30c0e..7547bdd436395 100644 --- a/paddle/fluid/operators/cuda_graph_with_in_out.h +++ b/paddle/fluid/operators/cuda_graph_with_in_out.h @@ -16,21 +16,21 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #endif namespace paddle { namespace operators { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) class CUDAGraphWithInOuts { public: template CUDAGraphWithInOuts(Callable &&callable, platform::CUDAPlace place, const std::vector &in_ptrs, - cudaStreamCaptureMode mode, + gpuStreamCaptureMode mode, int64_t pool_id) { in_indices_.resize(in_ptrs.size()); ins_.reserve(in_ptrs.size()); @@ -102,7 +102,7 @@ static std::unique_ptr CaptureCUDAGraph( const framework::ExecutionContext &ctx, const std::vector &input_names, const std::vector &output_names, - cudaStreamCaptureMode mode, + gpuStreamCaptureMode mode, int64_t pool_id) { std::vector inputs; for (const auto &name : input_names) { diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index e61512924f81d..a082dbbcb8bcb 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -152,7 +152,7 @@ the cell input ct-1 and the previous layer input xt given matrices W, R and bias which is computed based on the current input and the previous hidden state. Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, -X represensts a matrix multiplication +X represents a matrix multiplication )DOC"); diff --git a/paddle/fluid/operators/cudnn_rnn_cache.h b/paddle/fluid/operators/cudnn_rnn_cache.h index 6cd7160e0ae26..9b6774af5832a 100644 --- a/paddle/fluid/operators/cudnn_rnn_cache.h +++ b/paddle/fluid/operators/cudnn_rnn_cache.h @@ -22,7 +22,8 @@ limitations under the License. */ namespace paddle { namespace operators { -struct CudnnRNNCache { +class CudnnRNNCache { + public: CudnnRNNCache() { x_desc_ = NULL; y_desc_ = NULL; @@ -30,8 +31,13 @@ struct CudnnRNNCache { ~CudnnRNNCache() { release(); } cudnnRNNDescriptor_t rnn_desc_; +#if CUDNN_VERSION >= 90000 + cudnnRNNDataDescriptor_t x_desc_; + cudnnRNNDataDescriptor_t y_desc_; +#else cudnnTensorDescriptor_t *x_desc_; cudnnTensorDescriptor_t *y_desc_; +#endif cudnnTensorDescriptor_t hx_desc_; cudnnTensorDescriptor_t cx_desc_; @@ -93,7 +99,37 @@ struct CudnnRNNCache { const auto numDirections = is_bidirec_ ? 2 : 1; auto cudnn_size = cudnn_type == CUDNN_DATA_FLOAT ? sizeof(float) : sizeof(double); +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnCreateRNNDataDescriptor(&x_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnCreateRNNDataDescriptor(&y_desc_)); + + std::vector seq_length_array(batch_size_); + for (int i = 0; i < batch_size_; ++i) { + seq_length_array[i] = seq_length_; + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDataDescriptor( + x_desc_, + cudnn_type, + CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED, + seq_length_, + batch_size_, + input_size_, + reinterpret_cast(seq_length_array.data()), + nullptr)); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDataDescriptor( + y_desc_, + cudnn_type, + CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED, + seq_length_, + batch_size_, + hidden_size_ * numDirections, + reinterpret_cast(seq_length_array.data()), + nullptr)); +#else x_desc_ = new cudnnTensorDescriptor_t[seq_length_]; y_desc_ = new cudnnTensorDescriptor_t[seq_length_]; std::vector dims = {batch_size_, input_size_, 1}; @@ -114,6 +150,7 @@ struct CudnnRNNCache { PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor( y_desc_[i], cudnn_type, 3, dims_y.data(), strides_y.data())); } +#endif std::vector dims_hx = { num_layers_ * numDirections, batch_size_, hidden_size_}; @@ -185,7 +222,24 @@ struct CudnnRNNCache { PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_)); - +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v8( + rnn_desc_, + CUDNN_RNN_ALGO_STANDARD, + CUDNN_LSTM, + CUDNN_RNN_DOUBLE_BIAS, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, + CUDNN_LINEAR_INPUT, + cudnn_type, + cudnn_type, + CUDNN_DEFAULT_MATH, + input_size_, + hidden_size_, + hidden_size_, + num_layers_, + dropout_desc_, + CUDNN_RNN_PADDED_IO_ENABLED)); +#else PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6( handle, rnn_desc_, @@ -197,15 +251,19 @@ struct CudnnRNNCache { CUDNN_LSTM, CUDNN_RNN_ALGO_STANDARD, cudnn_type)); - +#endif PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnCreateFilterDescriptor(&w_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_)); +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNWeightSpaceSize( + handle, rnn_desc_, &weights_size_)); +#else PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNParamsSize( handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type)); - +#endif PADDLE_ENFORCE_EQ( weights_size_, cudnn_size * weight_numel, @@ -220,18 +278,32 @@ struct CudnnRNNCache { w_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w)); PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor( dw_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w)); - +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnGetRNNTempSpaceSizes(handle, + rnn_desc_, + CUDNN_FWD_MODE_TRAINING, + x_desc_, + &workspace_size_, + reserve_size_)); +#else PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize( handle, rnn_desc_, seq_length_, x_desc_, &workspace_size_)); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnGetRNNTrainingReserveSize( handle, rnn_desc_, seq_length_, x_desc_, reserve_size_)); - +#endif workspace_data_.Resize({static_cast(workspace_size_)}); workspace_data_.mutable_data(place); } void release() { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnDestroyRNNDataDescriptor(x_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cudnnDestroyRNNDataDescriptor(y_desc_)); +#else for (size_t i = 0; i < seq_length_; ++i) { PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i])); @@ -241,6 +313,7 @@ struct CudnnRNNCache { delete[] x_desc_; delete[] y_desc_; +#endif PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_)); diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index 9573809d6c7cc..d63197af754f2 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -120,7 +120,7 @@ class CConcatOpCustomDeviceKernel : public framework::OpKernel { reinterpret_cast(const_cast(send_buff)), recv_buff, send_numel, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), comm->GetXcclComm(), stream); } @@ -465,10 +465,10 @@ class CSoftmaxWithCrossEntropyGradCustomDeviceKernel framework::TensorCopy( *softmax, context.GetPlace(), context.device_context(), logit_grad); } - const auto sofrmax_dims = softmax->dims(); - const int axis = sofrmax_dims.size() - 1; - const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims); - const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims); + const auto softmax_dims = softmax->dims(); + const int axis = softmax_dims.size() - 1; + const int N = phi::funcs::SizeToAxis(axis, softmax_dims); + const int D = phi::funcs::SizeFromAxis(axis, softmax_dims); const auto& label_type = labels->dtype(); if (label_type == phi::DataType::INT32 || @@ -514,7 +514,7 @@ class CSoftmaxWithCrossEntropyGradCustomDeviceKernel logit_grad ->ShareDataWith(*reinterpret_cast( logits_grad_out_tensor2.impl().get())) - .Resize(sofrmax_dims); + .Resize(softmax_dims); } else { PADDLE_THROW(phi::errors::Unavailable( "CustomDevice c_softmax_with_cross_entropy_grad " @@ -560,7 +560,7 @@ class CAllReduceOpCustomDeviceKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto place = ctx.GetPlace(); - auto dtype = phi::ccl::ToCCLDataType(in->dtype()); + auto dtype = in->dtype(); int64_t numel = in->numel(); const void* sendbuff = in->data(); out->Resize(in->dims()); @@ -651,7 +651,7 @@ class CBroadcastOpCustomDeviceKernel : public framework::OpKernel { } int numel = x->numel(); - auto dtype = phi::ccl::ToCCLDataType(x->dtype()); + auto dtype = x->dtype(); if (root == comm->GetRank()) { phi::DeviceManager::CCLBroadcast(place.GetDeviceType(), const_cast(x->data()), @@ -712,7 +712,7 @@ class BarrierOpCustomDeviceKernel : public framework::OpKernel { const_cast(sendbuff), recvbuff, numel, - phi::ccl::ToCCLDataType(in->dtype()), + in->dtype(), phi::ccl::CCLReduceOp::SUM, comm->GetXcclComm(), *stream); @@ -853,7 +853,7 @@ class AssignPosCustomDeviceKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { // assign pos decides which tokens should be fetched belong to specially - // counter orderingly. + // counter orderly. auto cum_count = context.Input( "cum_count"); // (counter number) int32 | int64 auto numbers = context.Input( @@ -1059,7 +1059,7 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel { place.GetDeviceType(), reinterpret_cast(recv_buf + recv_ptr * in_feat), cpu_global_count_data[idx] * in_feat, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), j, comm->GetXcclComm(), *stream); @@ -1075,7 +1075,7 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel { const_cast(reinterpret_cast( send_buf + expert_ptr[idx] * in_feat)), cpu_local_count_data[idx] * in_feat, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), j, comm->GetXcclComm(), *stream); @@ -1098,7 +1098,7 @@ class GlobalScatterOpCustomDeviceKernel : public framework::OpKernel { place.GetDeviceType(), reinterpret_cast(recv_buf + recv_ptr * in_feat), cpu_global_count_data[idx] * in_feat, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), j, comm->GetXcclComm(), *stream); @@ -1269,7 +1269,7 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel { phi::DeviceManager::CCLRecv(place.GetDeviceType(), recv_buf + expert_ptr[idx] * in_feat, cpu_local_count_data[idx] * in_feat, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), j, comm->GetXcclComm(), *stream); @@ -1284,7 +1284,7 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel { const_cast(reinterpret_cast( send_buf + send_ptr * in_feat)), cpu_global_count_data[idx] * in_feat, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), j, comm->GetXcclComm(), *stream); @@ -1305,7 +1305,7 @@ class GlobalGatherOpCustomDeviceKernel : public framework::OpKernel { phi::DeviceManager::CCLRecv(place.GetDeviceType(), recv_buf + expert_ptr[idx] * in_feat, cpu_local_count_data[idx] * in_feat, - phi::ccl::ToCCLDataType(x->dtype()), + x->dtype(), j, comm->GetXcclComm(), *stream); diff --git a/paddle/fluid/operators/cvm_op.cc b/paddle/fluid/operators/cvm_op.cc index 578a59130495a..1e414ff217c2f 100644 --- a/paddle/fluid/operators/cvm_op.cc +++ b/paddle/fluid/operators/cvm_op.cc @@ -127,7 +127,7 @@ class CVMOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(LodTensor, default LodTensor), a 2-D tensor with shape " "[N x D]," - " where N is the batch size and D is the emebdding dim. "); + " where N is the batch size and D is the embedding dim. "); AddInput("CVM", "(Tensor), a 2-D Tensor with shape [N x 2], where N is the batch " "size, 2 is show and click."); diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 32cc8b49cd007..cc3a224a7e862 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -81,28 +81,28 @@ class DataNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize").size(), 1UL, platform::errors::InvalidArgument( - "The input dim of BatchSize shouold be 1")); + "The input dim of BatchSize should be 1")); PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum").size(), 1UL, platform::errors::InvalidArgument( - "The input dim of BatchSum shouold be 1")); + "The input dim of BatchSum should be 1")); PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum").size(), 1UL, platform::errors::InvalidArgument( - "The input dim of BatchSquareSum shouold be 1")); + "The input dim of BatchSquareSum should be 1")); if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSize")[0], C, platform::errors::InvalidArgument( - "The input dim[0] of BatchSize shouold be C")); + "The input dim[0] of BatchSize should be C")); PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSum")[0], C, platform::errors::InvalidArgument( - "The input dim[0] of BatchSum shouold be C")); + "The input dim[0] of BatchSum should be C")); PADDLE_ENFORCE_EQ(ctx->GetInputDim("BatchSquareSum")[0], C, platform::errors::InvalidArgument( - "The input dim[0] of BatchSqureSum shouold be C")); + "The input dim[0] of BatchSquareSum should be C")); } if (enable_scale_and_shift) { @@ -112,10 +112,10 @@ class DataNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ( scale_dim.size(), 1UL, - platform::errors::InvalidArgument("the dimensionof scale" + platform::errors::InvalidArgument("the dimension of scale" "must equal to 1. But received: " "the shape of scale is [%s], " - "the dimensionof scale is [%d]", + "the dimension of scale is [%d]", scale_dim, scale_dim.size())); PADDLE_ENFORCE_EQ( @@ -691,7 +691,7 @@ class DataNormGradKernel : public framework::OpKernel { } } } else { - // calculate data sum and squre sum + // calculate data sum and square sum Eigen::Array sample_sum(C); Eigen::Array sample_square_sum(C); // calculate data sample sum and square sum @@ -769,7 +769,7 @@ PD_REGISTER_STRUCT_KERNEL( REGISTER_OP_VERSION(data_norm).AddCheckpoint( R"ROC( - upgrad data_norm op by adding scale_w to support scale and shift.)ROC", + upgrade data_norm op by adding scale_w to support scale and shift.)ROC", paddle::framework::compatible::OpVersionDesc().NewInput( "scale_w", - "scale_w is used to do scale duirng data_norm like batchnorm ")); + "scale_w is used to do scale during data_norm like batchnorm ")); diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cc b/paddle/fluid/operators/deformable_psroi_pooling_op.cc index 1e3e52d34e41c..5b339cf96c2b1 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cc +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cc @@ -101,7 +101,7 @@ class DeformablePSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker { "The format is NCHW, where N is the number of ROIs, " "C is the number of output channels, " "H is the height of output, and " - "W is thewidth of output. "); + "W is the width of output. "); AddComment(R"DOC( **DeformablePSROIPooling Operator** DeformablePSROIPooling is a new method based Region of interest pooling diff --git a/paddle/fluid/operators/dgc_op.cc b/paddle/fluid/operators/dgc_op.cc index 06fb2874f2171..7325c4271f9c4 100644 --- a/paddle/fluid/operators/dgc_op.cc +++ b/paddle/fluid/operators/dgc_op.cc @@ -87,7 +87,7 @@ class DGCOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(true); AddAttr>("sparsity", - "(vecotr, float)" + "(vector, float)" "The period sparsity of k_select."); AddAttr("rampup_begin_step", diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 382a3f7ac920b..01df430f52161 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -108,7 +108,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Dropout Operator. -Dropout refers to randomly dropping out units in a nerual network. It is a +Dropout refers to randomly dropping out units in a neural network. It is a regularization technique for reducing overfitting by preventing neuron co-adaption during training. The dropout operator randomly set (according to the given dropout probability) the outputs of some units to zero, while others @@ -175,7 +175,7 @@ class DropoutCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { auto mode = this->Attr("dropout_implementation"); prim::dropout_grad( mask, out_grad, p, is_test, mode, x_grad_p); - VLOG(3) << "Runing dropout_grad composite func"; + VLOG(3) << "Running dropout_grad composite func"; this->RecoverOutputName(x_grad, x_grad_name); } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 191890865fb89..4029be65a00d6 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -107,6 +107,7 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker { op->SetType("elementwise_div_grad_grad"); op->SetInput("Y", this->Input("Y")); op->SetInput("Out", this->Input("Out")); + op->SetInput("Out@GRAD", this->Input(framework::GradVarName("Out"))); op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y"))); op->SetInput("DX", this->Output(framework::GradVarName("X"))); diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h index 2c62dc570ff21..abc89ba75c671 100644 --- a/paddle/fluid/operators/expand_as_v2_op.h +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 4c2dd99265781..bd558ee944359 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -44,10 +44,11 @@ class ExpandOp : public framework::OperatorWithKernel { static_cast(x_dims.size()))); PADDLE_ENFORCE_LE( x_dims.size(), - 6, + MAX_RANK_SUPPORTED, platform::errors::InvalidArgument( "The number of dimensions of the input for Op(expand) " - "must not be greater than 6, but the value received is %d.", + "must not be greater than %d, but the value received is %d.", + MAX_RANK_SUPPORTED, x_dims.size())); std::vector out_shape(x_dims.size()); @@ -98,7 +99,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "(Tensor, default Tensor). A tensor with rank in [1, 8]." "X is the input to be expanded."); AddInput("ExpandTimes", "(Tensor), optional). If provided, expand according to " @@ -106,13 +107,13 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { "expand_times_tensor and expand_times.") .AsDispensable(); AddInput("expand_times_tensor", - "(Tensor Tensor), epxand times for X." + "(Tensor Tensor), expand times for X." "It has a higher priority than expand_times, but a lower priority " "than ExpandTimes") .AsDuplicable() .AsDispensable(); AddOutput("Out", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "(Tensor, default Tensor). A tensor with rank in [1, 8]." "The rank of Output(Out) have the same with Input(X). " "After expanding, size of each dimension of Output(Out) is equal " "to size of the corresponding dimension of Input(X) multiplying " @@ -123,7 +124,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Expand operator tiles the input by given times number. You should set times number for each dimension by providing attribute 'expand_times'. The rank of X -should be in [1, 6]. Please note that size of 'expand_times' must be the same +should be in [1, 8]. Please note that size of 'expand_times' must be the same with X's rank. Following is a using case: Input(X) is a 3-D tensor with shape [2, 3, 1]: [ @@ -165,7 +166,7 @@ class ExpandGradOp : public framework::OperatorWithKernel { out_dims[0], platform::errors::InvalidArgument( "The first dimension size (%d) of Input(Out@GRAD) should be " - "equal to the crroresponding dimension size (%d) of Input(X)", + "equal to the corresponding dimension size (%d) of Input(X)", out_dims[0], x_dims[0])); start_pos = 1u; @@ -181,7 +182,7 @@ class ExpandGradOp : public framework::OperatorWithKernel { out_dims[i], platform::errors::InvalidArgument( "The %uth dimension size (%d) of Input(Out@GRAD) should be " - "equal to the multiplication of the crroresponding dimension " + "equal to the multiplication of the corresponding dimension " "sizes of Input(X) (%d) and expand_times (%d).", i, out_dims[i], diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index 8ff69a537ff7f..3d9fbe883b31b 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace paddle { namespace operators { @@ -43,36 +43,36 @@ inline std::vector get_expand_times( expand_data = cpu_expand_tensor.data(); } #endif - auto vec_epxand_times = + auto vec_expand_times = std::vector(expand_data, expand_data + expand_tensor->numel()); - return vec_epxand_times; + return vec_expand_times; } auto list_expand_times_tensor = ctx.MultiInput("expand_times_tensor"); if (list_expand_times_tensor.size() > 0) { // get tensor from - std::vector vec_epxand_times; + std::vector vec_expand_times; for (size_t i = 0; i < list_expand_times_tensor.size(); ++i) { auto tensor = list_expand_times_tensor[i]; if (platform::is_gpu_place(tensor->place())) { phi::DenseTensor temp; paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_epxand_times.push_back(*temp.data()); + vec_expand_times.push_back(*temp.data()); } #ifdef PADDLE_WITH_XPU else if (platform::is_xpu_place(tensor->place())) { // NOLINT phi::DenseTensor temp; paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_epxand_times.push_back(*temp.data()); + vec_expand_times.push_back(*temp.data()); } #endif else { // NOLINT - vec_epxand_times.push_back(*tensor->data()); + vec_expand_times.push_back(*tensor->data()); } } - return vec_epxand_times; + return vec_expand_times; } else { return ctx.Attr>("expand_times"); } @@ -128,6 +128,12 @@ class ExpandKernel : public framework::OpKernel { case 6: Expand<6>(context); break; + case 7: + Expand<7>(context); + break; + case 8: + Expand<8>(context); + break; } } @@ -249,10 +255,17 @@ class ExpandGradKernel : public framework::OpKernel { case 6: ExpandBackward<6>(context, reshape_dims_vec, reduce_dims_vec); break; + case 7: + ExpandBackward<7>(context, reshape_dims_vec, reduce_dims_vec); + break; + case 8: + ExpandBackward<8>(context, reshape_dims_vec, reduce_dims_vec); + break; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " + "Only support tensor with rank being between 1 and %d. But " "received tensor's rank = %d.", + MAX_RANK_SUPPORTED, dims)); } } diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h index 474ae818617fa..b61cf2dc485e5 100644 --- a/paddle/fluid/operators/expand_v2_op.h +++ b/paddle/fluid/operators/expand_v2_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace paddle { namespace operators { @@ -53,26 +53,26 @@ inline std::vector get_expand_shape( ctx.MultiInput("expand_shapes_tensor"); if (list_expand_shapes_tensor.size() > 0) { // get tensor from - std::vector vec_epxand_shape; + std::vector vec_expand_shape; for (size_t i = 0; i < list_expand_shapes_tensor.size(); ++i) { auto tensor = list_expand_shapes_tensor[i]; if (platform::is_gpu_place(tensor->place())) { phi::DenseTensor temp; paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_epxand_shape.push_back(*temp.data()); + vec_expand_shape.push_back(*temp.data()); } #ifdef PADDLE_WITH_XPU else if (platform::is_xpu_place(tensor->place())) { // NOLINT phi::DenseTensor temp; paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp); - vec_epxand_shape.push_back(*temp.data()); + vec_expand_shape.push_back(*temp.data()); } #endif else { // NOLINT - vec_epxand_shape.push_back(*tensor->data()); + vec_expand_shape.push_back(*tensor->data()); } } - return vec_epxand_shape; + return vec_expand_shape; } else { return ctx.Attr>("shape"); } diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 0515a56d41d5b..a5169892187a2 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -825,7 +825,7 @@ And it will not quantize the input tensor. } }; -class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { +class StraightThroughEstimatorGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -835,11 +835,11 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name, - "StrightThroughEstimatorGradOp"); + "StraightThroughEstimatorGradOp"); OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name, - "StrightThroughEstimatorGradOp"); + "StraightThroughEstimatorGradOp"); ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name)); } @@ -853,13 +853,13 @@ class StrightThroughEstimatorGradOp : public framework::OperatorWithKernel { }; template -class StrightThroughEstimatorMaker : public framework::SingleGradOpMaker { +class StraightThroughEstimatorMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; protected: void Apply(GradOpPtr grad_op) const override { - grad_op->SetType("stright_throuth_estimator_grad"); + grad_op->SetType("straight_through_estimator_grad"); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetAttrMap(this->Attrs()); @@ -888,8 +888,8 @@ REGISTER_OPERATOR( fake_quantize_dequantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp, ops::FakeQuantOrWithDequantAbsMaxOpMaker, - ops::StrightThroughEstimatorMaker, - ops::StrightThroughEstimatorMaker); + ops::StraightThroughEstimatorMaker, + ops::StraightThroughEstimatorMaker); PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_abs_max, CPU, ALL_LAYOUT, @@ -924,8 +924,8 @@ REGISTER_OPERATOR( fake_quantize_dequantize_moving_average_abs_max, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp, ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker, - ops::StrightThroughEstimatorMaker, - ops::StrightThroughEstimatorMaker); + ops::StraightThroughEstimatorMaker, + ops::StraightThroughEstimatorMaker); PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_moving_average_abs_max, CPU, ALL_LAYOUT, @@ -948,28 +948,28 @@ REGISTER_OPERATOR( moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, ops::MovingAverageAbsMaxScaleOpMaker, - ops::StrightThroughEstimatorMaker, - ops::StrightThroughEstimatorMaker); + ops::StraightThroughEstimatorMaker, + ops::StraightThroughEstimatorMaker); PD_REGISTER_STRUCT_KERNEL(moving_average_abs_max_scale, CPU, ALL_LAYOUT, ops::MovingAverageAbsMaxScaleKernel, float) {} -REGISTER_OPERATOR(stright_throuth_estimator_grad, - ops::StrightThroughEstimatorGradOp); -PD_REGISTER_STRUCT_KERNEL(stright_throuth_estimator_grad, +REGISTER_OPERATOR(straight_through_estimator_grad, + ops::StraightThroughEstimatorGradOp); +PD_REGISTER_STRUCT_KERNEL(straight_through_estimator_grad, CPU, ALL_LAYOUT, - ops::StrightThroughEstimatorGradKernel, + ops::StraightThroughEstimatorGradKernel, float) {} REGISTER_OPERATOR( fake_channel_wise_quantize_dequantize_abs_max, ops::FakeChannelWiseQuantizeDequantizeAbsMaxOp, ops::FakeChannelWiseQuantizeDequantizeAbsMaxOpMaker, - ops::StrightThroughEstimatorMaker, - ops::StrightThroughEstimatorMaker); + ops::StraightThroughEstimatorMaker, + ops::StraightThroughEstimatorMaker); PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_dequantize_abs_max, CPU, ALL_LAYOUT, diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index bf990a451eb2d..68ceaca46d04f 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -60,10 +60,10 @@ PD_REGISTER_STRUCT_KERNEL(fake_quantize_dequantize_moving_average_abs_max, ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel, float, float16) {} -PD_REGISTER_STRUCT_KERNEL(stright_throuth_estimator_grad, +PD_REGISTER_STRUCT_KERNEL(straight_through_estimator_grad, GPU, ALL_LAYOUT, - ops::StrightThroughEstimatorGradKernel, + ops::StraightThroughEstimatorGradKernel, float, float16) {} PD_REGISTER_STRUCT_KERNEL(fake_channel_wise_quantize_dequantize_abs_max, diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index dd8675331fce6..6387018d1865e 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -446,7 +446,7 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { }; template -class StrightThroughEstimatorGradKernel : public framework::OpKernel { +class StraightThroughEstimatorGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *d_out = @@ -455,7 +455,7 @@ class StrightThroughEstimatorGradKernel : public framework::OpKernel { auto *d_x = context.Output(x_grad_name); PADDLE_ENFORCE_NOT_NULL(d_x, platform::errors::PreconditionNotMet( - "StrightThroughEstimatorGradKernel " + "StraightThroughEstimatorGradKernel " "doesn't have the output named %s.", x_grad_name)); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index 1263d156ce220..8a27649af864b 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -152,7 +152,7 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { "device") .SetDefault(false); AddAttr("place_type", - "(int, default -1) allow mamually setting place where the " + "(int, default -1) allow manually setting place where the " "variable should be hold. " "-1: not set manually, determine the place by executor. " "0: CPUPlace. " diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu index 157a45c71c16e..a76e93f5cdcf5 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_int8_op.cu @@ -14,6 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/attn_gemm_int8.h" #include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" +#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" namespace paddle { namespace operators { @@ -345,18 +346,18 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { if (time_step) { // generation decoder stage // [2, batch_size, num_head, max_seq_len, head_size] int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask, - cache_kv_out, - &fmha_out, - bsz, - max_seq_len, - num_head, - dim_head, - time_step->data()[0], - 1. / std::sqrt(dim_head)); + phi::fusion::fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + cache_kv_out, + &fmha_out, + bsz, + max_seq_len, + num_head, + dim_head, + time_step->data()[0], + 1. / std::sqrt(dim_head)); } else if (cache_kv_out) { // generation context stage // TODO(wangxi): can remove dropout in inference fmha_compute.ComputeForward(qkv_out, @@ -387,16 +388,16 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { T *cache_k_ptr = cache_kv_data; T *cache_v_ptr = cache_kv_data + cache_k_size; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len, - max_seq_len, - dim_head); + phi::fusion::write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len, + max_seq_len, + dim_head); } else { // not generation // TODO(wangxi): can remove dropout in inference fmha_compute.ComputeForward(qkv_out, @@ -427,10 +428,10 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { quant_round_type, quant_max_bound, quant_min_bound); - AllReduce(output_workspace, - ring_id, - bsz * seq_len * num_head * dim_head, - dev_ctx); + phi::fusion::AllReduce(output_workspace, + ring_id, + bsz * seq_len * num_head * dim_head, + dev_ctx); } else { out_linear_compute.ComputeForward(out_linear_weights[i], &fmha_out, @@ -444,7 +445,7 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { quant_round_type, quant_max_bound, quant_min_bound); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + phi::fusion::AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step4"; @@ -583,12 +584,12 @@ class FusedMultiTransformerINT8OpKernel : public framework::OpKernel { #endif if (pre_layer_norm) { - AllReduce(output_workspace, - ring_id, - bsz * seq_len * num_head * dim_head, - dev_ctx); + phi::fusion::AllReduce(output_workspace, + ring_id, + bsz * seq_len * num_head * dim_head, + dev_ctx); } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + phi::fusion::AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.1"; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index e3158d74df629..75a4c7b275a8a 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -14,1365 +14,1393 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h" -namespace paddle { -namespace operators { +#include "paddle/fluid/framework/op_registry.h" + +#include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" +#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h" +#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h" +#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h" + +namespace phi { +namespace fusion { #if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation. -template -class FusedMultiTransformerOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - auto &dev_ctx = ctx.cuda_device_context(); - - auto *time_step = ctx.Input("TimeStep"); - // 0. input - auto *input_x = ctx.Input("X"); - const auto input_x_dims = input_x->dims(); - int bsz = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int bsz_seq = bsz * seq_len; - const std::string act_method = ctx.Attr("act_method"); - bool remove_padding = false; - auto *sequence_lengths = ctx.Input("SeqLengths"); - if (sequence_lengths) { - remove_padding = true; - } - phi::DenseTensor d_token_tensor; - phi::DenseTensor padding_offset_tensor; - phi::DenseTensor x_remove_padding; - bool encoder_remove_padding = (remove_padding && !time_step); - int token_num = 0; - - // remove padding in encoder - if (encoder_remove_padding) { - // just for encoder - d_token_tensor.Resize({{1}}); - auto *d_token_num = dev_ctx.Alloc( - &d_token_tensor, d_token_tensor.numel() * sizeof(int)); - // alloc the max size of padding_offset_tensor - padding_offset_tensor.Resize({{bsz_seq}}); - dev_ctx.Alloc(&padding_offset_tensor, - padding_offset_tensor.numel() * sizeof(int)); - InvokeGetPaddingOffset(dev_ctx, - &token_num, - d_token_num, - padding_offset_tensor.data(), - sequence_lengths->data(), - bsz, - seq_len); - padding_offset_tensor.Resize({{token_num}}); - x_remove_padding.Resize({{token_num, dim_embed}}); - dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); - InvokeRemovePadding(dev_ctx, - x_remove_padding.data(), - input_x->data(), - padding_offset_tensor.data(), - token_num, - dim_embed); - } else { - token_num = bsz_seq; - } - auto *padding_offset_data = - encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - - // 1. layer norm - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); - - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); - phi::DenseTensor ln_mean, ln_var; - ln_mean.Resize({{token_num}}); - auto *ln_mean_data = - dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{token_num}}); - auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); - - // 2. qkv - // x: qkv's input [batch_size, seq_len, dim_embed] - // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); - const bool trans_qkvw = ctx.Attr("trans_qkvw"); - const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; - // (transA, transB, compute_bias) = (false, trans_qkvw, false) - // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set - // compute_bias as false. - auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); - - phi::DenseTensor qkv_out; - qkv_out.Resize({{token_num, 3, num_head, dim_head}}); - auto *qkv_out_data = - dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); - - // 2.1 rotary - auto *rotary_tensor = ctx.Input("RotaryPosEmb"); - const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); - - // 3. fmha - AttnDropoutParam attn_param( - true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); - auto pre_caches = ctx.MultiInput("PreCaches"); - int cache_offset = 0; - if (pre_caches.size() > 0) { - cache_offset = pre_caches[0]->dims()[3]; +template +void FusedMultiTransformerKernel( + const Context &dev_ctx, + const DenseTensor &x, + const std::vector &ln_scales, + const std::vector &ln_biases, + const std::vector &qkv_weights, + const paddle::optional> &qkv_biases, + const paddle::optional> &cache_kvs, + const paddle::optional> &pre_caches, + const paddle::optional &rotary_tensor, + const paddle::optional &time_step, + const paddle::optional &seq_lengths, + const paddle::optional &src_mask, + const std::vector &out_linear_weights, + const paddle::optional> &out_linear_biases, + const std::vector &ffn_ln_scales, + const std::vector &ffn_ln_biases, + const std::vector &ffn1_weights, + const paddle::optional> &ffn1_biases, + const std::vector &ffn2_weights, + const paddle::optional> &ffn2_biases, + bool pre_layer_norm, + float epsilon, + float dropout_rate, + int rotary_emb_dims, + bool is_test, + const std::string &dropout_implementation, + const std::string &act_method, + bool trans_qkvw, + int ring_id, + std::vector cache_kv_outs, + DenseTensor *out) { + if (cache_kvs) { + for (size_t i = 0; i < cache_kv_outs.size(); i++) { + *(cache_kv_outs[i]) = *(cache_kvs.get()[i]); } + } + using U = phi::funcs::LayerNormParamType; + + auto *rotary_tensor_t = rotary_tensor.get_ptr(); + auto *seq_lengths_t = seq_lengths.get_ptr(); + auto *src_mask_t = src_mask.get_ptr(); + auto *time_step_t = time_step.get_ptr(); + + const auto input_x_dims = x.dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + bool remove_padding = false; + if (seq_lengths_t) { + remove_padding = true; + } + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step_t); + int token_num = 0; + + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({1}); + auto *d_token_num = dev_ctx.template Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({bsz_seq}); + dev_ctx.template Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + seq_lengths_t->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({token_num}); + x_remove_padding.Resize({token_num, dim_embed}); + dev_ctx.template Alloc(&x_remove_padding, + x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + x.data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({token_num}); + auto *ln_mean_data = + dev_ctx.template Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({token_num}); + auto *ln_var_data = + dev_ctx.template Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = + qkv_biases && !qkv_biases.get().empty() && time_step_t == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set + // compute_bias as false. + auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, + false, + trans_qkvw, + token_num, + output_size, + input_size, + /*compute_bias=*/false); + + phi::DenseTensor qkv_out; + qkv_out.Resize({token_num, 3, num_head, dim_head}); + auto *qkv_out_data = + dev_ctx.template Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + int cache_offset = 0; + if (pre_caches && pre_caches.get().size() > 0) { + cache_offset = pre_caches.get()[0]->dims()[3]; + } - auto out_seq_len = seq_len; - if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), - platform::CPUPlace(), - platform::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, - 0, - platform::errors::PreconditionNotMet( - "The value of time_step must > 0, but now is %d", - time_step_value)); - PADDLE_ENFORCE_EQ( - seq_len, - 1, - platform::errors::PreconditionNotMet( - "In decode stage, the seq_len of input must be 1, but now is %d", - seq_len)); - out_seq_len += time_step_value; - } else { - out_seq_len += cache_offset; - } + auto out_seq_len = seq_len; + if (time_step_t) { + PADDLE_ENFORCE_EQ(time_step_t->place(), + phi::CPUPlace(), + phi::errors::PreconditionNotMet( + "The place of input(TimeStep) must be CPUPlace.")); + // cache_seq_len + int time_step_value = time_step_t->data()[0]; + PADDLE_ENFORCE_GT(time_step_value, + 0, + phi::errors::PreconditionNotMet( + "The value of time_step_t must > 0, but now is %d", + time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + phi::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; + } - phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; - q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *q_transpose_out_data = - dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({bsz, num_head, seq_len, dim_head}); + auto *q_transpose_out_data = dev_ctx.template Alloc( + &q_transpose_out, q_transpose_out.numel() * sizeof(T)); - kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); - auto *kv_transpose_out_data = dev_ctx.Alloc( - &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + kv_transpose_out.Resize({2, bsz, num_head, seq_len, dim_head}); + auto *kv_transpose_out_data = dev_ctx.template Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); - qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + qk_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *qk_out_data = + dev_ctx.template Alloc(&qk_out, qk_out.numel() * sizeof(T)); - phi::DenseTensor src_mask_out; - if (cache_offset > 0) { - src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *src_mask_out_data = - dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); - } + phi::DenseTensor src_mask_out; + if (cache_offset > 0) { + src_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *src_mask_out_data = dev_ctx.template Alloc( + &src_mask_out, src_mask_out.numel() * sizeof(T)); + } - // [2, bs, num_head, cache_seq_len + seq_len, head_dim] - phi::DenseTensor pre_cache_kv_out; - if (cache_offset > 0) { - pre_cache_kv_out.Resize( - {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); - auto *pre_cache_kv_out_data = dev_ctx.Alloc( - &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); - } + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.template Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); + } - phi::DenseTensor softmax_out; - phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; - phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *softmax_out_data = - dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - - qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *qktv_out_data = - dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); - auto *fmha_out_data = - dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); - - // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); - int ring_id = ctx.Attr("ring_id"); - // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, hidden_size, false); - - // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, token_num, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; - T *bias_dropout_residual_out_data = nullptr; + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; + softmax_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *softmax_out_data = + dev_ctx.template Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + attn_dropout_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *attn_dropout_mask_out_data = dev_ctx.template Alloc( + &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); + attn_dropout_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *attn_dropout_data_data = dev_ctx.template Alloc( + &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + + qktv_out.Resize({bsz, num_head, seq_len, dim_head}); + auto *qktv_out_data = + dev_ctx.template Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({bsz, seq_len, num_head, dim_head}); + auto *fmha_out_data = + dev_ctx.template Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = phi::fusion::AttnMatMul( + dev_ctx, false, false, token_num, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out.Resize({token_num, dim_embed}); + bias_dropout_residual_out_data = dev_ctx.template Alloc( + &bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + } + dropout_mask_out.Resize({token_num, dim_embed}); + auto *dropout_mask_out_data = dev_ctx.template Alloc( + &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + + // 6. ffn1 matmul + act + bias + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + int dim_ffn = ffn1_weight_dim[1]; + + auto ffn1_cublas_linear = CublasFusedMLP(dev_ctx); + const phi::DDim ffn1_input_shape({token_num, dim_embed}); + ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false); + + phi::DenseTensor ffn1_out; + ffn1_out.Resize({token_num, dim_ffn}); + auto *ffn1_out_data = + dev_ctx.template Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + // 7. ffn2 matmul + bias + residual. + auto ffn2_linear_compute = phi::fusion::AttnMatMul( + dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); + + // 8. ffn2 Layernorm residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + + // calc + auto *from_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + phi::DenseTensor *from_tensor = out; + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({token_num, dim_embed}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({token_num, dim_embed}); + auto *tmp_out_rm_padding_data = dev_ctx.template Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } + auto *tmp_out_data = + dev_ctx.template Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = x.data(); + } + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. + buf0 = &tmp_out; + buf1 = &tmp_out_rm_padding; + } else { if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{token_num, dim_embed}}); - bias_dropout_residual_out_data = - dev_ctx.Alloc(&bias_dropout_residual_out, - bias_dropout_residual_out.numel() * sizeof(T)); - } - dropout_mask_out.Resize({{token_num, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); - - // 6. ffn1 matmul + act + bias - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); - auto ffn1_weight_dim = ffn1_weights[0]->dims(); - - int dim_ffn = ffn1_weight_dim[1]; - - auto ffn1_cublas_linear = CublasFusedMLP(dev_ctx); - const phi::DDim ffn1_input_shape({token_num, dim_embed}); - ffn1_cublas_linear.Setup(ffn1_input_shape, ffn1_weight_dim, false, false); - - phi::DenseTensor ffn1_out; - ffn1_out.Resize({{token_num, dim_ffn}}); - auto *ffn1_out_data = - dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); - - // 7. ffn2 matmul + bias + residual. - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - - auto ffn2_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); - - // 8. ffn2 Layernorm residual bias - DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; - phi::DenseTensor tmp_out, tmp_out_rm_padding; - tmp_out.Resize({{token_num, dim_embed}}); - if (encoder_remove_padding) { - tmp_out_rm_padding.Resize({{token_num, dim_embed}}); - auto *tmp_out_rm_padding_data = dev_ctx.Alloc( - &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); - } - auto *tmp_out_data = - dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - - const T *x_data; - if (encoder_remove_padding) { - x_data = x_remove_padding.data(); - } else { - x_data = input_x->data(); - } - phi::DenseTensor *buf0 = nullptr; - phi::DenseTensor *buf1 = nullptr; - - // step0: x --> buf1 - // step1: buf1 --> buf0 - // step2: buf0 --> buf1 - int layers = qkv_weights.size(); - if (encoder_remove_padding) { - // In the case of variable lengths, the padding needs to be rebuilt - // eventually. So buf0 and buf1 do not need to be changed according to the - // pre_layer_norm and the number of layers. - buf0 = &tmp_out; - buf1 = &tmp_out_rm_padding; - } else { - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out - buf0 = &tmp_out; - buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; - } - } else { + if (layers & 1) { + // odd, set buf1 as out buf0 = &tmp_out; buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; } + } else { + buf0 = &tmp_out; + buf1 = out; } + } - for (int i = 0; i < layers; ++i) { - // step1. layer_norm - if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; + VLOG(0) << "step1"; #endif - // step2. qkv - const phi::DenseTensor *qkv_bias = - qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; - if (!pre_layer_norm && i == 0) { - const phi::DenseTensor *tmp_input_x = - (encoder_remove_padding) ? &x_remove_padding : input_x; - qkv_compute.ComputeForward( - qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); - } else { - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); - } + // step2. qkv + const phi::DenseTensor *qkv_bias = + qkv_biases && !qkv_biases.get().empty() ? qkv_biases.get()[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const phi::DenseTensor *bias = time_step_t ? nullptr : qkv_bias; + if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : &x; + qkv_compute.ComputeForward( + qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); + } else { + qkv_compute.ComputeForward( + qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; + VLOG(0) << "step2"; #endif - // step3. fmha - const phi::DenseTensor *cache_kv = - cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; - - if (time_step) { // generation decoder stage - // [2, batch_size, num_head, max_seq_len, head_size] - int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask, - sequence_lengths, - rotary_tensor, - cache_kv_out, - &fmha_out, - bsz, - max_seq_len, - num_head, - dim_head, - time_step->data()[0], - rotary_emb_dims, - 1. / std::sqrt(dim_head)); - } else if (cache_kv_out) { // generation context stage - const phi::DenseTensor *pre_cache_kv_tensor = - pre_caches.size() > 0 ? pre_caches[i] : nullptr; - phi::DenseTensor *pre_cache_kv_out_tmp = - cache_offset > 0 ? &pre_cache_kv_out : nullptr; - phi::DenseTensor *src_mask_tmp = - cache_offset > 0 ? &src_mask_out : nullptr; - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? sequence_lengths->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - const T *k_ptr = nullptr; - const T *v_ptr = nullptr; - - if (cache_offset > 0) { - // [2, bsz, num_head, cache_offset + seq_len, head_dim] - const T *kv_data = pre_cache_kv_out.data(); - k_ptr = kv_data; - int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; - v_ptr = k_ptr + k_size; - } else { - // [3, bsz, num_head, seq_len, head_dim] - int64_t k_size = bsz * seq_len * num_head * dim_head; - const T *q_ptr = q_transpose_out_data; - k_ptr = kv_transpose_out_data; - v_ptr = k_ptr + k_size; - } - - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - - const int seq_len_tmp = seq_len + cache_offset; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len_tmp, - max_seq_len, - dim_head); - } else { // not generation - // TODO(wangxi): can remove dropout in inference - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? sequence_lengths->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(cache_kv, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); + // step3. fmha + const phi::DenseTensor *cache_kv = + cache_kvs && cache_kvs.get().size() > 0 ? cache_kvs.get()[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step_t) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask_t, + seq_lengths_t, + rotary_tensor_t, + cache_kv_out, + &fmha_out, + bsz, + max_seq_len, + num_head, + dim_head, + time_step_t->data()[0], + rotary_emb_dims, + 1. / std::sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + const phi::DenseTensor *pre_cache_kv_tensor = + pre_caches && pre_caches.get().size() > 0 ? pre_caches.get()[i] + : nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = + cache_offset > 0 ? &pre_cache_kv_out : nullptr; + phi::DenseTensor *src_mask_tmp = + cache_offset > 0 ? &src_mask_out : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor_t->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? seq_lengths_t->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif - if (pre_layer_norm) { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask_t, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; } else { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + const int seq_len_tmp = seq_len + cache_offset; + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len_tmp, + max_seq_len, + dim_head); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor_t->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? seq_lengths_t->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask_t, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; + VLOG(0) << "step3"; #endif - // step5. ln(residual + dropout(input + bias)) - if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); + if (pre_layer_norm) { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif - // inplace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); + + // inplace + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf1->data(), + x_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; + VLOG(0) << "step5"; #endif - // step6. ffn matmul1 - ffn1_cublas_linear.ComputeForward(buf1, - ffn1_weights[i], - ffn1_biases[i], - nullptr, - &ffn1_out, - act_method); + // step6. ffn matmul1 + ffn1_cublas_linear.ComputeForward(buf1, + ffn1_weights[i], + ffn1_biases.get()[i], + nullptr, + &ffn1_out, + act_method); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; + VLOG(0) << "step6"; #endif - // step7. ffn2 matmul - if (pre_layer_norm) { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr); - } else { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr); - } + // step7. ffn2 matmul + if (pre_layer_norm) { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_out, nullptr, buf1, nullptr); + } else { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_out, nullptr, buf0, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; + VLOG(0) << "step7"; #endif - if (pre_layer_norm) { - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7.1"; + VLOG(0) << "step7.1"; #endif - // step8. layer norm + bias_add + residual - if (pre_layer_norm) { - // TODO(wangxi): remove dropout mask in inference - if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - buf0->data(), - ln_mean_data, - ln_var_data); - } else { - ffn2_fused_dropout_helper.ResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - buf1->data(), - dropout_mask_out_data); - } - } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); + // step8. layer norm + bias_add + residual + if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference + if (i < layers - 1) { + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); ffn2_fused_dropout_helper.LayernormResidualDropoutBias( dev_ctx, - buf0->data(), buf1->data(), - ffn2_biases[i]->data(), + bias_dropout_residual_out_data, + ffn2_biases.get()[i]->data(), ln_scale_data, ln_bias_data, - buf0->data(), - dropout_mask_out_data, buf1->data(), + dropout_mask_out_data, + buf0->data(), ln_mean_data, ln_var_data); + } else { + ffn2_fused_dropout_helper.ResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases.get()[i]->data(), + buf1->data(), + dropout_mask_out_data); } + } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases.get()[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8"; + VLOG(0) << "step8"; #endif - if (pre_layer_norm) { - x_data = buf1->data(); - std::swap(buf0, buf1); - } + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); } - if (encoder_remove_padding) { - if (pre_layer_norm) { - InvokeRebuildPadding(dev_ctx, - from_data, - buf0->data(), - padding_offset_data, - token_num, - dim_embed); - } else { - InvokeRebuildPadding(dev_ctx, - from_data, - buf1->data(), - padding_offset_data, - token_num, - dim_embed); - } + } + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); } } -}; +} #else -template -class FusedMultiTransformerOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - using U = LayerNormParamType; - auto &dev_ctx = ctx.cuda_device_context(); - - auto *time_step = ctx.Input("TimeStep"); - // 0. input - auto *input_x = ctx.Input("X"); - const auto input_x_dims = input_x->dims(); - int bsz = input_x_dims[0]; - int seq_len = input_x_dims[1]; - int dim_embed = input_x_dims[2]; - int bsz_seq = bsz * seq_len; - const std::string act_method = ctx.Attr("act_method"); - bool remove_padding = false; - auto *sequence_lengths = ctx.Input("SeqLengths"); - if (sequence_lengths) { - remove_padding = true; - } - phi::DenseTensor d_token_tensor; - phi::DenseTensor padding_offset_tensor; - phi::DenseTensor x_remove_padding; - bool encoder_remove_padding = (remove_padding && !time_step); - int token_num = 0; - - // remove padding in encoder - if (encoder_remove_padding) { - // just for encoder - d_token_tensor.Resize({{1}}); - auto *d_token_num = dev_ctx.Alloc( - &d_token_tensor, d_token_tensor.numel() * sizeof(int)); - // alloc the max size of padding_offset_tensor - padding_offset_tensor.Resize({{bsz_seq}}); - dev_ctx.Alloc(&padding_offset_tensor, - padding_offset_tensor.numel() * sizeof(int)); - InvokeGetPaddingOffset(dev_ctx, - &token_num, - d_token_num, - padding_offset_tensor.data(), - sequence_lengths->data(), - bsz, - seq_len); - padding_offset_tensor.Resize({{token_num}}); - x_remove_padding.Resize({{token_num, dim_embed}}); - dev_ctx.Alloc(&x_remove_padding, x_remove_padding.numel() * sizeof(T)); - InvokeRemovePadding(dev_ctx, - x_remove_padding.data(), - input_x->data(), - padding_offset_tensor.data(), - token_num, - dim_embed); - } else { - token_num = bsz_seq; - } - auto *padding_offset_data = - encoder_remove_padding ? padding_offset_tensor.data() : nullptr; - - // 1. layer norm - const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); - const float epsilon = ctx.Attr("epsilon"); - auto ln_scales = ctx.MultiInput("LnScale"); - auto ln_biases = ctx.MultiInput("LnBias"); - - auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); - phi::DenseTensor ln_mean, ln_var; - ln_mean.Resize({{token_num}}); - auto *ln_mean_data = - dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); - ln_var.Resize({{token_num}}); - auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); - - // 2. qkv - // x: qkv's input [batch_size, seq_len, dim_embed] - // y: qkv's weight: [3, num_head, dim_head, dim_embed] - auto qkv_weights = ctx.MultiInput("QKVW"); - auto qkv_biases = ctx.MultiInput("QKVBias"); - const bool trans_qkvw = ctx.Attr("trans_qkvw"); - const auto qkv_w_dims = qkv_weights[0]->dims(); - int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; - int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; - int hidden_size = num_head * dim_head; - int output_size = 3 * hidden_size; - int input_size = dim_embed; - - bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; - // (transA, transB, compute_bias) = (false, trans_qkvw, false) - // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we - // set compute_bias as false. - auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, - false, - trans_qkvw, - token_num, - output_size, - input_size, - /*compute_bias=*/false); - - phi::DenseTensor qkv_out; - qkv_out.Resize({{token_num, 3, num_head, dim_head}}); - auto *qkv_out_data = - dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); - - // 2.1 rotary - auto *rotary_tensor = ctx.Input("RotaryPosEmb"); - const int rotary_emb_dims = ctx.Attr("rotary_emb_dims"); - - // 3. fmha - AttnDropoutParam attn_param( - true, "upscale_in_train", 0.0, true, true, 0, nullptr); - auto fmha_compute = - FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); - auto *src_mask = ctx.Input("SrcMask"); - auto cache_kvs = ctx.MultiInput("CacheKV"); - auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); - // auto *time_step = ctx.Input("TimeStep"); - auto pre_caches = ctx.MultiInput("PreCaches"); - int cache_offset = 0; - if (pre_caches.size() > 0) { - cache_offset = pre_caches[0]->dims()[3]; +template +void FusedMultiTransformerKernel( + const Context &dev_ctx, + const DenseTensor &x, + const std::vector &ln_scales, + const std::vector &ln_biases, + const std::vector &qkv_weights, + const paddle::optional> &qkv_biases, + const paddle::optional> &cache_kvs, + const paddle::optional> &pre_caches, + const paddle::optional &rotary_tensor, + const paddle::optional &time_step, + const paddle::optional &seq_lengths, + const paddle::optional &src_mask, + const std::vector &out_linear_weights, + const paddle::optional> &out_linear_biases, + const std::vector &ffn_ln_scales, + const std::vector &ffn_ln_biases, + const std::vector &ffn1_weights, + const paddle::optional> &ffn1_biases, + const std::vector &ffn2_weights, + const paddle::optional> &ffn2_biases, + bool pre_layer_norm, + float epsilon, + float dropout_rate, + int rotary_emb_dims, + bool is_test, + const std::string &dropout_implementation, + const std::string &act_method, + bool trans_qkvw, + int ring_id, + std::vector cache_kv_outs, + DenseTensor *out) { + if (cache_kvs) { + for (size_t i = 0; i < cache_kv_outs.size(); i++) { + *(cache_kv_outs[i]) = *(cache_kvs.get()[i]); } + } + using U = phi::funcs::LayerNormParamType; + auto *rotary_tensor_t = rotary_tensor.get_ptr(); + auto *seq_lengths_t = seq_lengths.get_ptr(); + auto *src_mask_t = src_mask.get_ptr(); + auto *time_step_t = time_step.get_ptr(); + + // 0. input + const auto input_x_dims = x.dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + bool remove_padding = false; + if (seq_lengths_t) { + remove_padding = true; + } + phi::DenseTensor d_token_tensor; + phi::DenseTensor padding_offset_tensor; + phi::DenseTensor x_remove_padding; + bool encoder_remove_padding = (remove_padding && !time_step_t); + int token_num = 0; + + // remove padding in encoder + if (encoder_remove_padding) { + // just for encoder + d_token_tensor.Resize({1}); + auto *d_token_num = dev_ctx.template Alloc( + &d_token_tensor, d_token_tensor.numel() * sizeof(int)); + // alloc the max size of padding_offset_tensor + padding_offset_tensor.Resize({bsz_seq}); + dev_ctx.template Alloc(&padding_offset_tensor, + padding_offset_tensor.numel() * sizeof(int)); + InvokeGetPaddingOffset(dev_ctx, + &token_num, + d_token_num, + padding_offset_tensor.data(), + seq_lengths_t->data(), + bsz, + seq_len); + padding_offset_tensor.Resize({token_num}); + x_remove_padding.Resize({token_num, dim_embed}); + dev_ctx.template Alloc(&x_remove_padding, + x_remove_padding.numel() * sizeof(T)); + InvokeRemovePadding(dev_ctx, + x_remove_padding.data(), + x.data(), + padding_offset_tensor.data(), + token_num, + dim_embed); + } else { + token_num = bsz_seq; + } + auto *padding_offset_data = + encoder_remove_padding ? padding_offset_tensor.data() : nullptr; + + // 1. layer norm + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, token_num, dim_embed); + phi::DenseTensor ln_mean, ln_var; + ln_mean.Resize({token_num}); + auto *ln_mean_data = + dev_ctx.template Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({token_num}); + auto *ln_var_data = + dev_ctx.template Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = + qkv_biases && !qkv_biases.get().empty() && time_step_t == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we + // set compute_bias as false. + auto qkv_compute = phi::fusion::AttnMatMul(dev_ctx, + false, + trans_qkvw, + token_num, + output_size, + input_size, + /*compute_bias=*/false); + + phi::DenseTensor qkv_out; + qkv_out.Resize({token_num, 3, num_head, dim_head}); + auto *qkv_out_data = + dev_ctx.template Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + int cache_offset = 0; + if (pre_caches && pre_caches.get().size() > 0) { + cache_offset = pre_caches.get()[0]->dims()[3]; + } - auto out_seq_len = seq_len; - if (time_step) { - PADDLE_ENFORCE_EQ(time_step->place(), - platform::CPUPlace(), - platform::errors::PreconditionNotMet( - "The place of input(TimeStep) must be CPUPlace.")); - // cache_seq_len - int time_step_value = time_step->data()[0]; - PADDLE_ENFORCE_GT(time_step_value, - 0, - platform::errors::PreconditionNotMet( - "The value of time_step must > 0, but now is %d", - time_step_value)); - PADDLE_ENFORCE_EQ( - seq_len, - 1, - platform::errors::PreconditionNotMet( - "In decode stage, the seq_len of input must be 1, but now is %d", - seq_len)); - out_seq_len += time_step_value; - } else { - out_seq_len += cache_offset; - } + auto out_seq_len = seq_len; + if (time_step_t) { + PADDLE_ENFORCE_EQ(time_step_t->place(), + phi::CPUPlace(), + phi::errors::PreconditionNotMet( + "The place of input(TimeStep) must be CPUPlace.")); + // cache_seq_len + int time_step_value = time_step_t->data()[0]; + PADDLE_ENFORCE_GT(time_step_value, + 0, + phi::errors::PreconditionNotMet( + "The value of time_step_t must > 0, but now is %d", + time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + phi::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; + } - phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; - q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *q_transpose_out_data = - dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + phi::DenseTensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({bsz, num_head, seq_len, dim_head}); + auto *q_transpose_out_data = dev_ctx.template Alloc( + &q_transpose_out, q_transpose_out.numel() * sizeof(T)); - kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); - auto *kv_transpose_out_data = dev_ctx.Alloc( - &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + kv_transpose_out.Resize({2, bsz, num_head, seq_len, dim_head}); + auto *kv_transpose_out_data = dev_ctx.template Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); - qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + qk_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *qk_out_data = + dev_ctx.template Alloc(&qk_out, qk_out.numel() * sizeof(T)); - phi::DenseTensor src_mask_out; - if (cache_offset > 0) { - src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *src_mask_out_data = - dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); - } + phi::DenseTensor src_mask_out; + if (cache_offset > 0) { + src_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *src_mask_out_data = dev_ctx.template Alloc( + &src_mask_out, src_mask_out.numel() * sizeof(T)); + } - // [2, bs, num_head, cache_seq_len + seq_len, head_dim] - phi::DenseTensor pre_cache_kv_out; - if (cache_offset > 0) { - pre_cache_kv_out.Resize( - {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); - auto *pre_cache_kv_out_data = dev_ctx.Alloc( - &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); - } + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + phi::DenseTensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.template Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); + } - phi::DenseTensor softmax_out; - phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; - phi::DenseTensor qktv_out, fmha_out; - softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *softmax_out_data = - dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); - - attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_mask_out_data = dev_ctx.Alloc( - &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); - attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); - auto *attn_dropout_data_data = dev_ctx.Alloc( - &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); - - qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); - auto *qktv_out_data = - dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); - fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); - auto *fmha_out_data = - dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); - - // 4. out_linear - auto out_linear_weights = ctx.MultiInput("OutLinearW"); - auto out_linear_biases = ctx.MultiInput("OutLinearBias"); - int ring_id = ctx.Attr("ring_id"); - // (transA, transB, compute_bias) = (false, false, false) - auto out_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, hidden_size, false); - - // 5. ln(residual + bias) - DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( - dev_ctx, token_num, dim_embed, dropout_param2, epsilon); - auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); - auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); - phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; - T *bias_dropout_residual_out_data = nullptr; + phi::DenseTensor softmax_out; + phi::DenseTensor attn_dropout_mask_out, attn_dropout_out; + phi::DenseTensor qktv_out, fmha_out; + softmax_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *softmax_out_data = + dev_ctx.template Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + attn_dropout_mask_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *attn_dropout_mask_out_data = dev_ctx.template Alloc( + &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); + attn_dropout_out.Resize({bsz, num_head, seq_len, out_seq_len}); + auto *attn_dropout_data_data = dev_ctx.template Alloc( + &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + + qktv_out.Resize({bsz, num_head, seq_len, dim_head}); + auto *qktv_out_data = + dev_ctx.template Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({bsz, seq_len, num_head, dim_head}); + auto *fmha_out_data = + dev_ctx.template Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = phi::fusion::AttnMatMul( + dev_ctx, false, false, token_num, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, token_num, dim_embed, dropout_param2, epsilon); + phi::DenseTensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out.Resize({token_num, dim_embed}); + bias_dropout_residual_out_data = dev_ctx.template Alloc( + &bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + } + dropout_mask_out.Resize({token_num, dim_embed}); + auto *dropout_mask_out_data = dev_ctx.template Alloc( + &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + + // 6. ffn matmul1 + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + int dim_ffn = ffn1_weight_dim[1]; + auto ffn1_linear_compute = phi::fusion::AttnMatMul( + dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); + phi::DenseTensor ffn1_out; + ffn1_out.Resize({token_num, dim_ffn}); + auto *ffn1_out_data = + dev_ctx.template Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + // 7. ffn act + bias + DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutHelper fused_act_dropout_helper( + dev_ctx, token_num, dim_ffn, ffn1_dropout_param); + phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; + ffn1_dropout_out.Resize({token_num, dim_ffn}); + auto *ffn1_dropout_out_data = dev_ctx.template Alloc( + &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); + ffn1_dropout_mask.Resize({token_num, dim_ffn}); + auto *ffn1_dropout_mask_data = dev_ctx.template Alloc( + &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); + + // 8. ffn2 matmul + auto ffn2_linear_compute = phi::fusion::AttnMatMul( + dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); + + // 9. ffn2 residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( + dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); + + // calc + auto *from_data = dev_ctx.template Alloc(out, out->numel() * sizeof(T)); + phi::DenseTensor *from_tensor = out; + phi::DenseTensor tmp_out, tmp_out_rm_padding; + tmp_out.Resize({token_num, dim_embed}); + if (encoder_remove_padding) { + tmp_out_rm_padding.Resize({token_num, dim_embed}); + auto *tmp_out_rm_padding_data = dev_ctx.template Alloc( + &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); + } + auto *tmp_out_data = + dev_ctx.template Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + + const T *x_data; + if (encoder_remove_padding) { + x_data = x_remove_padding.data(); + } else { + x_data = x.data(); + } + phi::DenseTensor *buf0 = nullptr; + phi::DenseTensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (encoder_remove_padding) { + // In the case of variable lengths, the padding needs to be rebuilt + // eventually. So buf0 and buf1 do not need to be changed according to the + // pre_layer_norm and the number of layers. + buf0 = &tmp_out; + buf1 = &tmp_out_rm_padding; + } else { if (pre_layer_norm) { - bias_dropout_residual_out.Resize({{token_num, dim_embed}}); - bias_dropout_residual_out_data = - dev_ctx.Alloc(&bias_dropout_residual_out, - bias_dropout_residual_out.numel() * sizeof(T)); - } - dropout_mask_out.Resize({{token_num, dim_embed}}); - auto *dropout_mask_out_data = dev_ctx.Alloc( - &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); - - // 6. ffn matmul1 - auto ffn1_weights = ctx.MultiInput("FFN1Weight"); - auto ffn1_biases = ctx.MultiInput("FFN1Bias"); - auto ffn1_weight_dim = ffn1_weights[0]->dims(); - - int dim_ffn = ffn1_weight_dim[1]; - auto ffn1_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_ffn, dim_embed, false); - phi::DenseTensor ffn1_out; - ffn1_out.Resize({{token_num, dim_ffn}}); - auto *ffn1_out_data = - dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); - - // 7. ffn act + bias - DropoutParam ffn1_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutHelper fused_act_dropout_helper( - dev_ctx, token_num, dim_ffn, ffn1_dropout_param); - phi::DenseTensor ffn1_dropout_out, ffn1_dropout_mask; - ffn1_dropout_out.Resize({{token_num, dim_ffn}}); - auto *ffn1_dropout_out_data = dev_ctx.Alloc( - &ffn1_dropout_out, ffn1_dropout_out.numel() * sizeof(T)); - ffn1_dropout_mask.Resize({{token_num, dim_ffn}}); - auto *ffn1_dropout_mask_data = dev_ctx.Alloc( - &ffn1_dropout_mask, ffn1_dropout_mask.numel() * sizeof(uint8_t)); - - // 8. ffn2 matmul - auto ffn2_weights = ctx.MultiInput("FFN2Weight"); - auto ffn2_biases = ctx.MultiInput("FFN2Bias"); - auto ffn2_linear_compute = phi::fusion::AttnMatMul( - dev_ctx, false, false, token_num, dim_embed, dim_ffn, false); - - // 9. ffn2 residual bias - DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); - FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( - dev_ctx, token_num, dim_embed, ffn2_dropout_param, epsilon); - - // calc - auto *out = ctx.Output("Out"); - auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); - phi::DenseTensor *from_tensor = out; - phi::DenseTensor tmp_out, tmp_out_rm_padding; - tmp_out.Resize({{token_num, dim_embed}}); - if (encoder_remove_padding) { - tmp_out_rm_padding.Resize({{token_num, dim_embed}}); - auto *tmp_out_rm_padding_data = dev_ctx.Alloc( - &tmp_out_rm_padding, tmp_out_rm_padding.numel() * sizeof(T)); - } - auto *tmp_out_data = - dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); - - const T *x_data; - if (encoder_remove_padding) { - x_data = x_remove_padding.data(); - } else { - x_data = input_x->data(); - } - phi::DenseTensor *buf0 = nullptr; - phi::DenseTensor *buf1 = nullptr; - - // step0: x --> buf1 - // step1: buf1 --> buf0 - // step2: buf0 --> buf1 - int layers = qkv_weights.size(); - if (encoder_remove_padding) { - // In the case of variable lengths, the padding needs to be rebuilt - // eventually. So buf0 and buf1 do not need to be changed according to the - // pre_layer_norm and the number of layers. - buf0 = &tmp_out; - buf1 = &tmp_out_rm_padding; - } else { - if (pre_layer_norm) { - if (layers & 1) { - // odd, set buf1 as out - buf0 = &tmp_out; - buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; - } - } else { + if (layers & 1) { + // odd, set buf1 as out buf0 = &tmp_out; buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; } + } else { + buf0 = &tmp_out; + buf1 = out; } + } - for (int i = 0; i < layers; ++i) { - // step1. layer_norm - if (i == 0 && pre_layer_norm) { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - // TODO(wangxi): can remove mean var in inference - ln_compute.ComputeForward(x_data, - ln_scale_data, - ln_bias_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step1"; + VLOG(0) << "step1"; #endif - // step2. qkv - const phi::DenseTensor *qkv_bias = - qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; - // NOTE: in decoder stage, bias is fused in fmha - const phi::DenseTensor *bias = time_step ? nullptr : qkv_bias; - if (!pre_layer_norm && i == 0) { - const phi::DenseTensor *tmp_input_x = - (encoder_remove_padding) ? &x_remove_padding : input_x; - qkv_compute.ComputeForward( - qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); - } else { - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); - } + // step2. qkv + const phi::DenseTensor *qkv_bias = + qkv_biases && !qkv_biases.get().empty() ? qkv_biases.get()[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const phi::DenseTensor *bias = time_step_t ? nullptr : qkv_bias; + if (!pre_layer_norm && i == 0) { + const phi::DenseTensor *tmp_input_x = + (encoder_remove_padding) ? &x_remove_padding : &x; + qkv_compute.ComputeForward( + qkv_weights[i], tmp_input_x, bias, &qkv_out, &qkv_out); + } else { + qkv_compute.ComputeForward( + qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step2"; + VLOG(0) << "step2"; #endif - // step3. fmha - const phi::DenseTensor *cache_kv = - cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; - phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; - - if (time_step) { // generation decoder stage - // [2, batch_size, num_head, max_seq_len, head_size] - int max_seq_len = cache_kv->dims()[3]; - fmha(dev_ctx, - qkv_out, - *qkv_bias, - *src_mask, - sequence_lengths, - rotary_tensor, - cache_kv_out, - &fmha_out, - bsz, - max_seq_len, - num_head, - dim_head, - time_step->data()[0], - rotary_emb_dims, - 1. / std::sqrt(dim_head)); - } else if (cache_kv_out) { // generation context stage - const phi::DenseTensor *pre_cache_kv_tensor = - pre_caches.size() > 0 ? pre_caches[i] : nullptr; - phi::DenseTensor *pre_cache_kv_out_tmp = - cache_offset > 0 ? &pre_cache_kv_out : nullptr; - phi::DenseTensor *src_mask_tmp = - cache_offset > 0 ? &src_mask_out : nullptr; - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? sequence_lengths->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - pre_cache_kv_out_tmp, - &qk_out, - src_mask_tmp, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); - const T *k_ptr = nullptr; - const T *v_ptr = nullptr; - - if (cache_offset > 0) { - // [2, bsz, num_head, cache_offset + seq_len, head_dim] - const T *kv_data = pre_cache_kv_out.data(); - k_ptr = kv_data; - int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; - v_ptr = k_ptr + k_size; - } else { - // [3, bsz, num_head, seq_len, head_dim] - int64_t k_size = bsz * seq_len * num_head * dim_head; - const T *q_ptr = q_transpose_out_data; - k_ptr = kv_transpose_out_data; - v_ptr = k_ptr + k_size; - } - - // [2, bsz, num_head, max_seq_len, head_dim] - int max_seq_len = cache_kv_out->dims()[3]; - T *cache_kv_data = cache_kv_out->data(); - int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; - - T *cache_k_ptr = cache_kv_data; - T *cache_v_ptr = cache_kv_data + cache_k_size; - - const int seq_len_tmp = seq_len + cache_offset; - write_cache_kv(dev_ctx, - cache_k_ptr, - cache_v_ptr, - k_ptr, - v_ptr, - bsz, - num_head, - seq_len_tmp, - max_seq_len, - dim_head); - } else { // not generation - // TODO(wangxi): can remove dropout in inference - qkv_bias_add_transpose_split(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - qkv_out_data, - qkv_bias->data(), - padding_offset_data, - token_num, - bsz, - num_head, - seq_len, - dim_head, - compute_bias); - - // q_transpose_out_data [bs, head_num, seq_len, dim_head] - // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] - if (rotary_emb_dims != 0) { - auto *rotary_emb_data = rotary_tensor->data(); - const int *sequence_lengths_data = - encoder_remove_padding ? sequence_lengths->data() : nullptr; - rotary_qk(dev_ctx, - q_transpose_out_data, - kv_transpose_out_data, - q_transpose_out_data, - kv_transpose_out_data, - rotary_emb_data, - sequence_lengths_data, - rotary_emb_dims, - bsz, - num_head, - seq_len, - dim_head); - } - - phi::DenseTensor *tmp_padding_offset_tensor = - encoder_remove_padding ? &padding_offset_tensor : nullptr; - fmha_compute.ComputeForwardWithoutTranspose(cache_kv, - src_mask, - tmp_padding_offset_tensor, - &q_transpose_out, - &kv_transpose_out, - cache_kv_out, - &qk_out, - nullptr, - &softmax_out, - &attn_dropout_mask_out, - &attn_dropout_out, - &qktv_out, - &fmha_out, - token_num); + // step3. fmha + const phi::DenseTensor *cache_kv = + cache_kvs && cache_kvs.get().size() > 0 ? cache_kvs.get()[i] : nullptr; + phi::DenseTensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step_t) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask_t, + seq_lengths_t, + rotary_tensor_t, + cache_kv_out, + &fmha_out, + bsz, + max_seq_len, + num_head, + dim_head, + time_step_t->data()[0], + rotary_emb_dims, + 1. / std::sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + const phi::DenseTensor *pre_cache_kv_tensor = + pre_caches && pre_caches.get().size() > 0 ? pre_caches.get()[i] + : nullptr; + phi::DenseTensor *pre_cache_kv_out_tmp = + cache_offset > 0 ? &pre_cache_kv_out : nullptr; + phi::DenseTensor *src_mask_tmp = + cache_offset > 0 ? &src_mask_out : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor_t->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? seq_lengths_t->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); } -#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step3"; -#endif - if (pre_layer_norm) { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(pre_cache_kv_tensor, + src_mask_t, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; } else { - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; } + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + const int seq_len_tmp = seq_len + cache_offset; + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len_tmp, + max_seq_len, + dim_head); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + padding_offset_data, + token_num, + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + + // q_transpose_out_data [bs, head_num, seq_len, dim_head] + // kv_transpose_out_data [2, bs, head_num, seq_len, dim_head] + if (rotary_emb_dims != 0) { + auto *rotary_emb_data = rotary_tensor_t->data(); + const int *sequence_lengths_data = + encoder_remove_padding ? seq_lengths_t->data() : nullptr; + rotary_qk(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + q_transpose_out_data, + kv_transpose_out_data, + rotary_emb_data, + sequence_lengths_data, + rotary_emb_dims, + bsz, + num_head, + seq_len, + dim_head); + } + + phi::DenseTensor *tmp_padding_offset_tensor = + encoder_remove_padding ? &padding_offset_tensor : nullptr; + fmha_compute.ComputeForwardWithoutTranspose(cache_kv, + src_mask_t, + tmp_padding_offset_tensor, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out, + token_num); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step4"; + VLOG(0) << "step3"; #endif - // step5. ln(residual + dropout(input + bias)) - if (pre_layer_norm) { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); + if (pre_layer_norm) { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif - // inplace - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - x_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - bias_dropout_residual_out_data, - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } else { - auto *ln_scale_data = ln_scales[i]->data(); - auto *ln_bias_data = ln_biases[i]->data(); - auto *out_linear_bias_data = out_linear_biases[i]->data(); - auto *residual_data = (i == 0 ? x_data : buf1->data()); - fused_dropout_layernorm_helper.LayernormResidualDropoutBias( - dev_ctx, - buf0->data(), - residual_data, - out_linear_bias_data, - ln_scale_data, - ln_bias_data, - buf0->data(), - dropout_mask_out_data, - buf1->data(), - ln_mean_data, - ln_var_data); - } + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); + + // inplace + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf1->data(), + x_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases.get()[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step5"; + VLOG(0) << "step5"; #endif - // step6. ffn matmul1 - ffn1_linear_compute.ComputeForward( - ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); + // step6. ffn matmul1 + ffn1_linear_compute.ComputeForward( + ffn1_weights[i], buf1, nullptr, &ffn1_out, nullptr); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step6"; + VLOG(0) << "step6"; #endif - // step7. act bias - // TODO(wangxi): remove dropout mask in inference - fused_act_dropout_helper.DropoutActBias(dev_ctx, - ffn1_out_data, - ffn1_biases[i]->data(), - act_method, - ffn1_dropout_out_data, - ffn1_dropout_mask_data); + // step7. act bias + // TODO(wangxi): remove dropout mask in inference + fused_act_dropout_helper.DropoutActBias(dev_ctx, + ffn1_out_data, + ffn1_biases.get()[i]->data(), + act_method, + ffn1_dropout_out_data, + ffn1_dropout_mask_data); #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step7"; + VLOG(0) << "step7"; #endif - // step8. ffn matmul2 - if (pre_layer_norm) { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); - } else { - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); - } + // step8. ffn matmul2 + if (pre_layer_norm) { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + } else { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.0"; + VLOG(0) << "step8.0"; #endif - if (pre_layer_norm) { - AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); - } else { - AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); - } + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step8.1"; + VLOG(0) << "step8.1"; #endif - // step9. residual bias - if (pre_layer_norm) { - // TODO(wangxi): remove dropout mask in inference - if (i < layers - 1) { - auto *ln_scale_data = ln_scales[i + 1]->data(); - auto *ln_bias_data = ln_biases[i + 1]->data(); - ffn2_fused_dropout_helper.LayernormResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - ln_scale_data, - ln_bias_data, - buf1->data(), - dropout_mask_out_data, - buf0->data(), - ln_mean_data, - ln_var_data); - } else { - ffn2_fused_dropout_helper.ResidualDropoutBias( - dev_ctx, - buf1->data(), - bias_dropout_residual_out_data, - ffn2_biases[i]->data(), - buf1->data(), - dropout_mask_out_data); - } - } else { - auto *ln_scale_data = ffn_ln_scales[i]->data(); - auto *ln_bias_data = ffn_ln_biases[i]->data(); + // step9. residual bias + if (pre_layer_norm) { + // TODO(wangxi): remove dropout mask in inference + if (i < layers - 1) { + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); ffn2_fused_dropout_helper.LayernormResidualDropoutBias( dev_ctx, - buf0->data(), buf1->data(), - ffn2_biases[i]->data(), + bias_dropout_residual_out_data, + ffn2_biases.get()[i]->data(), ln_scale_data, ln_bias_data, - buf0->data(), - dropout_mask_out_data, buf1->data(), + dropout_mask_out_data, + buf0->data(), ln_mean_data, ln_var_data); + } else { + ffn2_fused_dropout_helper.ResidualDropoutBias( + dev_ctx, + buf1->data(), + bias_dropout_residual_out_data, + ffn2_biases.get()[i]->data(), + buf1->data(), + dropout_mask_out_data); } + } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases.get()[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER - VLOG(0) << "step9"; + VLOG(0) << "step9"; #endif - if (pre_layer_norm) { - x_data = buf1->data(); - std::swap(buf0, buf1); - } + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); } - if (encoder_remove_padding) { - if (pre_layer_norm) { - InvokeRebuildPadding(dev_ctx, - from_data, - buf0->data(), - padding_offset_data, - token_num, - dim_embed); - } else { - InvokeRebuildPadding(dev_ctx, - from_data, - buf1->data(), - padding_offset_data, - token_num, - dim_embed); - } + } + if (encoder_remove_padding) { + if (pre_layer_norm) { + InvokeRebuildPadding(dev_ctx, + from_data, + buf0->data(), + padding_offset_data, + token_num, + dim_embed); + } else { + InvokeRebuildPadding(dev_ctx, + from_data, + buf1->data(), + padding_offset_data, + token_num, + dim_embed); } } -}; - +} #endif // CUDA_VERSION >= 11060 -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL(fused_multi_transformer, - GPU, - ALL_LAYOUT, - ops::FusedMultiTransformerOpKernel, - float, - plat::float16) {} +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_multi_transformer, + GPU, + ALL_LAYOUT, + phi::fusion::FusedMultiTransformerKernel, + float, + phi::dtype::float16) { + kernel->InputAt(8).SetBackend(phi::Backend::CPU); +} diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 0aff1cb5365fc..415a6ba1ffdf3 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -31,8 +31,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/core/distributed/comm_context_manager.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" @@ -49,8 +49,8 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm); COMMON_DECLARE_bool(gemm_use_half_precision_compute_type); -namespace paddle { -namespace operators { +namespace phi { +namespace fusion { // for debug // #define _DEBUG_FUSED_MULTI_TRANSFORMER @@ -75,14 +75,13 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT auto task = pg->AllReduce(in_tensor, out_tensor, opts); task->Wait(); } else { - auto dtype = platform::ToNCCLDataType( - framework::TransToProtoVarType(tensor.dtype())); + auto dtype = phi::ToNCCLDataType(tensor.dtype()); int64_t numel = tensor.numel(); const void *sendbuff = tensor.data(); auto place = ctx.GetPlace(); void *recvbuff = tensor.mutable_data(place); gpuStream_t stream = nullptr; - platform::NCCLComm *comm = nullptr; + paddle::platform::NCCLComm *comm = nullptr; phi::distributed::NCCLCommContext *comm_ctx = nullptr; const auto &comm_context_manager = @@ -92,7 +91,7 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT // Use New Communication Library PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), true, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "You choose to use new communication library by " "setting environment " "variable FLAGS_dynamic_static_unified_comm True. " @@ -103,7 +102,7 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT comm_context_manager.Get(std::to_string(ring_id))); PADDLE_ENFORCE_NE(comm_ctx, nullptr, - platform::errors::Unavailable( + phi::errors::Unavailable( "NCCLCommContext is nullptr, collective op should " "has ring_id attr.")); @@ -111,20 +110,19 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT VLOG(3) << "new comm_context_manager has ring_id" << ring_id; } else { - comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - + comm = paddle::platform::NCCLCommContext::Instance().Get(ring_id, place); stream = ctx.stream(); VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; } if (comm_ctx) { comm_ctx->AllReduce(&tensor, tensor, ncclSum, stream); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce( sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream)); } } #else - PADDLE_THROW(platform::errors::Unimplemented( + PADDLE_THROW(phi::errors::Unimplemented( "PaddlePaddle should compile with NCCL or RCCL when used tensor model " "parallel op.")); #endif @@ -1310,8 +1308,8 @@ void fmha(const phi::GPUContext &dev_ctx, fmha_launch_kernel(params, dev_ctx.stream()); break; default: - PADDLE_THROW(platform::errors::Unimplemented( - "Dim_head = %d is unsupport!", dim_head)); + PADDLE_THROW( + phi::errors::Unimplemented("Dim_head = %d is unsupport!", dim_head)); } } @@ -1431,7 +1429,7 @@ void write_cache_kv(const phi::GPUContext &dev_ctx, PADDLE_ENFORCE_EQ( dim_head % x, 0, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "dim_head=%d must be divisible by vec_size=%d", dim_head, x)); int max_size = max_seq_len * dim_head / x; @@ -1548,7 +1546,7 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, constexpr int PackSize = VEC_16B / sizeof(T); PADDLE_ENFORCE_EQ(size_per_head % PackSize, 0, - platform::errors::PreconditionNotMet( + phi::errors::PreconditionNotMet( "dim_head=%d must be divisible by vec_size=%d", size_per_head, PackSize)); @@ -1711,12 +1709,12 @@ void InvokeGetPaddingOffset(const phi::GPUContext &dev_ctx, const int max_seq_len) { GetPaddingOffset<<<1, 1, 0, dev_ctx.stream()>>>( d_token_num, padding_offset, sequence_lengths, batch_size, max_seq_len); - memory::Copy(platform::CPUPlace(), - h_token_num, - dev_ctx.GetPlace(), - d_token_num, - sizeof(int), - dev_ctx.stream()); + phi::memory_utils::Copy(phi::CPUPlace(), + h_token_num, + dev_ctx.GetPlace(), + d_token_num, + sizeof(int), + dev_ctx.stream()); } template @@ -1785,7 +1783,7 @@ class CublasFusedMLP { cudaDataType_t mat_type = CUDA_R_32F; cudaDataType_t scale_type = CUDA_R_32F; cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - if (std::is_same::value) { + if (std::is_same::value) { mat_type = CUDA_R_16F; if (FLAGS_gemm_use_half_precision_compute_type) { // This option default value is true, it tends to result NaN, but get @@ -1795,7 +1793,7 @@ class CublasFusedMLP { scale_type = CUDA_R_16F; } } - if (std::is_same::value) { + if (std::is_same::value) { mat_type = CUDA_R_16BF; } if (std::is_same::value) { @@ -1804,24 +1802,24 @@ class CublasFusedMLP { compute_type = CUBLAS_COMPUTE_64F; } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate( &operation_desc_, compute_type, scale_type)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc_, mat_type, 1, 1, 1)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &w_desc_, mat_type, 1, 1, 1)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&x_desc_, mat_type, 1, 1, 1)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&w_desc_, mat_type, 1, 1, 1)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( &out_desc_, mat_type, 1, 1, 1)); } ~CublasFusedMLP() { PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(operation_desc_)); + phi::dynload::cublasLtMatmulDescDestroy(operation_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(x_desc_)); + phi::dynload::cublasLtMatrixLayoutDestroy(x_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(w_desc_)); + phi::dynload::cublasLtMatrixLayoutDestroy(w_desc_)); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_)); + phi::dynload::cublasLtMatrixLayoutDestroy(out_desc_)); } void Setup(const phi::DDim &x_shape, @@ -1834,18 +1832,16 @@ class CublasFusedMLP { cublasOperation_t cublas_transA = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t cublas_transB = trans_w ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublas_transA, - sizeof(cublas_transA))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublas_transB, - sizeof(cublas_transB))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &cublas_transA, + sizeof(cublas_transA))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &cublas_transB, + sizeof(cublas_transB))); SetCublasMatrixLayout(x_desc_, trans_x, M, K); SetCublasMatrixLayout(w_desc_, trans_w, K, N); @@ -1867,27 +1863,25 @@ class CublasFusedMLP { if (add_bias) { bias_data = bias->data(); } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_data, - sizeof(bias_data))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_data, + sizeof(bias_data))); cublasLtEpilogue_t epiloque_func = GetEpilogueType(activation, add_bias); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc_, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func, - sizeof(epiloque_func))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func, + sizeof(epiloque_func))); T *residual_data = add_residual ? residual->data() : out_data; cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle(); size_t workspace_size = static_cast(4) * 1024 * 1024; cudaStream_t stream = dev_ctx_.stream(); - memory::allocation::AllocationPtr workspace = memory::Alloc( + phi::Allocator::AllocationPtr workspace = phi::memory_utils::Alloc( dev_ctx_.GetPlace(), workspace_size, phi::Stream(reinterpret_cast(dev_ctx_.stream()))); @@ -1930,23 +1924,22 @@ class CublasFusedMLP { workspace->ptr(), workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - operation_desc_, - alpha, - w_data, - w_desc_, - x_data, - x_desc_, - beta, - residual_data, - out_desc_, - out_data, - out_desc_, - algo, - workspace->ptr(), - workspace_size, - stream)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle, + operation_desc_, + alpha, + w_data, + w_desc_, + x_data, + x_desc_, + beta, + residual_data, + out_desc_, + out_data, + out_desc_, + algo, + workspace->ptr(), + workspace_size, + stream)); } private: @@ -1974,7 +1967,7 @@ class CublasFusedMLP { PADDLE_ENFORCE_EQ( true, false, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The activation attribute of fused_gemm_epilogue op should be" " one of {\"none\", \"relu\", \"gelu\"}. But received %s." "But received activation=%s.", @@ -1987,42 +1980,32 @@ class CublasFusedMLP { const uint64_t cublas_row, const uint64_t cublas_col) { cudaDataType_t mat_type = CUDA_R_32F; - if (std::is_same::value) { + if (std::is_same::value) { mat_type = CUDA_R_16F; } - if (std::is_same::value) { + if (std::is_same::value) { mat_type = CUDA_R_16BF; } if (std::is_same::value) { mat_type = CUDA_R_64F; } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_TYPE, - &mat_type, - sizeof(mat_type))); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_ROWS, - transpose ? &cublas_row : &cublas_col, - sizeof(cublas_row))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_COLS, - transpose ? &cublas_col : &cublas_row, - sizeof(cublas_col))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, CUBLASLT_MATRIX_LAYOUT_TYPE, &mat_type, sizeof(mat_type))); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, + CUBLASLT_MATRIX_LAYOUT_ROWS, + transpose ? &cublas_row : &cublas_col, + sizeof(cublas_row))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, + CUBLASLT_MATRIX_LAYOUT_COLS, + transpose ? &cublas_col : &cublas_row, + sizeof(cublas_col))); int64_t cublas_ld = transpose ? cublas_row : cublas_col; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutSetAttribute( - layout_desc, - CUBLASLT_MATRIX_LAYOUT_LD, - &cublas_ld, - sizeof(cublas_ld))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, CUBLASLT_MATRIX_LAYOUT_LD, &cublas_ld, sizeof(cublas_ld))); } const phi::GPUContext &dev_ctx_; @@ -2036,5 +2019,5 @@ class CublasFusedMLP { } // namespace -} // namespace operators -} // namespace paddle +} // namespace fusion +} // namespace phi diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc b/paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc similarity index 97% rename from paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc rename to paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc index ada14e280a0f3..c85022e08bcc7 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_lstm_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/onednn/fusion_lstm_onednn_op.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h" -#include "paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h" +#include "paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h" #include "paddle/phi/core/expect.h" namespace paddle { @@ -321,7 +321,7 @@ class LSTMMKLDNNHandler } }; -template +template class FusionLSTMMKLDNNKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { @@ -349,8 +349,6 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel { const auto* weight_h = ctx.Input("WeightH"); const auto* bias = ctx.Input("Bias"); auto* hidden = ctx.Output("Hidden"); - auto* cell = ctx.Output("Cell"); - cell = cell; auto x_dims = input->dims(); auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) ? common::flatten_to_2d(x_dims, 1) @@ -473,9 +471,11 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_KERNEL(fusion_lstm, - MKLDNN, - phi::CPUPlace, - ops::FusionLSTMMKLDNNKernel, - ops::FusionLSTMMKLDNNKernel, - ops::FusionLSTMMKLDNNKernel); + +PD_REGISTER_STRUCT_KERNEL(fusion_lstm, + OneDNN, + ONEDNN, + ops::FusionLSTMMKLDNNKernel, + float, + uint8_t, + paddle::platform::bfloat16) {} diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h b/paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h similarity index 100% rename from paddle/fluid/operators/fused/mkldnn/fusion_rnn_mkldnn.h rename to paddle/fluid/operators/fused/onednn/fusion_rnn_onednn.h diff --git a/paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/onednn/multi_gru_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc rename to paddle/fluid/operators/fused/onednn/multi_gru_onednn_op.cc diff --git a/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc b/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc index bd918924cdf09..16e2261f1afb5 100644 --- a/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc +++ b/paddle/fluid/operators/fused/resnet_basic_block_op_xpu.cc @@ -295,7 +295,7 @@ static inline void xpu_conv2d_grad(xpu::Context* ctx, template class ResNetBasicBlockXPUKernel : public framework::OpKernel { public: - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ( @@ -319,20 +319,23 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { phi::DenseTensor* output = ctx.Output("Y"); auto place = ctx.GetPlace(); - auto x_data = reinterpret_cast(x->data()); - auto conv1_filter_data = reinterpret_cast(filter1->data()); - auto conv2_filter_data = reinterpret_cast(filter2->data()); + auto x_data = reinterpret_cast(x->data()); + auto conv1_filter_data = + reinterpret_cast(filter1->data()); + auto conv2_filter_data = + reinterpret_cast(filter2->data()); auto conv1_output_data = - reinterpret_cast(conv1_output->mutable_data(place)); + reinterpret_cast(conv1_output->mutable_data(place)); auto conv2_input_data = - reinterpret_cast(conv2_input->mutable_data(place)); + reinterpret_cast(conv2_input->mutable_data(place)); auto conv2_output_data = - reinterpret_cast(conv2_output->mutable_data(place)); + reinterpret_cast(conv2_output->mutable_data(place)); auto scale1_data = scale1->data(); auto scale2_data = scale2->data(); auto bias1_data = bias1->data(); auto bias2_data = bias2->data(); - auto output_data = reinterpret_cast(output->mutable_data(place)); + auto output_data = + reinterpret_cast(output->mutable_data(place)); float* conv1_input_max_data = nullptr; float* conv1_filter_max_data = nullptr; @@ -372,18 +375,18 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { int r = XPU_SUCCESS; // 1. short - const XPUT* z_out_data = nullptr; + const XPUType* z_out_data = nullptr; if (attr.has_shortcut) { phi::DenseTensor* conv3_out = ctx.Output("Conv3"); const phi::DenseTensor* filter3 = ctx.Input("Filter3"); auto conv3_filter_data = - reinterpret_cast(filter3->data()); + reinterpret_cast(filter3->data()); auto conv3_output_data = - reinterpret_cast(conv3_out->mutable_data(place)); + reinterpret_cast(conv3_out->mutable_data(place)); - XPUT* conv3_input_l3_data = nullptr; - XPUT* conv3_filter_l3_data = - RAII_GUARD.alloc_l3(attr.conv3_filter_numel); + XPUType* conv3_input_l3_data = nullptr; + XPUType* conv3_filter_l3_data = + RAII_GUARD.alloc_l3_or_gm(attr.conv3_filter_numel); if (attr.find_max) { r = xpu::findmax_copy_fusion(dev_ctx.x_context(), @@ -420,7 +423,7 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { auto bias3_data = bias3->data(); auto scale3_data = scale3->data(); - auto bn3_output_data = RAII_GUARD.alloc(attr.conv3_output_numel); + auto bn3_output_data = RAII_GUARD.alloc(attr.conv3_output_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(bn3_output_data); if (!attr.global_stats) { @@ -438,56 +441,56 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { auto running_mean3_data = running_mean3->mutable_data(place); auto running_var3_data = running_var3->mutable_data(place); - r = xpu::batch_norm_fusion(dev_ctx.x_context(), - conv3_output_data, - bn3_output_data, - attr.conv3_output_shape[0], - attr.conv3_output_shape[1], - attr.conv3_output_shape[3], - attr.conv3_output_shape[3], - attr.eps, - attr.momentum, - scale3_data, - bias3_data, - saved_mean3_data, - saved_invstd3_data, - running_mean3_data, - running_var3_data, - true, - nullptr, - xpu::Activation_t::LINEAR, - nullptr, - 0); + r = xpu::batch_norm_fusion(dev_ctx.x_context(), + conv3_output_data, + bn3_output_data, + attr.conv3_output_shape[0], + attr.conv3_output_shape[1], + attr.conv3_output_shape[3], + attr.conv3_output_shape[3], + attr.eps, + attr.momentum, + scale3_data, + bias3_data, + saved_mean3_data, + saved_invstd3_data, + running_mean3_data, + running_var3_data, + true, + nullptr, + xpu::Activation_t::LINEAR, + nullptr, + 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_fusion"); } else { const auto* mean3 = ctx.Input("Mean3"); const auto* var3 = ctx.Input("Var3"); const auto* mean3_data = mean3->data(); const auto* variance3_data = var3->data(); - r = xpu::batch_norm_infer(dev_ctx.x_context(), - conv3_output_data, - bn3_output_data, - attr.conv3_output_shape[0], - attr.conv3_output_shape[1], - attr.conv3_output_shape[2], - attr.conv3_output_shape[3], - attr.eps, - scale3_data, - bias3_data, - mean3_data, - variance3_data, - true); + r = xpu::batch_norm_infer(dev_ctx.x_context(), + conv3_output_data, + bn3_output_data, + attr.conv3_output_shape[0], + attr.conv3_output_shape[1], + attr.conv3_output_shape[2], + attr.conv3_output_shape[3], + attr.eps, + scale3_data, + bias3_data, + mean3_data, + variance3_data, + true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); } - z_out_data = reinterpret_cast(bn3_output_data); + z_out_data = reinterpret_cast(bn3_output_data); } else { z_out_data = x_data; } // 2. conv1 - XPUT* conv1_input_l3_data = nullptr; - XPUT* conv1_filter_l3_data = - RAII_GUARD.alloc_l3(attr.conv1_filter_numel); + XPUType* conv1_input_l3_data = nullptr; + XPUType* conv1_filter_l3_data = + RAII_GUARD.alloc_l3_or_gm(attr.conv1_filter_numel); if (attr.find_max) { r = xpu::findmax_copy_fusion(dev_ctx.x_context(), x_data, @@ -531,49 +534,49 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { auto running_mean1_data = running_mean1->mutable_data(place); auto running_var1_data = running_var1->mutable_data(place); - r = xpu::batch_norm_fusion(dev_ctx.x_context(), - conv1_output_data, - conv2_input_data, - attr.conv1_output_shape[0], - attr.conv1_output_shape[1], - attr.conv1_output_shape[2], - attr.conv1_output_shape[3], - attr.eps, - attr.momentum, - scale1_data, - bias1_data, - saved_mean1_data, - saved_invstd1_data, - running_mean1_data, - running_var1_data, - true, - nullptr, - xpu::Activation_t::RELU, - nullptr, - 0); + r = xpu::batch_norm_fusion(dev_ctx.x_context(), + conv1_output_data, + conv2_input_data, + attr.conv1_output_shape[0], + attr.conv1_output_shape[1], + attr.conv1_output_shape[2], + attr.conv1_output_shape[3], + attr.eps, + attr.momentum, + scale1_data, + bias1_data, + saved_mean1_data, + saved_invstd1_data, + running_mean1_data, + running_var1_data, + true, + nullptr, + xpu::Activation_t::RELU, + nullptr, + 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_fusion"); } else { // bn --> relu - auto bn1_output_data = RAII_GUARD.alloc(attr.conv1_output_numel); + auto bn1_output_data = RAII_GUARD.alloc(attr.conv1_output_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(bn1_output_data); const auto* mean1 = ctx.Input("Mean1"); const auto* var1 = ctx.Input("Var1"); const auto* mean_data = mean1->data(); const auto* variance_data = var1->data(); - r = xpu::batch_norm_infer(dev_ctx.x_context(), - conv1_output_data, - bn1_output_data, - attr.conv1_output_shape[0], - attr.conv1_output_shape[1], - attr.conv1_output_shape[2], - attr.conv1_output_shape[3], - attr.eps, - scale1_data, - bias1_data, - mean_data, - variance_data, - true); + r = xpu::batch_norm_infer(dev_ctx.x_context(), + conv1_output_data, + bn1_output_data, + attr.conv1_output_shape[0], + attr.conv1_output_shape[1], + attr.conv1_output_shape[2], + attr.conv1_output_shape[3], + attr.eps, + scale1_data, + bias1_data, + mean_data, + variance_data, + true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); r = xpu::relu(dev_ctx.x_context(), @@ -584,9 +587,9 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { } // 4. conv2 - XPUT* conv2_input_l3_data = nullptr; - XPUT* conv2_filter_l3_data = - RAII_GUARD.alloc_l3(attr.conv2_filter_numel); + XPUType* conv2_input_l3_data = nullptr; + XPUType* conv2_filter_l3_data = + RAII_GUARD.alloc_l3_or_gm(attr.conv2_filter_numel); if (attr.find_max) { phi::DenseTensor* max_input2 = ctx.Output("MaxInput2"); phi::DenseTensor* max_filter2 = @@ -637,59 +640,59 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { auto running_mean2_data = running_mean2->mutable_data(place); auto running_var2_data = running_var2->mutable_data(place); - r = xpu::batch_norm_fusion(dev_ctx.x_context(), - conv2_output_data, - output_data, - attr.conv2_output_shape[0], - attr.conv2_output_shape[1], - attr.conv2_output_shape[2], - attr.conv2_output_shape[3], - attr.eps, - attr.momentum, - scale2_data, - bias2_data, - saved_mean2_data, - saved_var2_data, - running_mean2_data, - running_var2_data, - true, - z_out_data, - xpu::Activation_t::RELU, - nullptr, - 0); + r = xpu::batch_norm_fusion(dev_ctx.x_context(), + conv2_output_data, + output_data, + attr.conv2_output_shape[0], + attr.conv2_output_shape[1], + attr.conv2_output_shape[2], + attr.conv2_output_shape[3], + attr.eps, + attr.momentum, + scale2_data, + bias2_data, + saved_mean2_data, + saved_var2_data, + running_mean2_data, + running_var2_data, + true, + z_out_data, + xpu::Activation_t::RELU, + nullptr, + 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_fusion"); } else { - auto bn2_out_data = RAII_GUARD.alloc(attr.conv2_output_numel); + auto bn2_out_data = RAII_GUARD.alloc(attr.conv2_output_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(bn2_out_data); const auto* mean2 = ctx.Input("Mean2"); const auto* var2 = ctx.Input("Var2"); const auto* mean_data = mean2->data(); const auto* variance_data = var2->data(); - r = xpu::batch_norm_infer(dev_ctx.x_context(), - conv2_output_data, - bn2_out_data, - attr.conv2_output_shape[0], - attr.conv2_output_shape[1], - attr.conv2_output_shape[2], - attr.conv2_output_shape[3], - attr.eps, - scale2_data, - bias2_data, - mean_data, - variance_data, - true); + r = xpu::batch_norm_infer(dev_ctx.x_context(), + conv2_output_data, + bn2_out_data, + attr.conv2_output_shape[0], + attr.conv2_output_shape[1], + attr.conv2_output_shape[2], + attr.conv2_output_shape[3], + attr.eps, + scale2_data, + bias2_data, + mean_data, + variance_data, + true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_infer"); - r = xpu::add_activation_fusion(dev_ctx.x_context(), - bn2_out_data, - z_out_data, - output_data, - output->numel(), - nullptr, - nullptr, - nullptr, - xpu::Activation_t::RELU); + r = xpu::add_activation_fusion(dev_ctx.x_context(), + bn2_out_data, + z_out_data, + output_data, + output->numel(), + nullptr, + nullptr, + nullptr, + xpu::Activation_t::RELU); PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_activation_fusion"); } } @@ -698,7 +701,7 @@ class ResNetBasicBlockXPUKernel : public framework::OpKernel { template class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { public: - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; void Compute(const framework::ExecutionContext& ctx) const override { PADDLE_ENFORCE_EQ( @@ -774,19 +777,20 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { ResnetBasicBlockGradAttr attr(ctx); auto place = ctx.GetPlace(); - const auto* y_grad_data = reinterpret_cast(y_grad->data()); - const auto* y_data = reinterpret_cast(y->data()); - const auto* x_data = reinterpret_cast(x->data()); + const auto* y_grad_data = + reinterpret_cast(y_grad->data()); + const auto* y_data = reinterpret_cast(y->data()); + const auto* x_data = reinterpret_cast(x->data()); const auto* conv1_output_data = - reinterpret_cast(conv1_out->data()); + reinterpret_cast(conv1_out->data()); const auto* conv1_filter_data = - reinterpret_cast(filter1->data()); + reinterpret_cast(filter1->data()); const auto* conv2_input_data = - reinterpret_cast(conv2_input->data()); + reinterpret_cast(conv2_input->data()); const auto* conv2_output_data = - reinterpret_cast(conv2_out->data()); + reinterpret_cast(conv2_out->data()); const auto* conv2_filter_data = - reinterpret_cast(filter2->data()); + reinterpret_cast(filter2->data()); const auto* scale2_data = scale2->data(); const auto* saved_mean2_data = saved_mean2->data(); @@ -826,77 +830,77 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { // 0. bn2, bn2_fusion grad auto conv2_output_grad_data = - RAII_GUARD.alloc(attr.conv2_output_numel); + RAII_GUARD.alloc(attr.conv2_output_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(conv2_output_grad_data); - XPUT* z_output_grad_data = nullptr; - XPUT* z_grad_data = nullptr; + XPUType* z_output_grad_data = nullptr; + XPUType* z_grad_data = nullptr; if (!attr.has_shortcut) { - z_output_grad_data = RAII_GUARD.alloc(attr.conv1_input_numel); + z_output_grad_data = RAII_GUARD.alloc(attr.conv1_input_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(z_output_grad_data); z_grad_data = z_output_grad_data; } else { - z_output_grad_data = RAII_GUARD.alloc(attr.conv3_output_numel); + z_output_grad_data = RAII_GUARD.alloc(attr.conv3_output_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(z_output_grad_data); - z_grad_data = RAII_GUARD.alloc(attr.conv1_input_numel); + z_grad_data = RAII_GUARD.alloc(attr.conv1_input_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(z_grad_data); } - r = xpu::batch_norm_grad_fusion(dev_ctx.x_context(), - conv2_output_data, - y_data, - y_grad_data, - conv2_output_grad_data, - attr.conv2_output_shape[0], - attr.conv2_output_shape[1], - attr.conv2_output_shape[2], - attr.conv2_output_shape[3], - scale2_data, - saved_mean2_data, - saved_invstd2_data, - scale2_grad_data, - bias2_grad_data, - true, - z_output_grad_data, - xpu::Activation_t::RELU, - nullptr, - 0); + r = xpu::batch_norm_grad_fusion(dev_ctx.x_context(), + conv2_output_data, + y_data, + y_grad_data, + conv2_output_grad_data, + attr.conv2_output_shape[0], + attr.conv2_output_shape[1], + attr.conv2_output_shape[2], + attr.conv2_output_shape[3], + scale2_data, + saved_mean2_data, + saved_invstd2_data, + scale2_grad_data, + bias2_grad_data, + true, + z_output_grad_data, + xpu::Activation_t::RELU, + nullptr, + 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad_fusion"); if (attr.has_shortcut) { // bn3 grad const auto* conv3_output_data = - reinterpret_cast(conv3_out->data()); + reinterpret_cast(conv3_out->data()); const auto* scale3_data = scale3->data(); const auto* saved_mean3_data = saved_mean3->data(); const auto* saved_invstd3_data = saved_invstd3->data(); auto* scale3_grad_data = scale3_grad->mutable_data(place); auto* bias3_grad_data = bias3_grad->mutable_data(place); auto* conv3_output_grad_data = - RAII_GUARD.alloc(attr.conv3_output_numel); - - r = xpu::batch_norm_grad(dev_ctx.x_context(), - conv3_output_data, - z_output_grad_data, - conv3_output_grad_data, - attr.conv3_output_shape[0], - attr.conv3_output_shape[1], - attr.conv3_output_shape[2], - attr.conv3_output_shape[3], - scale3_data, - saved_mean3_data, - saved_invstd3_data, - scale3_grad_data, - bias3_grad_data, - true); + RAII_GUARD.alloc(attr.conv3_output_numel); + + r = xpu::batch_norm_grad(dev_ctx.x_context(), + conv3_output_data, + z_output_grad_data, + conv3_output_grad_data, + attr.conv3_output_shape[0], + attr.conv3_output_shape[1], + attr.conv3_output_shape[2], + attr.conv3_output_shape[3], + scale3_data, + saved_mean3_data, + saved_invstd3_data, + scale3_grad_data, + bias3_grad_data, + true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad"); // conv3 grad auto* conv3_filter_grad_data = - reinterpret_cast(filter3_grad->mutable_data(place)); + reinterpret_cast(filter3_grad->mutable_data(place)); auto* conv3_filter_data = - reinterpret_cast(filter3->data()); + reinterpret_cast(filter3->data()); xpu_conv2d_grad(dev_ctx.x_context(), x_data, conv3_filter_data, @@ -915,9 +919,9 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { // 2. conv2_grad auto* conv2_filter_grad_data = - reinterpret_cast(filter2_grad->mutable_data(place)); + reinterpret_cast(filter2_grad->mutable_data(place)); auto* conv2_input_grad_data = - RAII_GUARD.alloc(attr.conv2_input_numel); + RAII_GUARD.alloc(attr.conv2_input_numel); xpu_conv2d_grad(dev_ctx.x_context(), conv2_input_data, conv2_filter_data, @@ -935,35 +939,36 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { // 3. b1 grad auto* conv1_output_grad_data = - RAII_GUARD.alloc(attr.conv1_output_numel); + RAII_GUARD.alloc(attr.conv1_output_numel); PADDLE_ENFORCE_XDNN_NOT_NULL(conv1_output_grad_data); auto* scale1_grad_data = scale1_grad->mutable_data(ctx.GetPlace()); auto* bias1_grad_data = bias1_grad->mutable_data(ctx.GetPlace()); - r = xpu::batch_norm_grad_fusion(dev_ctx.x_context(), - conv1_output_data, - conv2_input_data, - conv2_input_grad_data, - conv1_output_grad_data, - attr.conv1_output_shape[0], - attr.conv1_output_shape[1], - attr.conv1_output_shape[2], - attr.conv1_output_shape[3], - scale1_data, - saved_mean1_data, - saved_invstd1_data, - scale1_grad_data, - bias1_grad_data, - true, - nullptr, - xpu::Activation_t::RELU, - nullptr, - 0); + r = xpu::batch_norm_grad_fusion(dev_ctx.x_context(), + conv1_output_data, + conv2_input_data, + conv2_input_grad_data, + conv1_output_grad_data, + attr.conv1_output_shape[0], + attr.conv1_output_shape[1], + attr.conv1_output_shape[2], + attr.conv1_output_shape[3], + scale1_data, + saved_mean1_data, + saved_invstd1_data, + scale1_grad_data, + bias1_grad_data, + true, + nullptr, + xpu::Activation_t::RELU, + nullptr, + 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "batch_norm_grad_fusion"); // 4. conv1_grad - auto* x_grad_data = reinterpret_cast(x_grad->mutable_data(place)); + auto* x_grad_data = + reinterpret_cast(x_grad->mutable_data(place)); auto* conv1_filter_grad_data = - reinterpret_cast(filter1_grad->mutable_data(place)); + reinterpret_cast(filter1_grad->mutable_data(place)); xpu_conv2d_grad(dev_ctx.x_context(), x_data, conv1_filter_data, @@ -980,7 +985,7 @@ class ResNetBasicBlockGradXPUKernel : public framework::OpKernel { attr.group); // add z_grad to x_grad - r = xpu::add( + r = xpu::add( dev_ctx.x_context(), x_grad_data, z_grad_data, x_grad_data, x->numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); } diff --git a/paddle/fluid/operators/fused/resnet_unit_op.cc b/paddle/fluid/operators/fused/resnet_unit_op.cc index f1f2628119c15..5827cd3427dee 100644 --- a/paddle/fluid/operators/fused/resnet_unit_op.cc +++ b/paddle/fluid/operators/fused/resnet_unit_op.cc @@ -27,7 +27,7 @@ static framework::DDim GetBitmaskDims(std::vector out_shape) { std::multiplies()) / // NOLINT c; int32_t c_int32_elems = ((c + 63) & ~63) / 32; - int32_t nhw_int32_elems = ((nhw + 31) & ~31); + int32_t nhw_int32_elems = static_cast(((nhw + 31) & ~31)); std::vector bitmask_shape = {nhw_int32_elems, c_int32_elems, 1}; return common::make_ddim(bitmask_shape); } diff --git a/paddle/fluid/operators/fused/unity_build_rule.cmake b/paddle/fluid/operators/fused/unity_build_rule.cmake index 8605cd3cdae85..b7405f93c3585 100644 --- a/paddle/fluid/operators/fused/unity_build_rule.cmake +++ b/paddle/fluid/operators/fused/unity_build_rule.cmake @@ -10,11 +10,7 @@ register_unity_group( fused_embedding_fc_lstm_op.cc fused_embedding_seq_pool_op.cc fusion_lstm_op.cc - fusion_repeated_fc_relu_op.cc - fusion_seqconv_eltadd_relu_op.cc - fusion_seqexpand_concat_fc_op.cc fusion_seqpool_concat_op.cc - fusion_squared_mat_sub_op.cc multi_gru_op.cc mkldnn/multi_gru_mkldnn_op.cc fusion_seqpool_cvm_concat_op.cc) diff --git a/paddle/fluid/operators/fused_token_prune_op.cc b/paddle/fluid/operators/fused_token_prune_op.cc index 021aa95b1fe2c..9fab5c8e7c48d 100644 --- a/paddle/fluid/operators/fused_token_prune_op.cc +++ b/paddle/fluid/operators/fused_token_prune_op.cc @@ -39,7 +39,7 @@ class FusedTokenPruneOpMaker : public framework::OpProtoAndCheckerMaker { "The input of fused_token_prune op, whose shape should be [bsz, " "num_head, " "max_seq_len, max_seq_len] and dtype should be float32/float64." - "Mask is corresponding to Attn's elemnts one by one. Elements of Attn " + "Mask is corresponding to Attn's elements one by one. Elements of Attn " "will be set to zero if their corresponding mask is smaller than 0." "This process happens before sorting X by attn."); @@ -56,7 +56,7 @@ class FusedTokenPruneOpMaker : public framework::OpProtoAndCheckerMaker { "slimmed_seq_len, C]." "The tokens of X will be sorted by Attn firstly and then the " "last (max_seq_len - slimmed_seq_len)" - "tokens will be deleted. SlimmedX is the remainning part of X. " + "tokens will be deleted. SlimmedX is the remaining part of X. " ""); AddOutput( @@ -82,7 +82,7 @@ class FusedTokenPruneOpMaker : public framework::OpProtoAndCheckerMaker { 1. Elements of Attn will be set to zero if their corresponding mask is smaller than 0. 2. The second dimension of X will be sorted by Attn. 3. The last (max_seq_len - slimmed_seq_len) lines of X will be pruned. - 4. The remainning part of sorted X will output. + 4. The remaining part of sorted X will output. )DOC"); } }; diff --git a/paddle/fluid/operators/generator/generate_op.py b/paddle/fluid/operators/generator/generate_op.py index 2f75051d68236..c3d66dbf39a29 100644 --- a/paddle/fluid/operators/generator/generate_op.py +++ b/paddle/fluid/operators/generator/generate_op.py @@ -125,7 +125,8 @@ def process_scalar(op_item, scalar_configs): '"' + attr_item['default_value'] + '"' ) if attr_item['is_support_tensor'] is False: - attr_item['tensor_name'] = scalar_config['tensor_name'] + if 'tensor_name' in scalar_config: + attr_item['tensor_name'] = scalar_config['tensor_name'] def process_int_array(op_item, int_array_configs): diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 0370d6cfba4b3..38a87efec0415 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -369,7 +369,7 @@ def check_op_config(op_entry, op_name): 'traits', 'interfaces', ) - infer_meta_key_set = ('func', 'param', 'spmd_rule') + infer_meta_key_set = ('func', 'param', 'spmd_rule', 'local_shape') kernel_key_set = ( 'func', 'param', diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 9309ca0417f62..933176433e2d7 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -105,7 +105,7 @@ class GRUUnitKernel : public framework::OpKernel { gate_data, frame_size * 3); - // calculate activited gate + // calculate activated gate Eigen::array extents{{batch_size, frame_size}}; Eigen::array u_offsets{{0, 0}}; ActCompute(context.Attr("gate_activation"), diff --git a/paddle/fluid/operators/im2sequence_op.h b/paddle/fluid/operators/im2sequence_op.h index 18e6d429f1b16..5fb689d5b1be0 100644 --- a/paddle/fluid/operators/im2sequence_op.h +++ b/paddle/fluid/operators/im2sequence_op.h @@ -48,13 +48,13 @@ class Im2SequenceKernel : public framework::OpKernel { auto strides = ctx.Attr>("strides"); auto paddings = ctx.Attr>("paddings"); if (ctx.HasInput("Y") && batch_size > 1) { - const phi::DenseTensor* imgrealsize = ctx.Input("Y"); + const phi::DenseTensor* img_real_size = ctx.Input("Y"); auto out_stride = ctx.Attr>("out_stride"); phi::DenseTensor cpu_shape_tensor; paddle::framework::TensorCopySync( - *imgrealsize, platform::CPUPlace(), &cpu_shape_tensor); - std::vector imgreal_h; - std::vector imgreal_w; + *img_real_size, platform::CPUPlace(), &cpu_shape_tensor); + std::vector img_real_h; + std::vector img_real_w; std::vector output_height; std::vector output_width; int result = 0; @@ -72,12 +72,12 @@ class Im2SequenceKernel : public framework::OpKernel { } else { tmp_real_w = tmp_real_w / out_stride[1] + 1; } - imgreal_h.push_back(tmp_real_h); - imgreal_w.push_back(tmp_real_w); + img_real_h.push_back(tmp_real_h); + img_real_w.push_back(tmp_real_w); output_height.push_back(Im2SeqOutputSize( - imgreal_h[i], kernels[0], paddings[0], paddings[2], strides[0])); + img_real_h[i], kernels[0], paddings[0], paddings[2], strides[0])); output_width.push_back(Im2SeqOutputSize( - imgreal_w[i], kernels[1], paddings[1], paddings[3], strides[1])); + img_real_w[i], kernels[1], paddings[1], paddings[3], strides[1])); result += output_height[i] * output_width[i]; } diff --git a/paddle/fluid/operators/is_empty_op.h b/paddle/fluid/operators/is_empty_op.h index 3c9dfbf58fae5..7c78c33621314 100644 --- a/paddle/fluid/operators/is_empty_op.h +++ b/paddle/fluid/operators/is_empty_op.h @@ -29,7 +29,7 @@ class IsEmptyOpKernel : public framework::OpKernel { auto* output_tensor = context.Output("Out"); // Note: is_empty is always executed on CPU and the output data should - // always be allocated for CPUPlace. We reigister CUDA kernel for this op to + // always be allocated for CPUPlace. We register CUDA kernel for this op to // avoid the unnecessary data transform. output_tensor->mutable_data(platform::CPUPlace())[0] = common::product(input_tensor->dims()) == 0; diff --git a/paddle/fluid/operators/isfinite_op.cc b/paddle/fluid/operators/isfinite_op.cc index 0d80a1c36b071..710cdaeb707b6 100644 --- a/paddle/fluid/operators/isfinite_op.cc +++ b/paddle/fluid/operators/isfinite_op.cc @@ -86,7 +86,7 @@ If any X contains Inf or Nan, the Out will generate a indicator. Out = Inf if any X contains Inf, Out = Nan if any X contains Nan, Out = 0 if no Inf/Nan detected. -If X contains both Inf/Nan, it will return the first indicator it meeted. +If X contains both Inf/Nan, it will return the first indicator it met. %s )DOC", diff --git a/paddle/fluid/operators/limit_by_capacity_op.cc b/paddle/fluid/operators/limit_by_capacity_op.cc index 569d1d025f79e..387e30ae647c9 100644 --- a/paddle/fluid/operators/limit_by_capacity_op.cc +++ b/paddle/fluid/operators/limit_by_capacity_op.cc @@ -71,7 +71,7 @@ class LimitByCapacityOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("capacity", "(Tensor) The input capacity."); AddOutput("Out", "(Tensor) The output tensor expert count limit by capacity."); - AddAttr("n_worker", "(int), The number of works."); + AddAttr("n_worker", "(int), The number of works."); AddComment( R"DOC(limit_by_capacity Operator.limit expert count by capacity.)DOC"); } diff --git a/paddle/fluid/operators/linear_chain_crf_op.cc b/paddle/fluid/operators/linear_chain_crf_op.cc index 46ff4c2e94a94..e017e43d7db2d 100644 --- a/paddle/fluid/operators/linear_chain_crf_op.cc +++ b/paddle/fluid/operators/linear_chain_crf_op.cc @@ -55,7 +55,7 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { "probabilities of all possible unfinished sequences of tags that end " "at position $k$ with tag $v$. For each $k$, " "$\alpha[k, v]$ is a vector of length $D$ with a component for " - "each tag value $v$. This vector is called a forward vecotr and " + "each tag value $v$. This vector is called a forward vector and " "will also be used in backward computations.") .AsIntermediate(); AddOutput( @@ -105,7 +105,7 @@ CRF. Please refer to http://www.cs.columbia.edu/~mcollins/fb.pdf and weights, denoted as $a$ here. 3. The next D values of Input(Transition) of this operator are for ending weights, denoted as $b$ here. -4. The remaning values of Input(Transition) are for transition weights, +4. The remaining values of Input(Transition) are for transition weights, denoted as $w$ here. 5. Denote Input(Label) as $s$ here. diff --git a/paddle/fluid/operators/linear_chain_crf_op.h b/paddle/fluid/operators/linear_chain_crf_op.h index ad2fbefdfd71f..2891320506391 100644 --- a/paddle/fluid/operators/linear_chain_crf_op.h +++ b/paddle/fluid/operators/linear_chain_crf_op.h @@ -234,7 +234,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel { static_cast(*std::max_element(lbl, lbl + seq_length)), tag_num, platform::errors::InvalidArgument( - "An invalid tag label that execesses the largest tag number.")); + "An invalid tag label that excesses the largest tag number.")); // Calculate the nominator part, which depends on the label sequence. ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] + @@ -308,7 +308,7 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { // Now, all the inputs and outputs should be on the CPU memory. auto emission_dims = emission_exps->dims(); // Beta is the memo table used in dynamic programming to calculate the - // backwark vectors. For a backward vector i (the i-th row of beta), it + // backward vectors. For a backward vector i (the i-th row of beta), it // captures the unnormalized probabilities of partial sequences starting // at position i. phi::DenseTensor beta; @@ -372,7 +372,7 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { const size_t state_trans_base_idx = 2; // Calculate the backward vectors: beta. - // First, calculate the initialition state. + // First, calculate the initial state. for (size_t i = 0; i < tag_num; ++i) { beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; } @@ -411,7 +411,7 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel { T* trans_grad = transition_grad->data(); for (size_t k = 0; k < tag_num; ++k) { // Do not multiply by the output gradient here, because x_grad_mat has - // alrealy done this. + // already done this. trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); trans_grad[tag_num + k] += x_grad_mat(/*to end state*/ seq_length - 1, k); diff --git a/paddle/fluid/operators/load_combine_op.h b/paddle/fluid/operators/load_combine_op.h index 9f15523ce0129..4641c39111fad 100644 --- a/paddle/fluid/operators/load_combine_op.h +++ b/paddle/fluid/operators/load_combine_op.h @@ -101,7 +101,7 @@ class LoadCombineOpKernel : public framework::OpKernel { framework::NFD(it->first, &tmp); if (tmp.empty()) { VLOG(0) << "The string " << it->first - << " was converted to unicode failedly! " + << " was converted to unicode unsuccessfully! " << "Then dropped to load it."; continue; } diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index dd85ccff87f2d..326746eb1e286 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -47,7 +47,7 @@ void LoadKernel(const Context& dev_ctx, PADDLE_ENFORCE_GE(seek, 0, phi::errors::InvalidArgument( - "seek witn tensor must great than or equal to 0")); + "seek with tensor must great than or equal to 0")); framework::DeserializeFromStream(fin, out, dev_ctx, seek, shape); } else { framework::DeserializeFromStream(fin, out, dev_ctx); diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index ec156954ca354..87b3695553356 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -191,6 +191,7 @@ FOR_ALL_TYPES(DEFINE_FUNCTOR); DEFINE_XPU_FUNCTOR(float) DEFINE_XPU_FUNCTOR(platform::float16) +DEFINE_XPU_FUNCTOR(platform::bfloat16) #endif } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/max_sequence_len_op.cc b/paddle/fluid/operators/max_sequence_len_op.cc index 813b1901760b9..1863787db3d3b 100644 --- a/paddle/fluid/operators/max_sequence_len_op.cc +++ b/paddle/fluid/operators/max_sequence_len_op.cc @@ -31,12 +31,12 @@ class OpBase; namespace paddle { namespace operators { -class MaxSeqenceLenOp : public framework::OperatorBase { +class MaxSequenceLenOp : public framework::OperatorBase { public: - MaxSeqenceLenOp(const std::string &type, - const framework::VariableNameMap &inputs, - const framework::VariableNameMap &outputs, - const framework::AttributeMap &attrs) + MaxSequenceLenOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) : OperatorBase(type, inputs, outputs, attrs) {} private: @@ -50,7 +50,7 @@ class MaxSeqenceLenOp : public framework::OperatorBase { } }; -class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker { +class MaxSequenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("RankTable", "Input variable which is a LoDRankTable object"); @@ -65,11 +65,11 @@ class MaxSeqenceLenOpProtoMaker : public framework::OpProtoAndCheckerMaker { } }; -class MaxSeqenceLenInferShape : public framework::InferShapeBase { +class MaxSequenceLenInferShape : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *context) const override { OP_INOUT_CHECK( - context->HasInput("RankTable"), "Input", "RankTable", "MaxSeqenceLen"); + context->HasInput("RankTable"), "Input", "RankTable", "MaxSequenceLen"); context->SetOutputDim("Out", {1}); } }; @@ -78,8 +78,8 @@ class MaxSeqenceLenInferShape : public framework::InferShapeBase { REGISTER_OPERATOR( max_sequence_len, - paddle::operators::MaxSeqenceLenOp, - paddle::operators::MaxSeqenceLenOpProtoMaker, - paddle::operators::MaxSeqenceLenInferShape, + paddle::operators::MaxSequenceLenOp, + paddle::operators::MaxSequenceLenOpProtoMaker, + paddle::operators::MaxSequenceLenInferShape, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/metrics/unity_build_rule.cmake b/paddle/fluid/operators/metrics/unity_build_rule.cmake index 58acbc3b1e62f..dee8680cc93d3 100644 --- a/paddle/fluid/operators/metrics/unity_build_rule.cmake +++ b/paddle/fluid/operators/metrics/unity_build_rule.cmake @@ -4,5 +4,4 @@ # Generally, the combination rules in this file do not need to be modified. # If there are some redefined error in compiling with the source file which # in combination rule, you can remove the source file from the following rules. -register_unity_group(cc accuracy_op.cc auc_op.cc precision_recall_op.cc) -register_unity_group(cu accuracy_op.cu auc_op.cu) +register_unity_group(cc precision_recall_op.cc) diff --git a/paddle/fluid/operators/nccl/nccl_gpu_common.h b/paddle/fluid/operators/nccl/nccl_gpu_common.h index 01905d8ca84b3..8d1478c123383 100644 --- a/paddle/fluid/operators/nccl/nccl_gpu_common.h +++ b/paddle/fluid/operators/nccl/nccl_gpu_common.h @@ -35,7 +35,8 @@ namespace paddle { namespace platform { constexpr int kInvalidGPUId = -1; -struct Communicator { +class Communicator { + public: Communicator() {} int GetCommId(int device_id) const; diff --git a/paddle/fluid/operators/nccl/nccl_op.cc b/paddle/fluid/operators/nccl/nccl_op.cc index 8b06aa653c070..c5a1097e2f157 100644 --- a/paddle/fluid/operators/nccl/nccl_op.cc +++ b/paddle/fluid/operators/nccl/nccl_op.cc @@ -18,7 +18,7 @@ limitations under the License. */ namespace paddle { namespace operators { -static constexpr char kParallelScopes[] = "parallel_scopes"; +static constexpr char kParallelScopes[] = "parallel_scopes"; // NOLINT // NCCLinitOp class NCCLInitOp : public framework::OperatorBase { diff --git a/paddle/fluid/operators/nce_op.cc b/paddle/fluid/operators/nce_op.cc index f4320cd0b6796..1b622b7571667 100644 --- a/paddle/fluid/operators/nce_op.cc +++ b/paddle/fluid/operators/nce_op.cc @@ -149,19 +149,19 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { AddInput( "CustomDistProbs", - "(Tensor) It is used in 'CostumDist' sampler. " + "(Tensor) It is used in 'CustomDist' sampler. " "It is a tensor with shape [num_total_classes]." "The i-th element is the probability of the i-th class being sampled.") .AsDispensable(); AddInput( "CustomDistAlias", - "(Tensor) It is used in 'CostumDist' sampler. " + "(Tensor) It is used in 'CustomDist' sampler. " "It is a tensor with shape [num_total_classes]." "The i-th element is the probability of the i-th class being sampled.") .AsDispensable(); AddInput( "CustomDistAliasProbs", - "(Tensor) It is used in 'CostumDist' sampler. " + "(Tensor) It is used in 'CustomDist' sampler. " "It is a tensor with shape [num_total_classes]." "The i-th element is the probability of the i-th class being sampled.") .AsDispensable(); @@ -194,7 +194,7 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(10); AddAttr("sampler", "(int) Which sampler to be used to sample negative class." - "0: Uniform; 1: LogUniform; 2: CostumDist.") + "0: Uniform; 1: LogUniform; 2: CustomDist.") .SetDefault(0); AddAttr("seed", "(int) The seed used in sampler. If it is 0, " diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index a21c7c816e191..41262dca6e53c 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -146,7 +146,7 @@ class NCEKernel : public framework::OpKernel { default: { PADDLE_THROW(platform::errors::InvalidArgument( "Unsupported SamplerType. SamplerType should be 0: Uniform, " - "1: LogUniform or 2: CostumDist. Received SamplerType: %d", + "1: LogUniform or 2: CustomDist. Received SamplerType: %d", sampler_type)); } } @@ -332,7 +332,7 @@ class NCEGradKernel : public framework::OpKernel { default: { PADDLE_THROW(platform::errors::InvalidArgument( "Unsupported SamplerType. SamplerType should be 0: Uniform, " - "1: LogUniform or 2: CostumDist. Received SamplerType: %d", + "1: LogUniform or 2: CustomDist. Received SamplerType: %d", sampler_type)); } } diff --git a/paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc b/paddle/fluid/operators/onednn/interpolate_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc rename to paddle/fluid/operators/onednn/interpolate_onednn_op.cc diff --git a/paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc b/paddle/fluid/operators/onednn/lrn_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc rename to paddle/fluid/operators/onednn/lrn_onednn_op.cc diff --git a/paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc b/paddle/fluid/operators/onednn/matmul_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/matmul_mkldnn_op.cc rename to paddle/fluid/operators/onednn/matmul_onednn_op.cc diff --git a/paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc b/paddle/fluid/operators/onednn/quantize_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc rename to paddle/fluid/operators/onednn/quantize_onednn_op.cc diff --git a/paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc b/paddle/fluid/operators/onednn/requantize_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc rename to paddle/fluid/operators/onednn/requantize_onednn_op.cc diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/onednn/reshape_onednn_op.cc similarity index 99% rename from paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc rename to paddle/fluid/operators/onednn/reshape_onednn_op.cc index 1e3b29da11e5b..8632160b04ae0 100644 --- a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc +++ b/paddle/fluid/operators/onednn/reshape_onednn_op.cc @@ -185,7 +185,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { "be -1. But received shape = [%s], shape[%d] is also -1.", common::make_ddim(shape), i)); - unk_dim_idx = i; + unk_dim_idx = static_cast(i); } else if (shape[i] == copy_dim_val) { PADDLE_ENFORCE_LT( static_cast(i), @@ -212,9 +212,9 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { shape[i])); } - capacity *= (shape[i] ? shape[i] : in_dims[i]); + capacity *= (shape[i] ? shape[i] : in_dims[i]); // NOLINT output_shape[i] = - (shape[i] ? static_cast(shape[i]) : in_dims[i]); + (shape[i] ? static_cast(shape[i]) : in_dims[i]); // NOLINT } if (unk_dim_idx != -1) { diff --git a/paddle/fluid/operators/mkldnn/shuffle_channel_mkldnn_op.cc b/paddle/fluid/operators/onednn/shuffle_channel_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/shuffle_channel_mkldnn_op.cc rename to paddle/fluid/operators/onednn/shuffle_channel_onednn_op.cc diff --git a/paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc b/paddle/fluid/operators/onednn/transpose_onednn_op.cc similarity index 100% rename from paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc rename to paddle/fluid/operators/onednn/transpose_onednn_op.cc diff --git a/paddle/fluid/operators/ops_signature/elementwise_sig.cc b/paddle/fluid/operators/ops_signature/elementwise_sig.cc index b1150268fbad1..82f891bb48a00 100644 --- a/paddle/fluid/operators/ops_signature/elementwise_sig.cc +++ b/paddle/fluid/operators/ops_signature/elementwise_sig.cc @@ -168,7 +168,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping( KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx UNUSED) { return KernelSignature("divide_double_grad", - {"Y", "Out", "DX", "DDX", "DDY"}, + {"Y", "Out", "Out@GRAD", "DX", "DDX", "DDY"}, {"axis"}, {"Y@GRAD", "DOut", "DDOut"}); } diff --git a/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc b/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc new file mode 100644 index 0000000000000..184df326b79e8 --- /dev/null +++ b/paddle/fluid/operators/ops_signature/fused_multi_transformer_sig.cc @@ -0,0 +1,58 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature FusedMultiTransformerOpArgumentMapping( + const ArgumentMappingContext& ctx UNUSED) { + return KernelSignature("fused_multi_transformer", + { + "X", + "LnScale", + "LnBias", + "QKVW", + "QKVBias", + "CacheKV", + "PreCaches", + "RotaryPosEmb", + "TimeStep", + "SeqLengths", + "SrcMask", + "OutLinearW", + "OutLinearBias", + "FFNLnScale", + "FFNLnBias", + "FFN1Weight", + "FFN1Bias", + "FFN2Weight", + "FFN2Bias", + }, + {"pre_layer_norm", + "epsilon", + "dropout_rate", + "rotary_emb_dims", + "is_test", + "dropout_implementation", + "act_method", + "trans_qkvw", + "ring_id"}, + {"CacheKVOut", "Out"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(fused_multi_transformer, + phi::FusedMultiTransformerOpArgumentMapping); diff --git a/paddle/fluid/operators/pad_op.cc b/paddle/fluid/operators/pad_op.cc index e2a0b3e025381..1a0f7b317d288 100644 --- a/paddle/fluid/operators/pad_op.cc +++ b/paddle/fluid/operators/pad_op.cc @@ -146,7 +146,7 @@ class PadCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { std::vector paddings = static_cast>(this->Attr>("paddings")); float pad_value = static_cast(this->Attr("pad_value")); - VLOG(6) << "Runing add_grad composite func"; + VLOG(6) << "Running add_grad composite func"; prim::pad_grad(x, out_grad, paddings, pad_value, dx_ptr); this->RecoverOutputName(x_grad, dx_name); diff --git a/paddle/fluid/operators/prim_ops/unity_build_rule.cmake b/paddle/fluid/operators/prim_ops/unity_build_rule.cmake index 74b04d234fcde..73340d33c1091 100644 --- a/paddle/fluid/operators/prim_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/prim_ops/unity_build_rule.cmake @@ -2,7 +2,6 @@ register_unity_group( cc reshape_p_op.cc broadcast_p_op.cc - reduce_p_op.cc transpose_p_op.cc split_p_op.cc concat_p_op.cc diff --git a/paddle/fluid/operators/pull_box_extended_sparse_op.h b/paddle/fluid/operators/pull_box_extended_sparse_op.h index b9508a279505e..76e570f10fb64 100644 --- a/paddle/fluid/operators/pull_box_extended_sparse_op.h +++ b/paddle/fluid/operators/pull_box_extended_sparse_op.h @@ -86,7 +86,7 @@ static void PushBoxExtendedSparseFunctor( cur_batch_size, platform::errors::PreconditionNotMet( "The batch size of all input slots should be same," - "please cheack")); + "please check")); } const float *grad_value = d_output[i]->data(); const float *grad_value_extend = d_output_extend[i]->data(); diff --git a/paddle/fluid/operators/pull_gpups_sparse_op.h b/paddle/fluid/operators/pull_gpups_sparse_op.h index d8fdadd99cbd4..e5e08cfdde685 100644 --- a/paddle/fluid/operators/pull_gpups_sparse_op.h +++ b/paddle/fluid/operators/pull_gpups_sparse_op.h @@ -30,7 +30,7 @@ static void PullGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { auto embedding_size_vec = ctx.Attr>("size"); const auto slot_size = inputs.size(); std::vector all_keys(slot_size); - // GpuPSPS only supports float now + // GpuPS only supports float now std::vector all_values(slot_size); std::vector slot_lengths(slot_size); for (size_t i = 0; i < slot_size; i++) { @@ -80,7 +80,7 @@ static void PushGpuPSSparseFunctor(const framework::ExecutionContext &ctx) { cur_batch_size, platform::errors::PreconditionNotMet( "The batch size of all input slots should be same, " - "please cheack")); + "please check")); } const float *grad_value = d_output[i]->data(); all_grad_values[i] = grad_value; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index ecdded21bb3e6..7d9c8ceca4943 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -119,7 +119,7 @@ static void CallPythonFunc(py::object *callable, out->ShareDataWith(*py_out_tensor); } catch (py::cast_error &) { PADDLE_THROW(platform::errors::InvalidArgument( - "py::cast to phi::DenseTensor error. The %d-th output expection is " + "py::cast to phi::DenseTensor error. The %d-th output exception is " "phi::DenseTensor", i)); } diff --git a/paddle/fluid/operators/pyramid_hash_op.cc b/paddle/fluid/operators/pyramid_hash_op.cc index 45373070d95f9..f5a8fcaa9de0c 100644 --- a/paddle/fluid/operators/pyramid_hash_op.cc +++ b/paddle/fluid/operators/pyramid_hash_op.cc @@ -354,8 +354,7 @@ class CPUPyramidHashOPKernel : public framework::OpKernel { ilayer + 1)) { if (_is_training != 0) { unsigned int rand_val = rand_r(&_seed); - float rate = - static_cast(rand_val) / (RAND_MAX); // NOLINT + double rate = static_cast(rand_val) / (RAND_MAX); *(iter_end++) = (rate < _drop_out_percent ? 0 : 1); } else { *(iter_end++) = 1; diff --git a/paddle/fluid/operators/random_routing_op.cc b/paddle/fluid/operators/random_routing_op.cc index 9eaa3a664877c..dffcc9c361a66 100644 --- a/paddle/fluid/operators/random_routing_op.cc +++ b/paddle/fluid/operators/random_routing_op.cc @@ -22,7 +22,7 @@ class RandomRoutingOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Prob"), "Input", "Porb", "RandomRouting"); + OP_INOUT_CHECK(ctx->HasInput("Prob"), "Input", "Prob", "RandomRouting"); OP_INOUT_CHECK( ctx->HasInput("TopK_Value"), "Input", "TopKValue", "RandomRouting"); OP_INOUT_CHECK( diff --git a/paddle/fluid/operators/randperm_op.h b/paddle/fluid/operators/randperm_op.h index 96981a4728402..560fdeb42eaa3 100644 --- a/paddle/fluid/operators/randperm_op.h +++ b/paddle/fluid/operators/randperm_op.h @@ -29,7 +29,7 @@ namespace paddle { namespace operators { template -static inline void random_permate(T* data_ptr, int num, unsigned int seed) { +static inline void random_permute(T* data_ptr, int num, unsigned int seed) { auto engine = phi::GetCPURandomEngine(seed); for (int i = 0; i < num; ++i) { data_ptr[i] = static_cast(i); @@ -50,13 +50,13 @@ class RandpermKernel : public framework::OpKernel { if (platform::is_cpu_place(ctx.GetPlace())) { T* out_data = out_tensor->mutable_data(platform::CPUPlace()); - random_permate(out_data, n, seed); + random_permute(out_data, n, seed); } else { phi::DenseTensor tmp_tensor; tmp_tensor.Resize(common::make_ddim({n})); T* tmp_data = tmp_tensor.mutable_data(platform::CPUPlace()); - random_permate(tmp_data, n, seed); + random_permute(tmp_data, n, seed); framework::TensorCopy(tmp_tensor, ctx.GetPlace(), out_tensor); } } diff --git a/paddle/fluid/operators/read_file_op.cc b/paddle/fluid/operators/read_file_op.cc index c19d0a6344ce5..a65b51d24e245 100644 --- a/paddle/fluid/operators/read_file_op.cc +++ b/paddle/fluid/operators/read_file_op.cc @@ -46,7 +46,7 @@ class ReadFileOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( This operator read a file. )DOC"); - AddAttr("filename", "Path of the file to be readed.") + AddAttr("filename", "Path of the file to be read.") .SetDefault({}); } }; diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index b73ffe4319be7..cc5034c86f90f 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -380,7 +380,7 @@ void BufferedReader::ReadNextImpl(paddle::framework::LoDTensorArray *out) { return; } - if (platform::is_gpu_place(place_)) { + if (platform::is_gpu_place(place_)) { // NOLINT *out = std::move(cuda_buffer_[i]); } else if (platform::is_xpu_place(place_)) { *out = std::move(xpu_buffer_[i]); diff --git a/paddle/fluid/operators/reduce_ops/unity_build_rule.cmake b/paddle/fluid/operators/reduce_ops/unity_build_rule.cmake index 839bb1ac7306c..da67c2c8d8b01 100644 --- a/paddle/fluid/operators/reduce_ops/unity_build_rule.cmake +++ b/paddle/fluid/operators/reduce_ops/unity_build_rule.cmake @@ -4,8 +4,7 @@ # Generally, the combination rules in this file do not need to be modified. # If there are some redefined error in compiling with the source file which # in combination rule, you can remove the source file from the following rules. -register_unity_group(cc reduce_all_op.cc reduce_any_op.cc) -register_unity_group(cu reduce_all_op.cu reduce_any_op.cu) + # The following groups are to make better use of `/MP` which MSVC's parallel # compilation instruction when compiling in Unity Build. register_unity_group(cu frobenius_norm_op.cu) diff --git a/paddle/fluid/operators/repeat_interleave_op.cc b/paddle/fluid/operators/repeat_interleave_op.cc index 15b4b80cb739b..d0af82510bdc4 100644 --- a/paddle/fluid/operators/repeat_interleave_op.cc +++ b/paddle/fluid/operators/repeat_interleave_op.cc @@ -77,7 +77,7 @@ class RepeatInterleaveOp : public framework::OperatorWithKernel { } else if (repeats > 0) { output_dim[dim] = input_dim[dim] * repeats; } - VLOG(3) << "infershap out " << output_dim[dim]; + VLOG(3) << "infershape out " << output_dim[dim]; ctx->SetOutputDim("Out", common::make_ddim(output_dim)); auto type = ctx->GetInputsVarType("X")[0]; if (type == framework::proto::VarType::LOD_TENSOR) { @@ -124,7 +124,7 @@ class RepeatInterleaveOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "(Tensor) the input tensor."); AddInput("RepeatsTensor", - "the 1-D tensor containing the repeats alongsize the axis.") + "the 1-D tensor containing the repeats alongside the axis.") .AsDispensable(); AddOutput("Out", "the output tensor."); AddAttr("Repeats", "the number of repetitions for each element.") diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 822eaf514bac5..34d80604ae8b0 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -581,7 +581,7 @@ class Reshape2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase { auto *dx_ptr = this->GetOutputPtr(&dx); std::string dx_name = this->GetOutputName(dx); - VLOG(6) << "Runing reshape2_grad composite func"; + VLOG(6) << "Running reshape2_grad composite func"; prim::reshape_grad(x, out_grad, dx_ptr); this->RecoverOutputName(dx, dx_name); } diff --git a/paddle/fluid/operators/run_program_op.h b/paddle/fluid/operators/run_program_op.h index 9e2d1fc4c97fb..6006d7556423c 100644 --- a/paddle/fluid/operators/run_program_op.h +++ b/paddle/fluid/operators/run_program_op.h @@ -34,7 +34,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_DNNL #include "paddle/fluid/platform/mkldnn_helper.h" #endif -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/operators/cuda_graph_with_in_out.h" #endif #include "paddle/common/flags.h" @@ -196,6 +196,20 @@ static cudaStreamCaptureMode StringToCUDAGraphCaptureMode( "Unsupported CUDA Graph capture mode %s", mode)); } } +#elif defined(PADDLE_WITH_HIP) +static hipStreamCaptureMode StringToCUDAGraphCaptureMode( + const std::string &mode) { + if (mode == "global") { + return hipStreamCaptureModeGlobal; + } else if (mode == "thread_local") { + return hipStreamCaptureModeThreadLocal; + } else if (mode == "relaxed") { + return hipStreamCaptureModeRelaxed; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Unsupported CUDA Graph capture mode %s", mode)); + } +} #endif } // namespace details @@ -211,7 +225,7 @@ class RunProgramOpKernel : public framework::OpKernel { return; } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto mode = details::StringToCUDAGraphCaptureMode(capture_mode); PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), @@ -408,7 +422,7 @@ class RunProgramGradOpKernel : public framework::OpKernel { return; } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto mode = details::StringToCUDAGraphCaptureMode(capture_mode); PADDLE_ENFORCE_EQ( platform::is_gpu_place(ctx.GetPlace()), diff --git a/paddle/fluid/operators/save_combine_op.h b/paddle/fluid/operators/save_combine_op.h index 1888ce5b57493..f5c3fb9969f1e 100644 --- a/paddle/fluid/operators/save_combine_op.h +++ b/paddle/fluid/operators/save_combine_op.h @@ -30,7 +30,7 @@ limitations under the License. */ #include "paddle/fluid/framework/raw_tensor.h" #include "paddle/fluid/framework/string_array.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "paddle/phi/core/dense_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/save_op.h b/paddle/fluid/operators/save_op.h index a27a2fe74c1dd..67f71f6e58559 100644 --- a/paddle/fluid/operators/save_op.h +++ b/paddle/fluid/operators/save_op.h @@ -106,14 +106,14 @@ class SaveOpKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto* input_var = ctx.InputVar("X"); - auto iname = ctx.InputNames("X").data(); + std::vector _iname = ctx.InputNames("X"); + auto iname = _iname.data(); PADDLE_ENFORCE_NOT_NULL( input_var, phi::errors::InvalidArgument( "The variable %s to be saved cannot be found.", iname)); auto filename = ctx.Attr("file_path"); - auto overwrite = ctx.Attr("overwrite"); auto save_as_fp16 = ctx.Attr("save_as_fp16"); VLOG(4) << "save output file_path: " << filename; diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 1842ed34a5c67..ddda1131f5cc7 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -86,13 +86,13 @@ class SplitOp : public framework::OperatorWithKernel { Variable *var = PADDLE_GET_CONST(Variable *, section_varptr); sections_from_tensor.emplace_back(var->Get()); } - sections_final = std::move(phi::IntArray(sections_from_tensor)); + sections_final = phi::IntArray(sections_from_tensor); } else if (!ctx->IsRuntime() && ctx->HasInputs("SectionsTensorList")) { - sections_final = std::move(phi::IntArray(std::vector( - ctx->GetInputVarPtrs("SectionsTensorList").size(), -1))); + sections_final = phi::IntArray(std::vector( + ctx->GetInputVarPtrs("SectionsTensorList").size(), -1)); sections_final.SetFromTensor(true); } else { - sections_final = std::move(phi::IntArray(sections)); + sections_final = phi::IntArray(sections); } if (!sections.empty()) { if (ctx->IsRuntime()) { @@ -222,7 +222,7 @@ class SplitCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { "We don't support dynamic index or sections from tensor for split " "composite grad for now. ")); } else { - VLOG(6) << "Runing split_grad composite func"; + VLOG(6) << "Running split_grad composite func"; prim::split_grad(out_grad, axis, dx_ptr); this->RecoverOutputName(input_grad, dx_name); } diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 718f4876406af..d8b7e35d6d3a1 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -127,7 +127,7 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput( "X", - "A Varaible list. The shape and data type of the list elements" + "A Variable list. The shape and data type of the list elements" "should be consistent. Variable can be multi-dimensional Tensor" "or phi::DenseTensor, and data types can be: float32, float64, int32, " "int64.") diff --git a/paddle/fluid/operators/svd_helper.h b/paddle/fluid/operators/svd_helper.h index caa31565d4cf3..273e2c7b65100 100644 --- a/paddle/fluid/operators/svd_helper.h +++ b/paddle/fluid/operators/svd_helper.h @@ -271,7 +271,7 @@ struct DiagAndFillFunctor { template struct DeviceIndependenceTensorOperations { - // 1. Device indenpendence, for kernel reuse. + // 1. Device independence, for kernel reuse. // 2. Input and output is always tensor type. // 3. output phi::DenseTensor is alway allocated // 4. Basic phi::DenseTensor operator is supported @@ -315,7 +315,7 @@ struct DeviceIndependenceTensorOperations { } phi::DenseTensor Transpose(const phi::DenseTensor& x) { - // transpose the last two dimision + // transpose the last two dimension phi::DenseTensor ret; auto x_dim = x.dims(); auto x_vec = common::vectorize(x_dim); @@ -745,7 +745,7 @@ struct DeviceIndependenceTensorOperations { const framework::AttributeMap& attrs, std::vector out_shape, NameOutTensor out_str = {"Out"}) { - // varialble set dims must be phi::DenseTensor / SelectedRowTensor + // variable set dims must be phi::DenseTensor / SelectedRowTensor framework::Scope& local_scope = context.scope().NewScope(); framework::VariableNameMap op_outputs; for (auto out_name : out_str) { @@ -753,7 +753,7 @@ struct DeviceIndependenceTensorOperations { op_outputs[out_name].emplace_back("tmp_" + out_name); } auto out_var = local_scope.Var("tmp_Out"); // return the Out - // create Out phi::DenseTensor and allocat memory + // create Out phi::DenseTensor and allocate memory out_var->GetMutable()->mutable_data( common::make_ddim(out_shape), context.GetPlace()); // common::make_ddim(out_shape) diff --git a/paddle/fluid/operators/tdm_sampler_op.h b/paddle/fluid/operators/tdm_sampler_op.h index ec5587c330fc7..52f86d633307b 100644 --- a/paddle/fluid/operators/tdm_sampler_op.h +++ b/paddle/fluid/operators/tdm_sampler_op.h @@ -214,9 +214,9 @@ void TDMSamplerInner(const framework::ExecutionContext &context, label_vec[i * sample_res_length + offset] = 0; mask_vec[i * sample_res_length + offset] = 1; VLOG(3) << "TDM: node id: " << travel_data[start_offset + layer_idx] - << " Res append negitive " + << " Res append negative " << output_vec[i * sample_res_length + offset] - << " Label append negitive " + << " Label append negative " << label_vec[i * sample_res_length + offset] << " Mask append value " << mask_vec[i * sample_res_length + offset]; diff --git a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc index ad54a49f820f9..332008894d5b9 100644 --- a/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc +++ b/paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc @@ -173,7 +173,7 @@ class TeacherStudentSigmoidLossGradientOp platform::errors::InvalidArgument( "When Attr(soft_label) == false, the 2nd dimension of " "Input(Label) should be 1. But received Input(Label)'s 2nd " - "dimemsion " + "dimension " "is [%d]", label_dims[1])); } diff --git a/paddle/fluid/operators/tile_op.cc b/paddle/fluid/operators/tile_op.cc index 26657ce42f303..9d961bbd57122 100644 --- a/paddle/fluid/operators/tile_op.cc +++ b/paddle/fluid/operators/tile_op.cc @@ -185,7 +185,7 @@ class TileCompositeGradOpMaker : public prim::CompositeGradOpMakerBase { "We don't support RepeatTimes from tensor or repeat_times_tensor for " "tile composite grad for now. ")); } else { - VLOG(6) << "Runing tile_grad composite func"; + VLOG(6) << "Running tile_grad composite func"; prim::tile_grad( x, out_grad, paddle::experimental::IntArray(repeat_times), dx_ptr); this->RecoverOutputName(x_grad, dx_name); diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index ef6172b6965f2..003f670133e45 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -93,7 +93,7 @@ class TopkOpCUDAKernel : public framework::OpKernel { if ((input_width <= 1024 || k >= 128 || k == input_width)) { if (phi::funcs::SortTopk( dev_ctx, input, input_width, input_height, k, output, indices)) { - // Successed, return. + // Succeed, return. return; } else { LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index f8fa53e2ad505..b0d30f1d22d3b 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -46,7 +46,7 @@ class TopkKernel : public framework::OpKernel { T* output_data = output->mutable_data(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - // reshape input to a flattern matrix(like flat_inner_dims) + // reshape input to a flatten matrix(like flat_inner_dims) framework::DDim inputdims = input->dims(); const size_t row = common::product(common::slice_ddim(inputdims, 0, inputdims.size() - 1)); diff --git a/paddle/fluid/operators/top_k_op_xpu.cc b/paddle/fluid/operators/top_k_op_xpu.cc index 55d3fa8624a8c..fff713236e9a6 100644 --- a/paddle/fluid/operators/top_k_op_xpu.cc +++ b/paddle/fluid/operators/top_k_op_xpu.cc @@ -60,7 +60,7 @@ class TopkXPUKernel : public framework::OpKernel { int* indices_int_data = RAII_GUARD.alloc_l3_or_gm(indices->numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(indices_int_data); - // reshape input to a flattern matrix(like flat_inner_dims) + // reshape input to a flatten matrix(like flat_inner_dims) framework::DDim inputdims = input->dims(); const size_t row = common::product(common::slice_ddim(inputdims, 0, inputdims.size() - 1)); diff --git a/paddle/fluid/operators/transfer_layout_op.h b/paddle/fluid/operators/transfer_layout_op.h index 52633640fa95b..2736171626121 100644 --- a/paddle/fluid/operators/transfer_layout_op.h +++ b/paddle/fluid/operators/transfer_layout_op.h @@ -110,7 +110,7 @@ class TransferLayoutFunctor { } VLOG(4) << "TransDataLayoutFromOneDNN: " << in_layout << "->" << target_layout; - // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel + // Case2 - transform from ONEDNN OPKernel to Non-ONEDNN OPKernel // Do transform via ONEDNN lib phi::funcs::TransDataLayoutFromOneDNN(in_layout, target_layout, @@ -119,11 +119,11 @@ class TransferLayoutFunctor { dev_ctx_.GetPlace()); } } else { - // Case3 - transfrom between Non-ONEDNN OPKernels + // Case3 - transform between Non-ONEDNN OPKernels TransDataLayout(dev_ctx_, in_tensor, &out_tensor); } #else - // Case3 - transfrom between Non-ONEDNN OPKernels + // Case3 - transform between Non-ONEDNN OPKernels TransDataLayout(dev_ctx_, in_tensor, &out_tensor); #endif framework::SetTensorToVariable(*in_, out_tensor, out_); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 417299d24db07..340728a1b8d1e 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -202,7 +202,7 @@ class Transpose2CompositeGradOpMaker : public prim::CompositeGradOpMakerBase { std::string dx_name = this->GetOutputName(dx); std::vector axis = static_cast>(this->Attr>("axis")); - VLOG(6) << "Runing transpose2_grad composite func"; + VLOG(6) << "Running transpose2_grad composite func"; prim::transpose_grad(out_grad, axis, dx_ptr); this->RecoverOutputName(dx, dx_name); } diff --git a/paddle/fluid/operators/unity_build_rule.cmake b/paddle/fluid/operators/unity_build_rule.cmake index 07136f7bd4f31..4409056108e62 100644 --- a/paddle/fluid/operators/unity_build_rule.cmake +++ b/paddle/fluid/operators/unity_build_rule.cmake @@ -29,22 +29,22 @@ register_unity_group( bmm_op.cc bpr_loss_op.cc cast_op.cc - mkldnn/cast_mkldnn_op.cc + onednn/cast_onednn_op.cc cholesky_op.cc chunk_eval_op.cc clip_by_norm_op.cc clip_op.cc coalesce_tensor_op.cc - mkldnn/activation_mkldnn_op.cc - mkldnn/interpolate_mkldnn_op.cc - mkldnn/pool_mkldnn_op.cc - mkldnn/softmax_mkldnn_op.cc) + onednn/activation_onednn_op.cc + onednn/interpolate_onednn_op.cc + onednn/pool_onednn_op.cc + onednn/softmax_onednn_op.cc) register_unity_group( cc center_loss_op.cc - mkldnn/concat_mkldnn_op.cc - mkldnn/conv_mkldnn_op.cc - mkldnn/conv_transpose_mkldnn_op.cc + onednn/concat_onednn_op.cc + onednn/conv_onednn_op.cc + onednn/conv_transpose_onednn_op.cc correlation_op.cc cos_sim_op.cc crf_decoding_op.cc @@ -69,7 +69,7 @@ register_unity_group( delete_var_op.cc dequantize_abs_max_op.cc dequantize_op.cc - mkldnn/dequantize_mkldnn_op.cc) + onednn/dequantize_onednn_op.cc) register_unity_group( cc dequeue_op.cc @@ -92,7 +92,7 @@ register_unity_group( expand_v2_op.cc fake_dequantize_op.cc fc_op.cc - mkldnn/fc_mkldnn_op.cc + onednn/fc_onednn_op.cc fill_any_like_op.cc fill_constant_batch_size_like_op.cc fill_constant_op.cc @@ -105,7 +105,7 @@ register_unity_group( gather_nd_op.cc gather_tree_op.cc gaussian_random_batch_size_like_op.cc - mkldnn/gaussian_random_mkldnn_op.cc + onednn/gaussian_random_onednn_op.cc group_norm_op.cc gru_op.cc) register_unity_group( @@ -143,7 +143,7 @@ register_unity_group( log_softmax_op.cc lookup_table_dequant_op.cc lrn_op.cc - mkldnn/lrn_mkldnn_op.cc + onednn/lrn_onednn_op.cc lstm_unit_op.cc) register_unity_group( cc @@ -152,7 +152,7 @@ register_unity_group( masked_select_op.cc match_matrix_tensor_op.cc matmul_op.cc - mkldnn/matmul_mkldnn_op.cc + onednn/matmul_onednn_op.cc max_sequence_len_op.cc maxout_op.cc merge_lod_tensor_op.cc @@ -204,7 +204,7 @@ register_unity_group( cc push_dense_op.cc quantize_op.cc - mkldnn/quantize_mkldnn_op.cc + onednn/quantize_onednn_op.cc queue_generator_op.cc range_op.cc rank_attention_op.cc @@ -212,7 +212,7 @@ register_unity_group( recurrent_op.cc reorder_lod_tensor_by_rank_op.cc requantize_op.cc - mkldnn/requantize_mkldnn_op.cc + onednn/requantize_onednn_op.cc reshape_op.cc reverse_op.cc) register_unity_group( @@ -224,7 +224,7 @@ register_unity_group( save_combine_op.cc save_op.cc scale_op.cc - mkldnn/scale_mkldnn_op.cc + onednn/scale_onednn_op.cc scatter_nd_add_op.cc scatter_op.cc seed_op.cc @@ -256,7 +256,7 @@ register_unity_group( stack_op.cc strided_slice_op.cc sum_op.cc - mkldnn/sum_mkldnn_op.cc + onednn/sum_onednn_op.cc tdm_child_op.cc tdm_sampler_op.cc teacher_student_sigmoid_loss_op.cc @@ -269,7 +269,7 @@ register_unity_group( top_k_v2_op.cc trace_op.cc transpose_op.cc - mkldnn/transpose_mkldnn_op.cc + onednn/transpose_onednn_op.cc unbind_op.cc unfold_op.cc) register_unity_group( diff --git a/paddle/fluid/pir/CMakeLists.txt b/paddle/fluid/pir/CMakeLists.txt index 24f5e2892de8e..9e883ef21af9a 100644 --- a/paddle/fluid/pir/CMakeLists.txt +++ b/paddle/fluid/pir/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(dialect) add_subdirectory(transforms) add_subdirectory(drr) +add_subdirectory(utils) diff --git a/paddle/fluid/pir/dialect/CMakeLists.txt b/paddle/fluid/pir/dialect/CMakeLists.txt index 2955a6d57afb5..59db81550bb8b 100644 --- a/paddle/fluid/pir/dialect/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/CMakeLists.txt @@ -95,7 +95,8 @@ execute_process( --op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace} --dialect_name ${dialect_name} --op_def_h_file ${op_header_file_tmp} --op_info_file ${op_info_file_tmp} --op_def_cc_file ${op_src_files_tmp} - --op_vjp_cc_file ${op_vjp_src_file_tmp}) + --op_vjp_cc_file ${op_vjp_src_file_tmp} --with_distributed + ${WITH_DISTRIBUTE}) set(generated_files_pd_op "${op_header_file}" @@ -141,7 +142,7 @@ if(WITH_MKLDNN) --op_def_h_file ${onednn_op_header_file_tmp} --op_info_file ${op_onednn_info_file_tmp} --op_def_cc_file ${onednn_op_source_file_tmp} --onednn_yaml_file ${pir_op_onednn_yaml} --ops_onednn_extra_yaml_file - ${pd_ops_onednn_extra_yaml_file}) + ${pd_ops_onednn_extra_yaml_file} --with_distributed ${WITH_DISTRIBUTE}) set(generated_files_onednn_pd_op "${onednn_op_header_file}" "${onednn_op_source_file}" @@ -255,7 +256,17 @@ if(WITH_MKLDNN) ${CMAKE_CURRENT_SOURCE_DIR}/operator/ir/manual_onednn_op.cc) endif() +file(GLOB_RECURSE dist_dialect_srcs + "${CMAKE_CURRENT_SOURCE_DIR}/distributed/ir/*.cc" + "${CMAKE_CURRENT_SOURCE_DIR}/distributed/transforms/*.cc") + +# if(WITH_DISTRIBUTE) FIXME in next PR +set(op_dialect_srcs ${op_dialect_srcs} ${dist_dialect_srcs}) +# endif() set(op_dialect_deps phi common pir type_info string_helper) +if(WITH_ROCM) + set(op_dialect_deps ${op_dialect_deps} global_utils) +endif() cc_library( op_dialect diff --git a/paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h b/paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h new file mode 100644 index 0000000000000..66fd9fd5a9d26 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h @@ -0,0 +1,170 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/common/ddim.h" +#include "paddle/common/hash_funcs.h" +#include "paddle/common/layout.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/phi/common/reduce_type.h" +#include "paddle/pir/include/core/attribute_base.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/utils.h" +#include "paddle/utils/flat_hash_map.h" + +namespace paddle { +namespace dialect { + +class ProcessMeshAttrStorage : public pir::AttributeStorage { + public: + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = phi::distributed::ProcessMesh; + + ProcessMeshAttrStorage(ParamKey&& process_mesh) // NOLINT + : process_mesh(std::move(process_mesh)) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static ProcessMeshAttrStorage* Construct(ParamKey&& key) { + return new ProcessMeshAttrStorage(std::move(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { return key.hash(); } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return process_mesh == key && process_mesh.dim_names() == key.dim_names(); + } + + ParamKey process_mesh; +}; + +class TensorDistAttrStorage : public pir::AttributeStorage { + public: + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = std::tuple, + flat_hash_map>; + + TensorDistAttrStorage(ParamKey&& param) // NOLINT + : mesh_attr(std::get<0>(param)), + dims_mapping(std::move(std::get<1>(param))), + partial_status(std::move(std::get<2>(param))) {} + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static TensorDistAttrStorage* Construct(ParamKey&& key) { + return new TensorDistAttrStorage(std::move(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + auto mesh_hash = std::get<0>(key).hash(); + auto dims_map_hash = std::hash>()(std::get<1>(key)); + std::string partial_status_str = "["; + for (auto& itr : std::get<2>(key)) { + partial_status_str += + "Partial(dims:" + std::to_string(itr.first) + ", " + + phi::ReduceTypeStrings[static_cast(itr.second)] + "), "; + } + partial_status_str += "]"; + auto combine_hash = pir::detail::hash_combine(mesh_hash, dims_map_hash); + return pir::detail::hash_combine( + combine_hash, std::hash()(partial_status_str)); + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return mesh_attr == std::get<0>(key) && dims_mapping == std::get<1>(key) && + partial_status == std::get<2>(key); + } + + ProcessMeshAttribute mesh_attr; + std::vector dims_mapping; + // partial map would less or equal than to mesh.size. + // iterate operation (copy and comparison) would more frequency than random + // element access. + flat_hash_map partial_status; +}; + +class OperationDistAttrStorage : public pir::AttributeStorage { + public: + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = std::tuple, + std::vector>; + OperationDistAttrStorage(ParamKey&& param) // NOLINT + : mesh_attr(std::get<0>(param)), + operand_dist_attrs(std::get<1>(param)), + result_dist_attrs(std::get<2>(param)) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static OperationDistAttrStorage* Construct(ParamKey&& key) { + return new OperationDistAttrStorage(std::move(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + auto hash_value = std::hash()(std::get<0>(key)); + for (auto& iter : std::get<1>(key)) { + auto tmp_value = std::hash()(iter); + hash_value = pir::detail::hash_combine(hash_value, tmp_value); + } + for (auto& iter : std::get<2>(key)) { + auto tmp_value = std::hash()(iter); + hash_value = pir::detail::hash_combine(hash_value, tmp_value); + } + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return mesh_attr == std::get<0>(key) && + operand_dist_attrs == std::get<1>(key) && + result_dist_attrs == std::get<2>(key); + } + + ProcessMeshAttribute mesh_attr; + std::vector operand_dist_attrs; + std::vector result_dist_attrs; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc new file mode 100644 index 0000000000000..3382fa18b9090 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_api.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h" +#include +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/common/reduce_type.h" +#include "paddle/pir/include/core/builder.h" +#include "paddle/pir/include/core/operation_utils.h" +#include "paddle/pir/include/core/value.h" +#include "paddle/utils/flat_hash_map.h" + +namespace paddle { +namespace dialect { + +pir::Value shard_tensor(const pir::Value& x, + const phi::distributed::ProcessMesh& process_mesh, + const std::vector& dims_mapping) { + pir::IrContext* ctx = pir::IrContext::Instance(); + // support amp for shard_tensor in the future + paddle::flat_hash_map partial_status; + pir::AttributeMap attribute_map = { + {"tensor_dist_attr", + TensorDistAttribute::get( + ctx, process_mesh, dims_mapping, partial_status)}}; + + auto shard_tensor_op = + ApiBuilder::Instance().GetBuilder()->Build(x, + attribute_map); + return shard_tensor_op.out(); +} + +pir::Value reshard(const pir::Value& x, + const phi::distributed::ProcessMesh& process_mesh, + const std::vector& dims_mapping) { + pir::IrContext* ctx = pir::IrContext::Instance(); + // TODO(ywt01) get partial_status by func parameter + paddle::flat_hash_map partial_status; + TensorDistAttribute tensor_dist_attr = + TensorDistAttribute::get(ctx, process_mesh, dims_mapping, partial_status); + + auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build( + x, tensor_dist_attr); + return reshard_op.result(0); +} + +pir::Value reshard(const pir::Value& x, + const TensorDistAttribute& tensor_dist_attr) { + auto reshard_op = ApiBuilder::Instance().GetBuilder()->Build( + x, tensor_dist_attr); + return reshard_op.result(0); +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_api.h b/paddle/fluid/pir/dialect/distributed/ir/dist_api.h new file mode 100644 index 0000000000000..18aa1bb32ca64 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_api.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/pir/include/core/value.h" + +namespace paddle { +namespace dialect { + +pir::Value shard_tensor(const pir::Value& x, + const phi::distributed::ProcessMesh& process_mesh, + const std::vector& dims_mapping); + +pir::Value reshard(const pir::Value& x, + const phi::distributed::ProcessMesh& process_mesh, + const std::vector& dims_mapping); + +pir::Value reshard(const pir::Value& x, + const TensorDistAttribute& tensor_dist_attr); + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc new file mode 100644 index 0000000000000..e36f678929dde --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.cc @@ -0,0 +1,129 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h" +#include "paddle/phi/core/enforce.h" +namespace paddle { +namespace dialect { +/// +/// \brief ProcessMeshAttribute interface. +/// +const phi::distributed::ProcessMesh& ProcessMeshAttribute::process_mesh() + const { + return storage()->process_mesh; +} +ProcessMeshAttribute ProcessMeshAttribute::get( + pir::IrContext* ctx, const phi::distributed::ProcessMesh& mesh) { + return Base::get(ctx, mesh); +} +ProcessMeshAttribute ProcessMeshAttribute::get( + pir::IrContext* ctx, + const std::vector& shape, + const std::vector& process_ids, + const std::vector& dim_names) { + return Base::get(ctx, shape, process_ids, dim_names); +} + +/// +/// \brief TensorDistAttribute interface. +/// +ProcessMeshAttribute TensorDistAttribute::process_mesh_attr() const { + return storage()->mesh_attr; +} +const std::vector& TensorDistAttribute::dims_mapping() const { + return storage()->dims_mapping; +} + +std::set TensorDistAttribute::partial_dims() const { + auto& partial = partial_status(); + std::set keys; + for (auto& kv : partial) { + keys.emplace(kv.first); + } + return keys; +} + +const flat_hash_map& +TensorDistAttribute::partial_status() const { + return storage()->partial_status; +} + +TensorDistAttribute TensorDistAttribute::get( + pir::IrContext* ctx, + ProcessMeshAttribute mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status) { + PADDLE_ENFORCE_NOT_NULL(mesh, + common::errors::PreconditionNotMet( + "Building tensor_dist_attr through a nullptr " + "mesh attribute is currently not supported.")); + return Base::get(ctx, mesh, dims_mapping, partial_status); +} + +/// +/// \brief OperationDistAttribute interface. +/// +ProcessMeshAttribute OperationDistAttribute::process_mesh_attr() const { + return storage()->mesh_attr; +} +const std::vector& +OperationDistAttribute::operand_dist_attrs() const { + return storage()->operand_dist_attrs; +} +TensorDistAttribute OperationDistAttribute::operand_dist_attr( + uint32_t index) const { + return operand_dist_attrs().at(index); +} +uint32_t OperationDistAttribute::num_operand_dist_attrs() const { + return operand_dist_attrs().size(); +} + +const std::vector& +OperationDistAttribute::result_dist_attrs() const { + return storage()->result_dist_attrs; +} +TensorDistAttribute OperationDistAttribute::result_dist_attr( + uint32_t index) const { + return result_dist_attrs().at(index); +} +uint32_t OperationDistAttribute::num_result_dist_attrs() const { + return result_dist_attrs().size(); +} +OperationDistAttribute OperationDistAttribute::get( + pir::IrContext* ctx, + ProcessMeshAttribute mesh, + const std::vector& operand_dist_attrs, + const std::vector& result_dist_attrs) { + for (const auto& iter : operand_dist_attrs) { + // NOTE: The operand dist attr maybe empty while the corresponding input is + // optional. + if (iter) { + PADDLE_ENFORCE_EQ(mesh, + iter.process_mesh_attr(), + common::errors::PreconditionNotMet( + "operand_dist_attrs element's mesh(%s) not equal " + "to input mesh(%s)", + iter.process_mesh_attr(), + mesh)); + } + } + return Base::get(ctx, mesh, operand_dist_attrs, result_dist_attrs); +} + +} // namespace dialect +} // namespace paddle +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OperationDistAttribute) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h new file mode 100644 index 0000000000000..2b2be781c9ca8 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h @@ -0,0 +1,133 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/reduce_type.h" +#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" +#include "paddle/pir/include/core/attribute.h" +#include "paddle/pir/include/core/builtin_attribute_storage.h" +#include "paddle/pir/include/core/utils.h" +#include "paddle/utils/flat_hash_map.h" + +namespace paddle { +namespace dialect { +class ProcessMeshAttrStorage; +class TensorDistAttrStorage; +class OperationDistAttrStorage; + +class ProcessMeshAttribute : public pir::AttrBase { + public: + using Base::Base; + const phi::distributed::ProcessMesh& process_mesh() const; + const std::vector& shape() const { return process_mesh().shape(); } + const std::vector& process_ids() const { + return process_mesh().process_ids(); + } + const std::vector& dim_names() const { + return process_mesh().dim_names(); + } + int64_t size() const { return process_mesh().size(); } + int64_t ndim() const { return process_mesh().ndim(); } + int64_t dim_size(int64_t dim) const { return process_mesh().dim_size(dim); } + int64_t dim_size(const std::string& dim_name) const { + return process_mesh().dim_size(dim_name); + } + bool empty() const { return process_mesh().empty(); } + bool contains(int64_t process_id) const { + return process_mesh().contains(process_id); + } + size_t hash() const { return process_mesh().hash(); } + + std::string to_string() const { return process_mesh().to_string(); } + + static ProcessMeshAttribute get(pir::IrContext* ctx, + const phi::distributed::ProcessMesh& mesh); + static ProcessMeshAttribute get(pir::IrContext* ctx, + const std::vector& shape, + const std::vector& process_ids, + const std::vector& dim_names); +}; + +class TensorDistAttribute : public pir::AttrBase { + public: + using Base::Base; + ProcessMeshAttribute process_mesh_attr() const; + const std::vector& dims_mapping() const; + + // return vector of mesh dims on which the this tensor is partial on + std::set partial_dims() const; + + const flat_hash_map& partial_status() const; + + static TensorDistAttribute get( + pir::IrContext* ctx, + ProcessMeshAttribute mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status = {}); + static TensorDistAttribute get( + pir::IrContext* ctx, + const phi::distributed::ProcessMesh& mesh, + const std::vector& dims_mapping, + const flat_hash_map& partial_status = {}) { + return get(ctx, + ProcessMeshAttribute::get(ctx, mesh), + dims_mapping, + partial_status); + } +}; + +class OperationDistAttribute : public pir::AttrBase { + public: + using Base::Base; + ProcessMeshAttribute process_mesh_attr() const; + + const std::vector& operand_dist_attrs() const; + TensorDistAttribute operand_dist_attr(uint32_t index) const; + uint32_t num_operand_dist_attrs() const; + + const std::vector& result_dist_attrs() const; + TensorDistAttribute result_dist_attr(uint32_t index) const; + uint32_t num_result_dist_attrs() const; + + static OperationDistAttribute get( + pir::IrContext* ctx, + ProcessMeshAttribute mesh, + const std::vector& operand_dist_attrs, + const std::vector& result_dist_attrs); + + static OperationDistAttribute get( + pir::IrContext* ctx, + const phi::distributed::ProcessMesh& mesh, + const std::vector& operand_dist_attrs, + const std::vector& result_dist_attrs) { + return get(ctx, + ProcessMeshAttribute::get(ctx, mesh), + operand_dist_attrs, + result_dist_attrs); + } +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ProcessMeshAttribute) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::TensorDistAttribute) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::OperationDistAttribute) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc new file mode 100644 index 0000000000000..0ea42bf6e093d --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.cc @@ -0,0 +1,182 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" + +#include "paddle/fluid/pir/dialect/distributed/ir/attribute_storage.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" + +REGISTER_FILE_SYMBOLS(dist_dialect); +namespace paddle { +namespace dialect { + +DistDialect::DistDialect(pir::IrContext *context) + : pir::Dialect(name(), context, pir::TypeId::get()) { + initialize(); +} + +void DistDialect::initialize() { + RegisterAttributes(); + RegisterTypes(); + RegisterOps(); +} + +void DistDialect::PrintType(pir::Type type, std::ostream &os) const { + if (auto dist_dense_tensor_type = type.dyn_cast()) { + // Todo: Design the dist dense tensor type print format. + os << type.dialect().name(); + os << '.'; + if (auto tensor_type = type.dyn_cast()) { + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ", "; + PrintAttribute(dist_dense_tensor_type.tensor_dist_attr(), os); + os << ">"; + } + } else { + os << "error_type!"; + } +} + +void DistDialect::PrintAttribute(pir::Attribute attr, std::ostream &os) const { + if (auto process_mesh_attr = attr.dyn_cast()) { + os << "mesh_shape:[" + + phi::distributed::auto_parallel::str_join( + process_mesh_attr.shape()) + + "]"; + os << ",process_ids:[" + + phi::distributed::auto_parallel::str_join( + process_mesh_attr.process_ids()) + + "]"; + } else if (auto tensor_dist_attr = attr.dyn_cast()) { + os << "mesh_shape:[" + + phi::distributed::auto_parallel::str_join( + tensor_dist_attr.process_mesh_attr().shape()) + + "]"; + os << ",dims_mappings:[" + + phi::distributed::auto_parallel::str_join( + tensor_dist_attr.dims_mapping()) + + "]"; + if (tensor_dist_attr.partial_status().size() > 0) { + std::vector partial_status_strs; + for (auto &itr : tensor_dist_attr.partial_status()) { + std::string s = "partial(" + std::to_string(itr.first) + "," + + phi::ReduceTypeStrings[static_cast(itr.second)] + + ")"; + partial_status_strs.emplace_back(s); + } + os << ", " + << phi::distributed::auto_parallel::str_join(partial_status_strs); + } + } else if (auto op_dist_attr = attr.dyn_cast()) { + os << "{mesh:{shape:[" + + phi::distributed::auto_parallel::str_join( + op_dist_attr.process_mesh_attr().shape()) + + "]"; + os << ",process_ids:[" + + phi::distributed::auto_parallel::str_join( + op_dist_attr.process_mesh_attr().process_ids()) + + "]}"; + auto num_operand_dist_attrs = op_dist_attr.num_operand_dist_attrs(); + for (uint32_t i = 0; i < num_operand_dist_attrs; ++i) { + auto dist_attr = op_dist_attr.operand_dist_attr(i); + os << ",operand(" + std::to_string(i) + "):{"; + if (!dist_attr) { + os << "null}"; + continue; + } + if (dist_attr.process_mesh_attr() != op_dist_attr.process_mesh_attr()) { + os << "mesh_shape:[" + + phi::distributed::auto_parallel::str_join( + dist_attr.process_mesh_attr().shape()) + + "],"; + } + os << "dims_maping:[" + + phi::distributed::auto_parallel::str_join( + dist_attr.dims_mapping()) + + "]"; + if (dist_attr.partial_status().size() > 0) { + std::vector partial_status_strs; + for (auto &itr : dist_attr.partial_status()) { + std::string s = "partial(" + std::to_string(itr.first) + "," + + phi::ReduceTypeStrings[static_cast(itr.second)] + + ")"; + partial_status_strs.emplace_back(s); + } + os << "," + + phi::distributed::auto_parallel::str_join( + partial_status_strs) + + "}"; + } else { + os << "}"; + } + } + auto num_result_dist_attrs = op_dist_attr.num_result_dist_attrs(); + for (uint32_t i = 0; i < num_result_dist_attrs; ++i) { + auto dist_attr = op_dist_attr.result_dist_attr(i); + os << ",result(" + std::to_string(i) + "):{"; + if (!dist_attr) { + os << "null}"; + continue; + } + if (dist_attr.process_mesh_attr() != op_dist_attr.process_mesh_attr()) { + os << "mesh_shape:[" + + phi::distributed::auto_parallel::str_join( + dist_attr.process_mesh_attr().shape()) + + "],"; + } + os << "dims_maping:[" + + phi::distributed::auto_parallel::str_join( + dist_attr.dims_mapping()) + + "]"; + if (dist_attr.partial_status().size() > 0) { + std::vector partial_status_strs; + for (auto &itr : dist_attr.partial_status()) { + std::string s = "partial(" + std::to_string(itr.first) + "," + + phi::ReduceTypeStrings[static_cast(itr.second)] + + ")"; + partial_status_strs.emplace_back(s); + } + os << "," + + phi::distributed::auto_parallel::str_join( + partial_status_strs) + + "}"; + } else { + os << "}"; + } + } + os << "}"; + } else { + os << "error_attribute_type"; + } +} + +pir::OpPrintFn DistDialect::PrintOperation(pir::Operation *op) const { + return nullptr; +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistDialect) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h new file mode 100644 index 0000000000000..2a7420b0a495a --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h @@ -0,0 +1,41 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/core/dialect.h" + +namespace paddle { +namespace dialect { + +class DistDialect : public pir::Dialect { + public: + explicit DistDialect(pir::IrContext* context); + + static const char* name() { return "pd_dist"; } + + void PrintType(pir::Type type, std::ostream& os) const override; + + void PrintAttribute(pir::Attribute attr, std::ostream& os) const override; + + pir::OpPrintFn PrintOperation(pir::Operation* op) const override; // NOLINT + + private: + void initialize(); +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistDialect) diff --git a/paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc similarity index 76% rename from paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h rename to paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc index 417f3c86c7e43..17e5caa6a22db 100644 --- a/paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.cc @@ -12,12 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once +#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" +namespace paddle::dialect {} // namespace paddle::dialect -namespace symbol { - -IR_API DimExpr SimplifyDimExpr(const DimExpr& dim_expr); - -} // namespace symbol +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistTypeInterface) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_interface.h b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.h new file mode 100644 index 0000000000000..6fca7d4442b7c --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_interface.h @@ -0,0 +1,76 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/pir/include/core/cast_utils.h" +#include "paddle/pir/include/core/dll_decl.h" +#include "paddle/pir/include/core/type.h" + +namespace paddle { +namespace dialect { + +class IR_API DistTypeInterface + : public pir::TypeInterfaceBase { + public: + struct Concept { + /// Defined these methods with the interface. + explicit Concept(pir::Type (*local_type)(pir::Type), + ProcessMeshAttribute (*process_mesh_attr)(pir::Type), + TensorDistAttribute (*tensor_dist_attr)(pir::Type)) + : local_type(local_type), + process_mesh_attr(process_mesh_attr), + tensor_dist_attr(tensor_dist_attr) {} + pir::Type (*local_type)(pir::Type); + ProcessMeshAttribute (*process_mesh_attr)(pir::Type); + TensorDistAttribute (*tensor_dist_attr)(pir::Type); + }; + + template + struct Model : public Concept { + static Type local_type(Type type) { + return pir::cast(type).local_type(); + } + static ProcessMeshAttribute process_mesh_attr(Type type) { + return pir::cast(type).process_mesh_attr(); + } + + static TensorDistAttribute tensor_dist_attr(Type type) { + return pir::cast(type).tensor_dist_attr(); + } + + Model() : Concept(local_type, process_mesh_attr, tensor_dist_attr) {} + }; + + DistTypeInterface(pir::Type type, Concept *impl) + : pir::TypeInterfaceBase(type), impl_(impl) {} + + pir::Type local_type() { return impl_->local_type(*this); } + + ProcessMeshAttribute process_mesh_attr() { + return impl_->process_mesh_attr(*this); + } + + TensorDistAttribute tensor_dist_attr() { + return impl_->tensor_dist_attr(*this); + } + + private: + Concept *impl_; +}; + +} // namespace dialect +} // namespace paddle + +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistTypeInterface) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc new file mode 100644 index 0000000000000..cc06461e66d55 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_op.cc @@ -0,0 +1,279 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_op.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/core/ir_context.h" + +namespace paddle { +namespace dialect { + +const char* ShardTensorOp::attributes_name[1] = {"op_dist_attr"}; +const char* ReShardOp::attributes_name[1] = {"op_dist_attr"}; + +void ShardTensorOp::VerifySig() { + VLOG(4) + << "Start Verifying inputs, outputs and attributes for: ShardTensorOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 1u, + common::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", input_size)); + PADDLE_ENFORCE_EQ((*this) + ->operand_source(0) + .type() + .isa(), + true, + common::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + VLOG(4) << "Verifying attributes:"; + { + auto& attributes = this->attributes(); + PADDLE_ENFORCE_EQ((attributes.count("op_dist_attr") > 0 && + attributes.at("op_dist_attr") + .isa()), + true, + common::errors::PreconditionNotMet( + "Type of attribute: op_dist_attr is not right.")); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 1u, + common::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", output_size)); + PADDLE_ENFORCE_EQ( + (*this)->result(0).type().isa(), + true, + common::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + } + + VLOG(4) << "Verifying op dist attrs:"; + { + auto op_dist_attr = + this->attribute( + "op_dist_attr"); + PADDLE_ENFORCE_EQ(op_dist_attr.num_operand_dist_attrs(), + 0u, + common::errors::PreconditionNotMet( + "The op_dist_attr input size %d must be equal to 0.", + op_dist_attr.num_operand_dist_attrs())); + + PADDLE_ENFORCE_EQ(op_dist_attr.num_result_dist_attrs(), + num_results(), + common::errors::PreconditionNotMet( + "The op_dist_attr output size %d must " + "be equal to op output size %d.", + op_dist_attr.num_result_dist_attrs(), + num_results())); + } + VLOG(4) << "End Verifying for: ShardTensorOp."; +} + +void ShardTensorOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value input, + pir::AttributeMap attributes) { + VLOG(4) << "Start build ShardOp"; + + // Temporary restriction, will support input use_empty false in the future + PADDLE_ENFORCE_EQ( + input.use_empty(), + true, + common::errors::PreconditionNotMet("'input' use_empty is not true")); + + paddle::dialect::DenseTensorType input_tensor_type; + if (input.type().isa()) { + input_tensor_type = + input.type().dyn_cast(); + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Only support paddle::dialect::DenseTensorType")); + } + + PADDLE_ENFORCE_NE( + attributes.find("tensor_dist_attr"), + attributes.end(), + common::errors::NotFound( + "'tensor_dist_attr' Attribute is expected for ShardOp")); + paddle::dialect::TensorDistAttribute tensor_dist_attr = + attributes.at("tensor_dist_attr") + .dyn_cast(); + + VLOG(4) << "Builder construction inputs"; + argument.AddInput(input); + + VLOG(4) << "Builder construction attributes"; + auto process_mesh_attr = tensor_dist_attr.process_mesh_attr(); + auto dims_mapping = tensor_dist_attr.dims_mapping(); + + pir::Attribute op_dist_attr = OperationDistAttribute::get( + pir::IrContext::Instance(), + process_mesh_attr, + std::vector(), + std::vector{tensor_dist_attr}); + argument.AddAttribute("op_dist_attr", op_dist_attr); + + VLOG(4) << "Builder construction outputs"; + auto global_dims = input_tensor_type.dims(); + auto process_mesh_shape = process_mesh_attr.shape(); + PADDLE_ENFORCE_EQ(static_cast(dims_mapping.size()), + global_dims.size(), + common::errors::PreconditionNotMet( + "dims_mapping size %d does not match input size %d", + dims_mapping.size(), + global_dims.size())); + auto local_shape = InferLocalDDim(global_dims, tensor_dist_attr); + pir::Type out_dist_tensor_type = + paddle::dialect::DistDenseTensorType::get(pir::IrContext::Instance(), + input_tensor_type, + tensor_dist_attr, + local_shape); + argument.AddOutput(out_dist_tensor_type); + ::pir::PassStopGradientsDefaultly(argument); +} + +void ReShardOp::VerifySig() { + VLOG(4) << "Start Verifying inputs, outputs and attributes for: ReShardOp."; + VLOG(4) << "Verifying inputs:"; + { + auto input_size = num_operands(); + PADDLE_ENFORCE_EQ( + input_size, + 1u, + common::errors::PreconditionNotMet( + "The size %d of inputs must be equal to 1.", input_size)); + PADDLE_ENFORCE_EQ((*this) + ->operand_source(0) + .type() + .isa(), + true, + common::errors::PreconditionNotMet( + "Type validation failed for the 0th input.")); + } + VLOG(4) << "Verifying attributes:"; + { + auto& attributes = this->attributes(); + PADDLE_ENFORCE_EQ((attributes.count("op_dist_attr") > 0 && + attributes.at("op_dist_attr") + .isa()), + true, + common::errors::PreconditionNotMet( + "Type of attribute: op_dist_attr is not right.")); + } + VLOG(4) << "Verifying outputs:"; + { + auto output_size = num_results(); + PADDLE_ENFORCE_EQ( + output_size, + 1u, + common::errors::PreconditionNotMet( + "The size %d of outputs must be equal to 1.", output_size)); + PADDLE_ENFORCE_EQ( + (*this)->result(0).type().isa(), + true, + common::errors::PreconditionNotMet( + "Type validation failed for the 0th output.")); + } + + VLOG(4) << "Verifying op dist attrs:"; + { + auto op_dist_attr = + this->attribute( + "op_dist_attr"); + PADDLE_ENFORCE_EQ(op_dist_attr.num_operand_dist_attrs(), + 1u, + common::errors::PreconditionNotMet( + "The op_dist_attr input size %d must be equal to 1.", + op_dist_attr.num_operand_dist_attrs())); + + PADDLE_ENFORCE_EQ(op_dist_attr.num_result_dist_attrs(), + num_results(), + common::errors::PreconditionNotMet( + "The op_dist_attr output size %d must " + "be equal to op output size %d.", + op_dist_attr.num_result_dist_attrs(), + num_results())); + } + VLOG(4) << "End Verifying for: ShardTensorOp."; +} + +void ReShardOp::Build(pir::Builder& builder, + pir::OperationArgument& argument, + pir::Value input, + TensorDistAttribute tensor_dist_attr) { + VLOG(4) << "Start build ReShardOp"; + + paddle::dialect::DistDenseTensorType input_tensor_type; + if (input.type().isa()) { + input_tensor_type = + input.type().dyn_cast(); + } else { + PADDLE_THROW(common::errors::Unimplemented( + "Only support paddle::dialect::DistDenseTensorType")); + } + + VLOG(4) << "Builder construction inputs"; + argument.AddInput(input); + + VLOG(4) << "Builder construction attributes"; + pir::Attribute op_dist_attr = OperationDistAttribute::get( + pir::IrContext::Instance(), + input_tensor_type.tensor_dist_attr().process_mesh_attr(), + std::vector{input_tensor_type.tensor_dist_attr()}, + std::vector{tensor_dist_attr}); + argument.AddAttribute("op_dist_attr", op_dist_attr); + + VLOG(4) << "Builder construction outputs"; + auto global_dims = input_tensor_type.global_ddim(); + auto process_mesh_attr = tensor_dist_attr.process_mesh_attr(); + auto dims_mapping = tensor_dist_attr.dims_mapping(); + + auto process_mesh_shape = process_mesh_attr.shape(); + PADDLE_ENFORCE_EQ(static_cast(dims_mapping.size()), + global_dims.size(), + common::errors::PreconditionNotMet( + "dst dims_mapping size %d does not match input size %d", + dims_mapping.size(), + global_dims.size())); + + auto local_shape = InferLocalDDim(global_dims, tensor_dist_attr); + pir::Type out_dist_tensor_type = paddle::dialect::DistDenseTensorType::get( + pir::IrContext::Instance(), + input_tensor_type.dense_tensor_type(), + tensor_dist_attr, + local_shape); + argument.AddOutput(out_dist_tensor_type); +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ReShardOp) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_op.h b/paddle/fluid/pir/dialect/distributed/ir/dist_op.h new file mode 100644 index 0000000000000..7ae81a0040702 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_op.h @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/pir/include/core/builder.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/op_base.h" +#include "paddle/pir/include/core/operation_utils.h" + +namespace paddle { +namespace dialect { +class TensorDistAttribute; + +class ShardTensorOp : public pir::Op { + public: + using Op::Op; + static const char* name() { return "dist_op.shard_tensor"; } + static const char* attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + TEST_API static void Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + pir::Value input, + pir::AttributeMap attributes); + pir::Value input() { return operand_source(0); } + pir::Value out() { return result(0); } + void VerifySig(); +}; + +class ReShardOp : public pir::Op { + public: + using Op::Op; + static const char* name() { return "dist_op.reshard"; } + static const char* attributes_name[1]; + static constexpr uint32_t attributes_num = 1; + TEST_API static void Build(pir::Builder& builder, // NOLINT + pir::OperationArgument& argument, // NOLINT + pir::Value input, + TensorDistAttribute tensor_dist_attr); + void VerifySig(); +}; +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ShardTensorOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ReShardOp) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc new file mode 100644 index 0000000000000..9741a76714816 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h" +#include "paddle/common/enforce.h" +#include "paddle/pir/include/core/operation.h" + +namespace paddle { +namespace dialect { + +bool HasDistInput(const std::vector& inputs, + ProcessMeshAttribute* p_mesh_attr) { + for (auto value : inputs) { + if (auto dist_type = value.type().dyn_cast()) { + if (p_mesh_attr) { + *p_mesh_attr = dist_type.process_mesh_attr(); + } + return true; + } + } + return false; +} + +void CvtAllInputsToDist(const std::vector& inputs, + ProcessMeshAttribute mesh_attr) { + for (auto value : inputs) { + if (auto type = value.type()) { + if (type.isa()) continue; + auto dense_type = type.dyn_cast(); + if (!dense_type) { + PADDLE_THROW(common::errors::Unimplemented( + "Currently only support convert dense_tensor_type to dist type.")); + } + auto ctx = pir::IrContext::Instance(); + auto dist_type = DistDenseTensorType::get(ctx, dense_type, mesh_attr); + value.set_type(dist_type); + if (auto define_op = value.defining_op()) { + if (define_op->num_operands() != 0u) { + PADDLE_THROW(common::errors::InvalidArgument( + "Currently only allowed add dist attribue for leaf nodes " + "operation. The current op is %s", + define_op->name())); + } + if (define_op->num_results() != 1u) { + PADDLE_THROW(common::errors::InvalidArgument( + "Currently only allowed add dist attribue for operation with " + "single output. The current op is %s", + define_op->name())); + } + define_op->set_attribute( + kAttrOpDistAttr, + OperationDistAttribute::get( + ctx, mesh_attr, {}, {dist_type.tensor_dist_attr()})); + } + } + } +} + +phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type) { + auto pir_attr = type.tensor_dist_attr(); + phi::distributed::TensorDistAttr phi_attr; + phi_attr.set_process_mesh(pir_attr.process_mesh_attr().process_mesh()); + phi_attr.set_dims_mapping(pir_attr.dims_mapping()); + phi_attr.set_partial_status(pir_attr.partial_status()); + return phi::distributed::DistMetaTensor(type.global_ddim(), phi_attr); +} + +TensorDistAttribute CvtToPirDistAttr( + const phi::distributed::ArgDistAttr& dist_attr) { + auto& attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr); + if (attr.process_mesh().empty()) return nullptr; + return TensorDistAttribute::get(pir::IrContext::Instance(), + attr.process_mesh(), + attr.dims_mapping(), + attr.partial_status()); +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h new file mode 100644 index 0000000000000..24d8d2d2143b0 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_tools.h @@ -0,0 +1,35 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" +#include "paddle/pir/include/core/value.h" + +namespace paddle { +namespace dialect { + +bool HasDistInput(const std::vector& inputs, + ProcessMeshAttribute* p_mesh_attr = nullptr); + +void CvtAllInputsToDist(const std::vector& inputs, + ProcessMeshAttribute mesh_attr); + +phi::distributed::DistMetaTensor CvtToDistMetaTensor(DistDenseTensorType type); +TensorDistAttribute CvtToPirDistAttr( + const phi::distributed::ArgDistAttr& dist_attr); + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc new file mode 100644 index 0000000000000..5753608c85256 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.cc @@ -0,0 +1,75 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/distributed/ir/type_storage.h" +#include "paddle/pir/include/core/ir_context.h" + +namespace paddle { +namespace dialect { + +pir::DenseTensorType DistDenseTensorType::dense_tensor_type() const { + return storage()->dense_tensor_type; +} + +TensorDistAttribute DistDenseTensorType::tensor_dist_attr() const { + return storage()->tensor_dist_attr; +} + +const common::DDim& DistDenseTensorType::local_ddim() const { + return storage()->local_ddim; +} + +DistDenseTensorType DistDenseTensorType::get( + pir::IrContext* ctx, + pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr, + const common::DDim& local_ddim) { + return Base::get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim); +} + +common::DDim InferLocalDDim(const common::DDim& global_ddim, + TensorDistAttribute dist_attr) { + auto& mesh_dim = dist_attr.process_mesh_attr().shape(); + auto& dim_mapping = dist_attr.dims_mapping(); + PADDLE_ENFORCE_EQ(global_ddim.size(), + dim_mapping.size(), + ::common::errors::PreconditionNotMet( + "The global ddim size must equal to dim_mapping's " + "size, but bot %d vs %d", + global_ddim.size(), + dim_mapping.size())); + common::DDim local_ddim(global_ddim); + for (size_t i = 0; i < dim_mapping.size(); ++i) { + if (dim_mapping[i] != -1) { + auto dim_size = mesh_dim.at(dim_mapping[i]); + local_ddim[i] = (global_ddim[i] + dim_size - 1) / dim_size; + } + } + return local_ddim; +} + +auto DistDenseTensorType::local_type() const -> Type { + return pir::DenseTensorType::get(pir::IrContext::Instance(), + dtype(), + local_ddim(), + data_layout(), + lod(), + offset()); +} + +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DistDenseTensorType) diff --git a/paddle/fluid/pir/dialect/distributed/ir/dist_type.h b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h new file mode 100644 index 0000000000000..2344a97399e34 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/dist_type.h @@ -0,0 +1,91 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_interface.h" +#include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/type.h" + +namespace paddle { +namespace dialect { + +class DistDenseTensorTypeStorage; + +common::DDim InferLocalDDim(const common::DDim& global_ddim, + TensorDistAttribute dist_attr); +class DistDenseTensorType + : public pir::Type::TypeBase { + public: + using Base::Base; + using LoD = pir::DenseTensorTypeStorage::LoD; + + pir::DenseTensorType dense_tensor_type() const; + TensorDistAttribute tensor_dist_attr() const; + const common::DDim& global_ddim() const { return dense_tensor_type().dims(); } + const common::DDim& local_ddim() const; + Type dtype() const { return dense_tensor_type().dtype(); } + DataLayout data_layout() const { return dense_tensor_type().data_layout(); } + const LoD& lod() const { return dense_tensor_type().lod(); } + size_t offset() const { return dense_tensor_type().offset(); } + + Type prim_type() { return dense_tensor_type(); } + Type local_type() const; + + ProcessMeshAttribute process_mesh_attr() const { + return tensor_dist_attr().process_mesh_attr(); + } + const std::vector& dims_mapping() const { + return tensor_dist_attr().dims_mapping(); + } + std::set partial_dims() const { + return tensor_dist_attr().partial_dims(); + } + const flat_hash_map& partial_status() const { + return tensor_dist_attr().partial_status(); + } + + static DistDenseTensorType get(pir::IrContext* ctx, + pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr, + const common::DDim& local_ddim); + static DistDenseTensorType get(pir::IrContext* ctx, + pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr) { + if (!dense_tensor_type) return nullptr; + auto local_ddim = + InferLocalDDim(dense_tensor_type.dims(), tensor_dist_attr); + return get(ctx, dense_tensor_type, tensor_dist_attr, local_ddim); + } + + // return the replicated dist dense tensor type. + static DistDenseTensorType get(pir::IrContext* ctx, + pir::DenseTensorType dense_tensor_type, + ProcessMeshAttribute process_mesh_attr) { + auto& ddim = dense_tensor_type.dims(); + auto attr = TensorDistAttribute::get( + ctx, process_mesh_attr, std::vector(ddim.size(), -1)); + return get(ctx, dense_tensor_type, attr, ddim); + } +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DistDenseTensorType) diff --git a/paddle/fluid/pir/dialect/distributed/ir/type_storage.h b/paddle/fluid/pir/dialect/distributed/ir/type_storage.h new file mode 100644 index 0000000000000..e6dde5e0df0c9 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/ir/type_storage.h @@ -0,0 +1,82 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/pir/include/core/builtin_type.h" + +namespace paddle { +namespace dialect { +/// +/// \brief Define Parametric TypeStorage for DistDenseTensorType. +/// +class DistDenseTensorTypeStorage : public pir::TypeStorage { + public: + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = + std::tuple; + + DistDenseTensorTypeStorage(pir::DenseTensorType dense_tensor_type, + TensorDistAttribute tensor_dist_attr, + const common::DDim& local_ddim) + : dense_tensor_type(dense_tensor_type), + tensor_dist_attr(tensor_dist_attr), + local_ddim(local_ddim) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static DistDenseTensorTypeStorage* Construct(ParamKey&& key) { + return new DistDenseTensorTypeStorage( + std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + auto dense_tensor_type_hash = std::hash()(std::get<0>(key)); + auto tensor_dist_attr_hash = std::hash()(std::get<1>(key)); + auto local_ddim_hash = std::hash()(std::get<2>(key)); + auto value = pir::detail::hash_combine(dense_tensor_type_hash, + tensor_dist_attr_hash); + return pir::detail::hash_combine(value, local_ddim_hash); + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return dense_tensor_type == std::get<0>(key) && + tensor_dist_attr == std::get<1>(key) && + local_ddim == std::get<2>(key); + } + + /// + /// \brief DistDenseTensorTypeStorage include three parameters: + /// dense_tensor_type, tensor_dist_attr and local_ddim; + /// + pir::DenseTensorType dense_tensor_type; + TensorDistAttribute tensor_dist_attr; + common::DDim local_ddim; +}; + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.cc b/paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.cc new file mode 100644 index 0000000000000..60d42984c57b6 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.cc @@ -0,0 +1,149 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h" + +#include +#include +#include + +#include "paddle/common/flags.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/include/core/attribute.h" + +using paddle::dialect::DistDenseTensorType; + +COMMON_DECLARE_bool(print_ir); + +namespace paddle { +namespace dialect { + +inline bool IsShardTensorOp(pir::Operation* op) { + std::string op_name = op->name(); + return op_name.find("shard_tensor") != op_name.npos; +} + +void ProcessBlock(pir::Block* block) { + std::vector deleted_ops; + + for (auto iter = block->begin(); iter != block->end(); ++iter) { + pir::Operation* op_item = &(*iter); + VLOG(6) << "mix_to_dist main loop over op name " << op_item->name(); + + if (paddle::dialect::IsShardTensorOp(op_item)) { + pir::Value shard_operand_value = op_item->operand_source(0); + pir::Value shard_result_value = op_item->result(0); + pir::Operation* shard_operand_define_op = + shard_operand_value.defining_op(); + std::string define_op_name = shard_operand_define_op->name(); + + // TODO(2024-Q2) Support more paddle op + if (define_op_name != "builtin.parameter" && + define_op_name != "pd_op.data") { + PADDLE_THROW(platform::errors::Unimplemented( + "op [%s] is not Supported by shard_tensor op in pir mode.", + define_op_name)); + } + + // TODO(2024-Q2) Support shard_tensor is called after tensor has been + // used. + if (shard_operand_value.use_count() != 1) { + PADDLE_THROW(platform::errors::Unimplemented( + "shard_tensor is supposed to be called right after tensor is " + "created, the use_count of tensor to be sharded is [%d] which is " + "not Supported in right now.", + shard_operand_value.use_count())); + } + shard_operand_value.set_type(shard_result_value.type()); + shard_result_value.ReplaceAllUsesWith(shard_operand_value); + + shard_operand_define_op->set_attribute( + kAttrOpDistAttr, op_item->attribute(kAttrOpDistAttr)); + deleted_ops.push_back(op_item); + } + + // TODO(2024-Q2) Handle other shard annotation op in future. + } + + for (auto* op : deleted_ops) { + // TODO(2024-Q2) Support control flow / region + VLOG(6) << "mix_to_dist pass delete op [" << op->name() << "]."; + op->Erase(); + } +} + +/* Verification: + 1. all operators have OperatorDistAttr. + 2. all Values (Results) are DistDenseTensorType. + 3. no shard_tensor in block. +*/ +void VerifyBlock(pir::Block* block) { + for (auto iter = block->begin(); iter != block->end(); ++iter) { + pir::Operation* op_item = &(*iter); + PADDLE_ENFORCE_EQ(paddle::dialect::IsShardTensorOp(op_item), + false, + phi::errors::PreconditionNotMet( + "Block still contain shard_tensor_op.")); + + if (op_item && !op_item->HasAttribute(kAttrOpDistAttr)) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "The op [%s] does not hase OperatorDistAttr after Mix2Dist Pass.", + op_item->name())); + } + + for (size_t i = 0; i < op_item->num_results(); ++i) { + PADDLE_ENFORCE_EQ(op_item->result(i).type().isa(), + true, + phi::errors::PreconditionNotMet( + "[%d]'s input of [%s] is NOT DistDenseTensorType", + i, + op_item->name())); + } + } +} + +std::shared_ptr MixToDistPass(pir::Program* prog) { + if (FLAGS_print_ir) { + std::cout << "IR before MixToDist Pass = " << *prog << std::endl; + } + + pir::IrMapping mapper; + auto new_prog = prog->Clone(mapper); + + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + ProcessBlock(new_prog->block()); + VerifyBlock(new_prog->block()); + + if (FLAGS_print_ir) { + std::cout << "IR after MixToDist Pass = " << *new_prog << std::endl; + } + + return new_prog; +} + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h b/paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h new file mode 100644 index 0000000000000..978f64f12d2b1 --- /dev/null +++ b/paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/pir/include/core/program.h" + +namespace paddle { +namespace dialect { + +// pir::Type ConvertOpTypeToKernelType(pir::Type op_type); + +TEST_API std::shared_ptr MixToDistPass(pir::Program* prog); + +void ProcessBlock(pir::Block* block); + +void VerifyBlock(pir::Block* block); + +} // namespace dialect +} // namespace paddle diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc index 0c8f007a51a9d..c3e44d4e3ef35 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_op.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" +#include + #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" +#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/include/core/builtin_attribute.h" diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc index f293bd5cf9baa..ef3a9a7c0b307 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.cc @@ -17,6 +17,10 @@ namespace paddle { namespace dialect { +pir::Type AllocatedDenseTensorType::prim_type() { + return storage()->dense_tensor_type_; +} + const phi::Place& AllocatedDenseTensorType::place() const { return storage()->place_; } @@ -41,6 +45,10 @@ size_t AllocatedDenseTensorType::offset() const { return storage()->dense_tensor_type_.offset(); } +pir::Type AllocatedSelectedRowsType::prim_type() { + return storage()->selected_rows_type_; +} + const phi::Place& AllocatedSelectedRowsType::place() const { return storage()->place_; } @@ -65,6 +73,10 @@ size_t AllocatedSelectedRowsType::offset() const { return storage()->selected_rows_type_.offset(); } +pir::Type AllocatedDenseTensorArrayType::prim_type() { + return storage()->dense_tensor_array_type_; +} + const phi::Place& AllocatedDenseTensorArrayType::place() const { return storage()->place_; } diff --git a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h index f8595c6ec68df..8bfdf0bae7906 100644 --- a/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h +++ b/paddle/fluid/pir/dialect/kernel/ir/kernel_type.h @@ -24,7 +24,8 @@ namespace dialect { class AllocatedDenseTensorType : public pir::Type::TypeBase { + AllocatedDenseTensorTypeStorage, + pir::WrapTypeInterface> { public: using Base::Base; @@ -49,6 +50,8 @@ class AllocatedDenseTensorType ctx, place, dense_tensor_type); } + pir::Type prim_type(); + const phi::Place &place() const; pir::Type dtype() const; @@ -65,7 +68,8 @@ class AllocatedDenseTensorType class AllocatedSelectedRowsType : public pir::Type::TypeBase { + AllocatedSelectedRowsTypeStorage, + pir::WrapTypeInterface> { public: using Base::Base; @@ -90,6 +94,8 @@ class AllocatedSelectedRowsType ctx, place, type); } + pir::Type prim_type(); + const phi::Place &place() const; pir::Type dtype() const; @@ -106,7 +112,8 @@ class AllocatedSelectedRowsType class AllocatedDenseTensorArrayType : public pir::Type::TypeBase { + AllocatedDenseTensorArrayTypeStorage, + pir::WrapTypeInterface> { public: using Base::Base; @@ -129,6 +136,8 @@ class AllocatedDenseTensorArrayType ctx, place, type); } + pir::Type prim_type(); + const phi::Place &place() const; const pir::Type &dtype() const; diff --git a/paddle/fluid/pir/dialect/op_generator/api_gen.py b/paddle/fluid/pir/dialect/op_generator/api_gen.py index d3c1a718a61b3..d049adc0ac4b1 100644 --- a/paddle/fluid/pir/dialect/op_generator/api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/api_gen.py @@ -105,7 +105,7 @@ auto op_name = phi::TransToFluidOpName("{op_name}"); paddle::small_vector, egr::kSlotSmallVectorSize> amp_values_vector = {{ {no_optional_inputs} }}; {optional_inputs} - auto amp_dst_dtype = paddle::imperative::GetAmpDestDtype("{op_name}", amp_values_vector); + auto amp_dst_dtype = paddle::imperative::GetAmpDestDtype(op_name, amp_values_vector); {new_inputs} {{ paddle::imperative::AutoCastGuard guard(egr::Controller::Instance().GetCurrentAmpAttrs(), paddle::imperative::AmpLevel::O0); @@ -656,10 +656,12 @@ def _gen_amp_logic(self, op_info, op_name, is_mutable_attr): input_list = op_info.input_name_list if not input_list: return ( - f'VLOG(7) << " No AMP for {op_name} because it has no input. ";' + f'VLOG(5) << " No AMP for {op_name} because it has no input. ";' ) if op_name.endswith(('_grad', '_grad_')): - return 'VLOG(7) << " No AMP for grad apis. ";' + return 'VLOG(5) << " No AMP for grad apis. ";' + if op_name.endswith('_') or op_name == 'cast': + return f'VLOG(5) << "No AMP for {op_name} because it is a inplace or cast api.";' return AMP_LOGIC_TEMPLATE.format( op_name=op_name, no_optional_inputs=self._gen_amp_no_optional_inputs(op_info), diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 9af8dfa12d702..4d37aaf829861 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -24,6 +24,7 @@ "batch_norm", "batch_norm_", "dropout", + "elu", "embedding", "flatten", "full_like", @@ -39,7 +40,7 @@ "mean", "pow", "relu", - "rsqrt", + "relu6", "sigmoid", "silu", "swiglu", @@ -57,6 +58,7 @@ decomp_interface_implementation_gen_op_list = [ "add_n", "dropout", + "elu", "embedding", "flatten", "full_like", @@ -72,7 +74,7 @@ "mean", "pow", "relu", - "rsqrt", + "relu6", "sigmoid", "silu", "swiglu", diff --git a/paddle/fluid/pir/dialect/op_generator/gen_utils.py b/paddle/fluid/pir/dialect/op_generator/gen_utils.py new file mode 100644 index 0000000000000..79a1f99fca058 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/gen_utils.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def to_pascal_case(s): + words = s.split("_") + if s[-1] == "_": + return "".join([word.capitalize() for word in words]) + "_" + else: + return "".join([word.capitalize() for word in words]) + "" diff --git a/paddle/fluid/pir/dialect/op_generator/op_all_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_all_func_gen.py new file mode 100644 index 0000000000000..57cb95eec9eb7 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_all_func_gen.py @@ -0,0 +1,39 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from op_infer_spmd_func_gen import gen_op_infer_spmd_func +from op_infermeta_func_gen import gen_op_infermeta_func +from op_member_access_func_gen import gen_op_member_access_func +from op_vjp_interface_func_gen import gen_op_vjp_interface_func + +all_gen_op_func_list = [ + gen_op_infer_spmd_func, + gen_op_infermeta_func, + gen_op_member_access_func, + gen_op_vjp_interface_func, +] + + +def gen_op_all_func(args, op_info, op_info_items): + interface_list = [] + declare_list = [] + impl_list = [] + for func in all_gen_op_func_list: + interface, declare, impl = func(args, op_info, op_info_items) + interface_list += interface + if declare is not None: + declare_list.append(declare) + if impl is not None: + impl_list.append(impl) + return interface_list, declare_list, impl_list diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 3365421990f1b..ee45bdf338270 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -248,8 +248,9 @@ def GenBuildInputArgsStr( } -def GenBuildInserFullForMutableAttribute( - op_class_name, +def GenBuildInsertFullForMutableAttribute( + args, + op_info, op_attribute_name_list, op_attribute_build_arg_type_list, op_mutable_attribute_name_list, @@ -386,9 +387,7 @@ def GenBuildAttributes( op_attribute_type=op_non_mutable_attribute_type_list[idx], attr=op_non_mutable_attribute_name_list[idx], ) - attr_str += """ argument.AddAttribute("{attr_name}", attr_{attr_name});\n argument_attributes.insert({{"{attr_name}", attr_{attr_name}}});\n""".format( - attr_name=op_non_mutable_attribute_name_list[idx] - ) + attr_str += f""" argument_attributes.insert({{"{op_non_mutable_attribute_name_list[idx]}", attr_{op_non_mutable_attribute_name_list[idx]}}});\n""" return attr_str @@ -480,15 +479,15 @@ def GenBuildOutputs( """ - CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::IntArray {name}; + CREATE_INTARRAY_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE = """ phi::IntArray {name}; if ({name}_.isa() && {name}_.defining_op()->isa()) {{ - {name} = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + {name} = phi::IntArray(paddle::dialect::GetInt64Vector( {name}_.defining_op() ->dyn_cast() - .attribute("value")))); + .attribute("value"))); }} else if ({name}_.type().isa()) {{ size_t {name}_size = {name}_.type().dyn_cast().size(); - {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); + {name} = phi::IntArray(std::vector({name}_size, -1)); {name}.SetFromTensor(true); }} else if ({name}_.type().isa()) {{ common::DDim {name}_dim = {name}_.type().dyn_cast().dims(); @@ -496,13 +495,13 @@ def GenBuildOutputs( if (common::contain_unknown_dim({name}_dim)) {{ {name}_size = 1; }} - {name} = std::move(phi::IntArray(std::vector({name}_size, -1))); + {name} = phi::IntArray(std::vector({name}_size, -1)); {name}.SetFromTensor(true); }} else {{ PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType")); }}\n""" - CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector {name}; + CREATE_VECTOR_INT_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE = """ std::vector {name}; if ({name}_.isa() && {name}_.defining_op()->isa()) {{ {name} = paddle::dialect::GetInt64Vector( {name}_.defining_op() @@ -522,17 +521,17 @@ def GenBuildOutputs( PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType")); }}\n""" - CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; + CREATE_SCALAR_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE = """ phi::Scalar {name}; if ({name}_.isa() && {name}_.defining_op()->isa()) {{ - {name} = std::move(phi::Scalar({name}_.defining_op() + {name} = phi::Scalar({name}_.defining_op() ->dyn_cast() .attribute("value") .dyn_cast() .data() - .to())); + .to()); }} else {{ - {name} = std::move(phi::Scalar(-1)); + {name} = phi::Scalar(-1); {name}.SetFromTensor(true); }}\n""" @@ -557,15 +556,11 @@ def GenBuildOutputs( # is a vector if 'pir::VectorType' in op_input_type_list[idx]: if op_input_optional_list[idx] == 'false': - build_output_str += " pir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( - name=op_input_name_list[idx] - ) + build_output_str += f" pir::VectorType {op_input_name_list[idx]} = {op_input_name_list[idx]}_.type().dyn_cast(); (void){op_input_name_list[idx]};\n" # is a Tensor else: if op_input_optional_list[idx] == 'false': - build_output_str += " {type} {name} = {name}_.type().dyn_cast<{type}>(); (void){name};\n".format( - type=op_input_type_list[idx], name=op_input_name_list[idx] - ) + build_output_str += f" {op_input_type_list[idx]} {op_input_name_list[idx]} = {op_input_name_list[idx]}_.type().dyn_cast<{op_input_type_list[idx]}>(); (void){op_input_name_list[idx]};\n" # Prepare mutable attributes if mutable_attr_is_input: @@ -577,16 +572,16 @@ def GenBuildOutputs( op_class_name in _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE ): - build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx] ) else: - build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx] ) # scalar elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": - build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1], ) @@ -594,7 +589,7 @@ def GenBuildOutputs( elif attr_dtype[0] == "pir::StrAttribute": build_output_str += "" else: - assert "mutable attribtue type is not right." + assert "mutable attribute type is not right." build_output_str += "\n" # Prepare inputs_meta_tensor & attributes for infer meta @@ -679,12 +674,12 @@ def GenBuildOutputs( CREATE_INFER_META_FUNC_TEMPLATE = """ phi::{func}({args}); """ - CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """ + CREATE_INFER_META_FUNC_WITH_META_CONFIG_TEMPLATE = """ phi::{func}({args}, phi::MetaConfig(false, false)); """ if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG: build_output_str += ( - CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format( + CREATE_INFER_META_FUNC_WITH_META_CONFIG_TEMPLATE.format( func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) ) ) @@ -748,6 +743,7 @@ def GenBuildOutputs( type=op_output_type_list[idx], name=output_name ) + build_output_str += " argument.AddAttributes(argument_attributes);\n" build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" # NOTE(Aurelius84): PassStopGradients must be placed after argument.AddOutputs. build_output_str += " ::pir::PassStopGradientsDefaultly(argument);\n" @@ -756,10 +752,8 @@ def GenBuildOutputs( def gen_build_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, + args, + op_info, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -770,18 +764,13 @@ def gen_build_func_str( op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, muta_attr_is_input=False, attr_args_is_map=False, ): + op_input_name_list = op_info.input_name_list build_args_for_declare = "" build_func = "" - build_info_str = OP_INFO_TEMPLATE.format(op_name=op_class_name) + build_info_str = OP_INFO_TEMPLATE.format(op_name=op_info.class_name) build_args_for_declare = GenBuildInputArgsStr( op_input_name_list, @@ -813,8 +802,9 @@ def gen_build_func_str( inset_full_for_mutable_attributes_str = "" if not muta_attr_is_input: inset_full_for_mutable_attributes_str = ( - GenBuildInserFullForMutableAttribute( - op_class_name, + GenBuildInsertFullForMutableAttribute( + args, + op_info, op_attribute_name_list, op_attribute_build_arg_type_list, op_mutable_attribute_name_list, @@ -830,44 +820,53 @@ def gen_build_func_str( op_non_mutable_attribute_type_list, ) - build_outputs_str = """ - std::vector argument_outputs = {op_name}::InferMeta(argument_inputs, argument_attributes); + build_outputs_str = f""" + std::vector argument_outputs = {op_info.class_name}::InferMeta(argument_inputs, &argument_attributes); + argument.AddAttributes(argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); - ::pir::PassStopGradientsDefaultly(argument);""".format( - op_name=op_class_name - ) + ::pir::PassStopGradientsDefaultly(argument);""" GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); """ GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name}; for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); }} """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); """ GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); """ @@ -900,7 +899,7 @@ def gen_build_func_str( data_name = "AsString" get_attributes_str += ( GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], inner_type=inner_type, @@ -910,7 +909,7 @@ def gen_build_func_str( elif "paddle::dialect::IntArrayAttribute" in attr_types[idx]: get_attributes_str += ( GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], ) @@ -918,7 +917,7 @@ def gen_build_func_str( elif "paddle::dialect::ScalarAttribute" in attr_types[idx]: get_attributes_str += ( GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], ) @@ -926,7 +925,7 @@ def gen_build_func_str( elif "pir::StrAttribute" in attr_types[idx]: get_attributes_str += ( GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], attr_ir_type=attr_types[idx], @@ -934,14 +933,14 @@ def gen_build_func_str( ) else: get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], attr_ir_type=attr_types[idx], ) build_func = OP_BUILD_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, build_info=build_info_str, build_args=build_args_for_define, build_mutable_attributes=inset_full_for_mutable_attributes_str, diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 67462983fbf0a..37e620ab24589 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -17,22 +17,19 @@ import os import pathlib import sys +from distutils.util import strtobool import yaml from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list +from gen_utils import to_pascal_case from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str +from op_all_func_gen import gen_op_all_func from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke -from op_infermeta_gen import ( - gen_infermeta_by_invoke_func_str, - gen_infermeta_func_str, -) from op_interface_gen import ( gen_exclusive_interface_str, - gen_op_infer_meta_str, gen_op_vjp_str, ) from op_kerneltype_gen import gen_kernel_type_for_var_str -from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str from ops_onednn_extra_parser import parse_data_format_tensors, parse_extra_args from parse_kernel_key_gen import gen_parse_kernel_key_str @@ -107,6 +104,11 @@ #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" #include "paddle/phi/core/infermeta_utils.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h" +#endif {only_pd_op_header_files} {other_info} @@ -147,7 +149,6 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {get_kernel_type_for_var_declare} {parse_kernel_key_declare} {infer_symbolic_shape_declare} -{get_inputs_and_outputs} {exclusive_interface} }}; """ @@ -312,7 +313,6 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ PD_MANUAL_OP_LIST = { 'add_n', 'add_n_', - 'add_n_with_kernel', 'split_grad', 'expand', 'increment', @@ -504,8 +504,13 @@ def __init__(self, op_yaml_item, op_compat_item): # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() self.invoke_map = self.parse_invoke_map() + self.spmd_rule_func = None if 'infer_meta' in self.op_yaml_item: self.infer_meta_func = self.op_yaml_item['infer_meta']["func"] + if 'spmd_rule' in self.op_yaml_item['infer_meta']: + self.spmd_rule_func = self.op_yaml_item['infer_meta'][ + 'spmd_rule' + ] else: self.infer_meta_func = None @@ -1075,14 +1080,6 @@ def get_phi_dtype_name(self, name): return name -def to_pascal_case(s): - words = s.split("_") - if s[-1] == "_": - return "".join([word.capitalize() for word in words]) + "_" - else: - return "".join([word.capitalize() for word in words]) + "" - - def get_input_grad_semantic(op_info, op_info_items): input_grad_semantics = [] num_inputs = len(op_info.input_name_list) @@ -1234,7 +1231,9 @@ def GenOneDnnExtraAttrsDefaultValue(onednn_extra_args): return attr_str -def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): +def AutoCodeGen( + args, op_info_items, all_op_info_items, namespaces, dialect_name +): # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list ops_declare_list = [] # all op class declare store in this list @@ -1292,19 +1291,6 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_traits = op_info.traits_list op_interfaces = op_info.interfaces_list op_interfaces += ["paddle::dialect::OpYamlInfoInterface"] - - if op_info.infer_meta_func: - op_interfaces += ["paddle::dialect::InferMetaInterface"] - elif op_invoke_map and op_invoke_map['func'] in op_info_items: - if op_info_items[op_invoke_map['func']].infer_meta_func: - op_interfaces += ["paddle::dialect::InferMetaInterface"] - - if ( - op_info.backward_name - and op_info.op_phi_name[0] not in vjp_interface_black_list - and dialect_name != "onednn_op" - ): - op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str( op_info, op_info_items ) @@ -1381,10 +1367,6 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): # =================================== # # gen interface list str # # =================================== # - op_interfaces_str = "" - if len(op_interfaces) > 0: - op_interfaces_str = "," + ",".join(op_interfaces) - if len(func_list) == 1: op_class_name = to_pascal_case(op_name) + "Op" op_dialect_name = dialect_name + "." + op_name @@ -1410,14 +1392,27 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): kernel_func_name ] - # =================================== # - # gen get input/output methods str # - # =================================== # - op_get_inputs_outputs_str = gen_op_get_inputs_outputs_str( - op_input_name_list, - op_mutable_attribute_name_list, - op_output_name_list, + op_info.class_name = op_class_name + op_info.kernel_input_type_list = op_input_type_list + op_info.kernel_output_type_list = op_output_type_list + + ( + all_interface_list, + exclusive_declare_list, + exclusive_impl_list, + ) = gen_op_all_func(args, op_info, op_info_items) + all_interface_list += op_interfaces + + all_interface_str = "" + if len(all_interface_list) > 0: + all_interface_str = "," + ",".join(all_interface_list) + + all_declare_str = ( + exclusive_interface_str + + '\n' + + '\n'.join(exclusive_declare_list) ) + ops_defined_list += exclusive_impl_list # =================================== # # gen Build methods str # @@ -1438,13 +1433,16 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ) parse_kernel_key_str = "" - if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + if ( + "paddle::dialect::ParseKernelKeyInterface" + in all_interface_list + ): parse_kernel_key_str = parse_kernel_key_template infer_symbolic_shape_str = "" if ( "paddle::dialect::InferSymbolicShapeInterface" - in op_interfaces + in all_interface_list ): infer_symbolic_shape_str = infer_symbolic_shape_template @@ -1453,10 +1451,8 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_args_with_muta_attr_not_input_for_declare, build_func_with_muta_attr_not_input, ) = gen_build_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, + args, + op_info, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -1467,12 +1463,6 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, muta_attr_is_input=False, ) if len(op_attribute_name_list) > 0: @@ -1480,10 +1470,8 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_args_with_attr_is_map_for_declare, build_func_with_attr_is_map, ) = gen_build_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, + args, + op_info, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -1494,12 +1482,6 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, muta_attr_is_input=False, attr_args_is_map=True, ) @@ -1510,10 +1492,8 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_args_with_muta_attr_is_input_for_declare, build_func_with_muta_attr_is_input, ) = gen_build_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, + args, + op_info, op_attribute_name_list, op_attribute_type_list, op_attribute_build_arg_type_list, @@ -1524,18 +1504,10 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): op_non_mutable_attribute_type_list, op_non_mutable_attribute_build_arg_type_list, op_non_mutable_attribute_default_value_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, muta_attr_is_input=True, ) - build_mutable_attr_is_input = "static void Build({build_args});".format( - build_args=build_args_with_muta_attr_is_input_for_declare - ) + build_mutable_attr_is_input = f"static void Build({build_args_with_muta_attr_is_input_for_declare});" if (op_invoke_map is not None) and ( op_invoke_map['func'] in op_info_items ): @@ -1574,7 +1546,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): TEST_API=TEST_API, op_name=op_class_name, dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, + interfaces=all_interface_str, traits=op_traits_str, attribute_declare=op_0_attribute_declare_str, attribute_num=0, @@ -1582,8 +1554,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_mutable_attr_is_input=build_mutable_attr_is_input, build_attr_num_over_1=build_attr_num_over_1, build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1, - get_inputs_and_outputs=op_get_inputs_outputs_str, - exclusive_interface=exclusive_interface_str, + exclusive_interface=all_declare_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, parse_kernel_key_declare=parse_kernel_key_str, infer_symbolic_shape_declare=infer_symbolic_shape_str, @@ -1594,7 +1565,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): TEST_API=TEST_API, op_name=op_class_name, dialect_op_name=op_dialect_name, - interfaces=op_interfaces_str, + interfaces=all_interface_str, traits=op_traits_str, attribute_declare=op_n_attribute_declare_str.format( attribute_num=len( @@ -1606,8 +1577,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): build_mutable_attr_is_input=build_mutable_attr_is_input, build_attr_num_over_1=build_attr_num_over_1, build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1, - get_inputs_and_outputs=op_get_inputs_outputs_str, - exclusive_interface=exclusive_interface_str, + exclusive_interface=all_declare_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, parse_kernel_key_declare=parse_kernel_key_str, infer_symbolic_shape_declare=infer_symbolic_shape_str, @@ -1856,7 +1826,10 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): # generate op ParseKernelKeyInterface function str parse_kernel_key_define_str = '' - if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + if ( + "paddle::dialect::ParseKernelKeyInterface" + in all_interface_list + ): parse_kernel_key_define_str = gen_parse_kernel_key_str( op_class_name ) @@ -1865,7 +1838,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): infer_symbolic_shape_define_str = '' if ( "paddle::dialect::InferSymbolicShapeInterface" - in op_interfaces + in all_interface_list ): infer_symbolic_shape_define_str = ( gen_infer_symbolic_shape_str(op_class_name) @@ -1875,7 +1848,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): infer_symbolic_shape_define_str = '' if ( "paddle::dialect::InferSymbolicShapeInterface" - in op_interfaces + in all_interface_list ): infer_symbolic_shape_define_str = ( gen_infer_symbolic_shape_str(op_class_name) @@ -1893,52 +1866,6 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ) ) - op_infer_meta_str = gen_op_infer_meta_str( - op_info, op_class_name, op_info_items - ) - - op_infer_meta_from_type_str = "" - if op_infer_meta_map is not None: - muta_attr_is_input = ( - True - if len(op_mutable_attribute_name_list) > 0 - else False - ) - op_infer_meta_from_type_str = gen_infermeta_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, - muta_attr_is_input, - attr_args_is_map=True, - ) - - if (op_invoke_map is not None) and ( - op_invoke_map['func'] in op_info_items - ): - op_invoke_class_name = ( - to_pascal_case(op_invoke_map['func']) + "Op" - ) - op_infer_meta_from_type_str = ( - gen_infermeta_by_invoke_func_str( - op_class_name, op_invoke_class_name - ) - ) - # =================================== # # gen Vjp func str # # =================================== # @@ -1979,8 +1906,6 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): ) ops_defined_list.append(op_verify_str) - ops_defined_list.append(op_infer_meta_str) - ops_defined_list.append(op_infer_meta_from_type_str) ops_defined_list.append(op_get_kernel_type_for_var_str) ops_defined_list.append(parse_kernel_key_define_str) ops_defined_list.append(infer_symbolic_shape_define_str) @@ -2060,6 +1985,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name): def OpGenerator( + args, op_yaml_files, op_compat_yaml_file, namespaces, @@ -2207,7 +2133,9 @@ def OpGenerator( source_file_str, op_to_multi_kernels_list, vjp_source_file_str, - ) = AutoCodeGen(items, all_op_info_items, namespaces, dialect_name) + ) = AutoCodeGen( + args, items, all_op_info_items, namespaces, dialect_name + ) op_list_strs.append(op_list_str) declare_type_id_strs.append(declare_type_id_str) define_type_id_strs.append(define_type_id_str) @@ -2361,6 +2289,7 @@ def ParseArguments(): parser.add_argument('--op_vjp_cc_file', type=str) parser.add_argument('--onednn_yaml_file', type=str) parser.add_argument('--ops_onednn_extra_yaml_file', type=str) + parser.add_argument('--with_distributed', type=strtobool) return parser.parse_args() @@ -2385,6 +2314,7 @@ def ParseArguments(): # auto code generate OpGenerator( + args, op_yaml_files, op_compat_yaml_file, namespaces, diff --git a/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py new file mode 100644 index 0000000000000..e8ab19ccf8863 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_infer_spmd_func_gen.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +OP_INFER_SPMD_TEMPLATE = """ + static phi::distributed::SpmdInfo InferSpmd({infer_spmd_args}) {{ + return phi::distributed::{func}({args}); + }} +""" + + +def gen_op_infer_spmd_func(args, op_info, op_info_items): + if not args.with_distributed or op_info.spmd_rule_func is None: + return [], None, None + input_types_map = { + 'paddle::dialect::DenseTensorType': 'const phi::distributed::DistMetaTensor&', + 'pir::VectorType': 'const std::vector&', + } + input_name_list = op_info.input_name_list + input_type_list = op_info.input_type_list + input_name_type_dict = {} + for attr_idx in range(len(input_name_list)): + input_name_type_dict[input_name_list[attr_idx]] = input_types_map[ + input_type_list[attr_idx] + ] + + attr_name_list = op_info.attribute_name_list + attr_type_list = op_info.attribute_gen_arg_type_list + + attr_name_type_dict = {} + for attr_idx in range(len(attr_type_list)): + attr_name_type_dict[attr_name_list[attr_idx]] = attr_type_list[attr_idx] + scalar_list = [ + "Scalar(int64_t)", + "Scalar(int)", + "Scalar(float)", + "Scalar(double)", + ] + if op_info.op_yaml_item['attrs'][attr_idx]['typename'] in scalar_list: + attr_name_type_dict[attr_name_list[attr_idx]] = "const phi::Scalar&" + + spmd_params = input_name_list + attr_name_list + if op_info.kernel_map is not None: + spmd_params = op_info.kernel_map['param'] + args_list_with_type = [] + args_list = [] + for param in spmd_params: + # is input + if param in op_info.input_name_list: + args_list_with_type.append( + input_name_type_dict[param] + " " + param + ) + args_list.append(param) + # is attribute + else: + param_type = attr_name_type_dict[param] + if param_type == "phi::IntArray": + param_type = "const std::vector&" + args_list_with_type.append(param_type + " " + param) + args_list.append(param) + + spmd_rule_func = op_info.spmd_rule_func + if spmd_rule_func is None: + spmd_rule_func = "VariadicReplicatedInferSpmdDynamic" + declare_str = OP_INFER_SPMD_TEMPLATE.format( + infer_spmd_args=', '.join(args_list_with_type), + func=spmd_rule_func, + args=', '.join(args_list), + ) + return [], declare_str, None diff --git a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py similarity index 64% rename from paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py index 500e36881b3f1..0485d2b86a1b3 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py @@ -12,13 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. +from gen_utils import to_pascal_case from op_build_gen import ( _INFERMETA_NEED_META_CONFIG, _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE, ) -OP_INFERMETA_TEMPLATE = """ -std::vector {op_name}::InferMeta(const std::vector& input_values, const pir::AttributeMap& attributes) {{ +OP_INFERMETA_DECL_STRING = ( + " static void InferMeta(phi::InferMetaContext *infer_meta );\n" + " static std::vector InferMeta( const std::vector& input_values, pir::AttributeMap* p_attributes );" +) + +OP_INFERMETA_IMPL_TEMPLATE_1 = """ +void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ + auto fn = PD_INFER_META(phi::{infer_meta_func}); + fn(infer_meta); +}} +""" + +OP_INFERMETA_IMPL_TEMPLATE_2 = """ +std::vector {op_name}::InferMeta(const std::vector& input_values, pir::AttributeMap* p_attributes) {{ + PADDLE_ENFORCE_NOT_NULL( + p_attributes, common::errors::Fatal("AttrtibueMap pointer in InferMeta function is nullptr.")); + auto& attributes = *p_attributes; (void)attributes; {infermeta_inputs} {get_attributes_str} {infermeta_outputs} @@ -26,33 +42,24 @@ }} """ +OP_INFERMETA_IMPL_TEMPLATE_2_BY_INVOKE = """ +std::vector {op_name}::InferMeta(const std::vector& input_values, pir::AttributeMap* attributes) {{ + return {invoke_class}::InferMeta(input_values, attributes); +}} +""" + CREATE_INPUT_VALUE_TEMPLATE = """ pir::Value {input_name}_ = input_values[{index}]; (void){input_name}_;""" ENFORCE_INPUT_NUM_TEMPLATE = """ - IR_ENFORCE(input_values.size() == {op_input_name_list_size}, - "Num of inputs is expected to be {op_input_name_list_size} but got %d.", input_values.size()); -""" - -OP_INFERMETA_BY_INVOKE_TEMPLATE = """ -std::vector {op_name}::InferMeta(const std::vector& input_values, const pir::AttributeMap& attributes) {{ - return {invoke_class}::InferMeta(input_values, attributes); -}} + PADDLE_ENFORCE_EQ(input_values.size() == {op_input_name_list_size}, true, phi::errors::InvalidArgument( + "Num of inputs is expected to be {op_input_name_list_size} but got %d.", input_values.size())); """ GET_INPUT_TYPE_TEMPLATE = """ {type} {name}; if ({name}_.type().isa<{type}>()) {{ {name} = {name}_.type().dyn_cast<{type}>(); (void){name}; - }} else if ({name}_.type().isa<{allocated_type}>()) {{ - {allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>(); - {name} = {type}::get(pir::IrContext::Instance(), - allocated_{name}.dtype(), - allocated_{name}.dims(), - allocated_{name}.data_layout(), - allocated_{name}.lod(), - allocated_{name}.offset()); - (void){name}; }} else {{ PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}")); }} @@ -60,6 +67,7 @@ def get_infermeta_inputs_str( + op_info, inuse_infer_meta_args, op_input_name_list, op_input_type_list, @@ -67,7 +75,7 @@ def get_infermeta_inputs_str( op_mutable_attribute_name_list, mutable_attr_is_input, ): - op_input_name_list_size = len(op_input_name_list) + op_input_name_list_size = len(op_info.input_name_list) if mutable_attr_is_input: op_input_name_list_size += len(op_mutable_attribute_name_list) @@ -75,22 +83,17 @@ def get_infermeta_inputs_str( op_input_name_list_size=str(op_input_name_list_size), ) - for i in range(len(op_input_name_list)): - if op_input_name_list[i] not in inuse_infer_meta_args: + for i in range(len(op_info.input_name_list)): + if op_info.input_name_list[i] not in inuse_infer_meta_args: continue infermeta_inputs_str += CREATE_INPUT_VALUE_TEMPLATE.format( - input_name=op_input_name_list[i], index=str(i) + input_name=op_info.input_name_list[i], index=str(i) ) if mutable_attr_is_input: # add mutable attributes as inputs if len(op_mutable_attribute_name_list) > 0: for i in range(len(op_mutable_attribute_name_list)): - if ( - op_mutable_attribute_name_list[i] - not in inuse_infer_meta_args - ): - continue infermeta_inputs_str += CREATE_INPUT_VALUE_TEMPLATE.format( input_name=op_mutable_attribute_name_list[i], index=str(i + len(op_input_name_list)), @@ -108,9 +111,7 @@ def get_infermeta_inputs_str( # is a vector if 'pir::VectorType' in op_input_type_list[idx]: if op_input_optional_list[idx] == 'false': - infermeta_inputs_str += " pir::VectorType {name} = {name}_.type().dyn_cast(); (void){name};\n".format( - name=op_input_name_list[idx] - ) + infermeta_inputs_str += f" pir::VectorType {op_input_name_list[idx]} = {op_input_name_list[idx]}_.type().dyn_cast(); (void){op_input_name_list[idx]};\n" # is a Tensor else: if op_input_optional_list[idx] == 'false': @@ -128,7 +129,8 @@ def get_infermeta_inputs_str( def GenBuildOutputsPart2( - op_class_name, + args, + op_info, inuse_infer_meta_args, op_input_name_list, op_input_type_list, @@ -158,20 +160,11 @@ def GenBuildOutputsPart2( paddle::dialect::IrMetaTensor meta_{name}; paddle::dialect::IrTensor ir_tensor_{name}; - if ({name}_.impl() != nullptr) {{ VLOG(4) << "Builder construction dense_{name}"; {type} {name}; if ({name}_.type().isa<{type}>()) {{ {name} = {name}_.type().dyn_cast<{type}>(); - }} else if ({name}_.type().isa<{allocated_type}>()) {{ - {allocated_type} allocated_{name} = {name}_.type().dyn_cast<{allocated_type}>(); - {name} = {type}::get(pir::IrContext::Instance(), - allocated_{name}.dtype(), - allocated_{name}.dims(), - allocated_{name}.data_layout(), - allocated_{name}.lod(), - allocated_{name}.offset()); }} else {{ PADDLE_THROW(phi::errors::Unimplemented("Only support {type} or {allocated_type}")); }} @@ -195,13 +188,6 @@ def GenBuildOutputsPart2( {name}_type.data_layout(), {name}_type.lod(), {name}_type.offset())); - }} else if({name}[i].isa()){{ - auto {name}_type = {name}[i].dyn_cast(); - vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()), - {name}_type.dims(), - {name}_type.data_layout(), - {name}_type.lod(), - {name}_type.offset())); }} else {{ PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType")); }} @@ -228,13 +214,6 @@ def GenBuildOutputsPart2( {name}_type.data_layout(), {name}_type.lod(), {name}_type.offset())); - }} else if({name}[i].isa()){{ - auto {name}_type = {name}[i].dyn_cast(); - vec_ir_tensor_{name}.push_back(paddle::dialect::IrTensor(paddle::dialect::TransToPhiDataType({name}_type.dtype()), - {name}_type.dims(), - {name}_type.data_layout(), - {name}_type.lod(), - {name}_type.offset())); }} else {{ PADDLE_THROW(phi::errors::Unimplemented("Only support DenseTensorType or AllocatedDenseTensorType")); }} @@ -253,11 +232,11 @@ def GenBuildOutputsPart2( """ - CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ is_from_tensor = false; - phi::IntArray {name} = std::move(phi::IntArray(paddle::dialect::ParseValueShape({name}_, &is_from_tensor))); + CREATE_INTARRAY_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE = """ is_from_tensor = false; + phi::IntArray {name} = phi::IntArray(paddle::dialect::ParseValueShape({name}_, &is_from_tensor)); if (is_from_tensor) {name}.SetFromTensor(true);\n""" - CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ std::vector {name}; + CREATE_VECTOR_INT_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE = """ std::vector {name}; if ({name}_.isa() && {name}_.defining_op()->isa()) {{ {name} = paddle::dialect::GetInt64Vector( {name}_.defining_op() @@ -273,28 +252,21 @@ def GenBuildOutputsPart2( {name}_size = 1; }} {name} = std::vector({name}_size, -1); - }} else if ({name}_.type().isa()) {{ - common::DDim {name}_dim = {name}_.type().dyn_cast().dims(); - size_t {name}_size = common::product({name}_dim); - if (common::contain_unknown_dim({name}_dim)) {{ - {name}_size = 1; - }} - {name} = std::vector({name}_size, -1); }} else {{ PADDLE_THROW(phi::errors::Unimplemented("Only support VectorType or DenseTensorType or AllocatedDenseTensorType")); }}\n""" - CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE = """ phi::Scalar {name}; + CREATE_SCALAR_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE = """ phi::Scalar {name}; if ({name}_.isa() && {name}_.defining_op()->isa()) {{ - {name} = std::move(phi::Scalar({name}_.defining_op() + {name} = phi::Scalar({name}_.defining_op() ->dyn_cast() .attribute("value") .dyn_cast() .data() - .to())); + .to()); }} else {{ - {name} = std::move(phi::Scalar(-1)); + {name} = phi::Scalar(-1); {name}.SetFromTensor(true); }}\n""" @@ -318,25 +290,23 @@ def GenBuildOutputsPart2( # Prepare mutable attributes if mutable_attr_is_input: for idx in range(len(op_mutable_attribute_name_list)): - if op_mutable_attribute_name_list[idx] not in inuse_infer_meta_args: - continue attr_dtype = op_mutable_attribute_type_list[idx] # int_array if attr_dtype[0] == "paddle::dialect::IntArrayAttribute": if ( - op_class_name + op_info.class_name in _PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE ): - build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + build_output_str += CREATE_VECTOR_INT_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx] ) else: - build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + build_output_str += CREATE_INTARRAY_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx] ) # scalar elif attr_dtype[0] == "paddle::dialect::ScalarAttribute": - build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUE_WITH_UNKONW_DATA_TEMPLATE.format( + build_output_str += CREATE_SCALAR_MUTABLE_ATTRIBUTE_WITH_UNKNOWN_DATA_TEMPLATE.format( name=op_mutable_attribute_name_list[idx], dtype=attr_dtype[1], ) @@ -436,12 +406,12 @@ def GenBuildOutputsPart2( CREATE_INFER_META_FUNC_TEMPLATE = """ phi::{func}({args}); """ - CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE = """ + CREATE_INFER_META_FUNC_WITH_META_CONFIG_TEMPLATE = """ phi::{func}({args}, phi::MetaConfig(false, false)); """ if op_infer_meta_map['func'] in _INFERMETA_NEED_META_CONFIG: build_output_str += ( - CREATE_INFER_META_FUNC_WITH_METACINFIG_TEMPLATE.format( + CREATE_INFER_META_FUNC_WITH_META_CONFIG_TEMPLATE.format( func=op_infer_meta_map['func'], args=", ".join(infer_meta_args) ) ) @@ -454,28 +424,21 @@ def GenBuildOutputsPart2( build_output_str += "\n std::vector argument_outputs;" CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """ - pir::Type {name}_dense_tensor_type = {type}::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset()); - argument_outputs.push_back({name}_dense_tensor_type); + pir::Type {name}_type = CvtTo{type}(dense_{name}); """ - CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE = """ + pir::Type {name}_type; if ({input_name}_.impl() != nullptr) {{ - pir::Type {output_name}_dense_tensor_type = {type}::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(dense_{output_name}.dtype()), dense_{output_name}.dims(), dense_{output_name}.layout(), dense_{output_name}.lod(), dense_{output_name}.offset()); - argument_outputs.push_back({output_name}_dense_tensor_type); - }} else {{ - pir::Type {output_name}_type; - argument_outputs.push_back({output_name}_type); + {name}_type = CvtTo{type}(dense_{name}); }} - """ CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """ std::vector {name}_types; for (size_t i=0; i < static_cast({output_size}); i++) {{ - {name}_types.push_back(paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), paddle::dialect::TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset())); + {name}_types.push_back(CvtToDenseTensorType(vec_dense_{name}[i])); }} - pir::Type {name}_vector_type = pir::VectorType::get(pir::IrContext::Instance(), {name}_types); - argument_outputs.push_back({name}_vector_type); + pir::Type {name}_type = pir::VectorType::get(pir::IrContext::Instance(), {name}_types); """ for idx in range(len(op_output_name_list)): # is a vector @@ -496,60 +459,73 @@ def GenBuildOutputsPart2( build_output_str += ( CREATE_OUTPUT_INPLACE_OPTIONAL_DENSE_TENSOR_TEMPLATE.format( input_name=op_inplace_map[output_name], - output_name=output_name, - type=op_output_type_list[idx], + name=output_name, + type=op_output_type_list[idx][17:], ) ) else: build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format( - type=op_output_type_list[idx], name=output_name + type=op_output_type_list[idx][17:], name=output_name ) + build_output_str += GenDistBranch(args, op_info) + + PUSH_BACK_OUTPUT_TYPE_TEMPLATE = """ + argument_outputs.push_back({name}); +""" + for idx in range(len(op_output_name_list)): + build_output_str += PUSH_BACK_OUTPUT_TYPE_TEMPLATE.format( + name=op_output_name_list[idx] + "_type", + ) return build_output_str def GetAttributes( - op_class_name, - muta_attr_is_input, + op_info, + mutable_attr_is_input, inuse_infer_meta_args, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, attr_args_is_map, ): GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); """ GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().AsString(); """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name}; for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast().at(i).dyn_cast<{inner_type}>().{data_name}()); }} """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().GetData(); """ GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """ - IR_ENFORCE( - attributes.find("{attribute_name}") != attributes.end(), - "'{attribute_name}' Attribute is expected for {op_name}. "); + PADDLE_ENFORCE_NE( + attributes.find("{attribute_name}"), + attributes.end(), + phi::errors::InvalidArgument( + "'{attribute_name}' Attribute is expected for {op_name}. ")); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast().data().to<{attr_type}>(); """ @@ -559,14 +535,14 @@ def GetAttributes( attr_names = [] attr_types = [] attr_build_arg_types = [] - if not muta_attr_is_input: - attr_names = op_attribute_name_list - attr_types = op_attribute_type_list - attr_build_arg_types = op_attribute_build_arg_type_list + if not mutable_attr_is_input: + attr_names = op_info.attribute_name_list + attr_types = op_info.attribute_type_list + attr_build_arg_types = op_info.attribute_build_arg_type_list else: - attr_names = op_non_mutable_attribute_name_list - attr_types = op_non_mutable_attribute_type_list - attr_build_arg_types = op_non_mutable_attribute_build_arg_type_list + attr_names = op_info.non_mutable_attribute_name_list + attr_types = op_info.non_mutable_attribute_type_list + attr_build_arg_types = op_info.non_mutable_attribute_build_arg_type_list if attr_args_is_map: for idx in range(len(attr_names)): if attr_names[idx] not in inuse_infer_meta_args: @@ -584,7 +560,7 @@ def GetAttributes( data_name = "AsString" get_attributes_str += ( GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], inner_type=inner_type, @@ -594,7 +570,7 @@ def GetAttributes( elif "paddle::dialect::IntArrayAttribute" in attr_types[idx]: get_attributes_str += ( GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], ) @@ -602,7 +578,7 @@ def GetAttributes( elif "paddle::dialect::ScalarAttribute" in attr_types[idx]: get_attributes_str += ( GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], ) @@ -610,7 +586,7 @@ def GetAttributes( elif "pir::StrAttribute" in attr_types[idx]: get_attributes_str += ( GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], attr_ir_type=attr_types[idx], @@ -618,7 +594,7 @@ def GetAttributes( ) else: get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( - op_name=op_class_name, + op_name=op_info.class_name, attr_type=attr_type, attribute_name=attr_names[idx], attr_ir_type=attr_types[idx], @@ -626,81 +602,179 @@ def GetAttributes( return get_attributes_str -def gen_infermeta_func_str( - op_class_name, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, - muta_attr_is_input=False, - attr_args_is_map=True, -): +def GenDistBranch(args, op_info): + if not args.with_distributed or op_info.spmd_rule_func is None: + return "" + TEMPLATE = """ + // Auto Parallel condition + ProcessMeshAttribute op_mesh; + if(HasDistInput(input_values, &op_mesh)) {{ + CvtAllInputsToDist(input_values, op_mesh); + auto ctx = pir::IrContext::Instance(); + std::vector operand_dist_attrs, result_dist_attrs;""" + dist_branch_str = TEMPLATE.format() + infer_spmd_args_list = [] + # Prepare inputs_meta_tensor & attributes for infer spmd + for name in op_info.spmd_params: + # is input + if name in op_info.input_name_list: + input_index = op_info.input_name_list.index(name) + # is a vector + if 'pir::VectorType' in op_info.input_type_list[input_index]: + TEMPLATE = """ + std::vector vec_dist_meta_{name}; + for(auto& sub_ir_tensor: {name}.data()) {{ + vec_dist_meta_{name}.push_back(CvtToDistMetaTensor(sub_ir_tensor.dyn_cast())); + }}""" + dist_branch_str += TEMPLATE.format(name=name) + infer_spmd_args_list.append("vec_dist_meta_" + name) + # is a Tensor + else: + if op_info.input_optional_list[input_index] == 'true': + TEMPLATE = """ + phi::distributed::DistMetaTensor dist_meta_{name}; + if({name}_) {{ + dist_meta_{name} = CvtToDistMetaTensor({name}_.type().dyn_cast()); + }}""" + dist_branch_str += TEMPLATE.format(name=name) + else: + TEMPLATE = """ + auto dist_meta_{name} = CvtToDistMetaTensor({name}_.type().dyn_cast());""" + dist_branch_str += TEMPLATE.format(name=name) + infer_spmd_args_list.append("dist_meta_" + name) + else: + attr_index = op_info.attribute_name_list.index(name) + param_type = op_info.attribute_gen_arg_type_list[attr_index] + infer_spmd_args_list.append(name) + if param_type == "phi::IntArray": + if name in op_info.mutable_attribute_name_list: + attr_index = op_info.mutable_attribute_name_list.index(name) + attr_type = op_info.mutable_attribute_type_list[attr_index] + if attr_type[0] == "paddle::dialect::IntArrayAttribute": + infer_spmd_args_list[-1] = name + ".GetData()" + TEMPLATE = """ + auto spmd_info = InferSpmd({args}); + PADDLE_ENFORCE_EQ(spmd_info.first.size(), {input_size}u, common::errors::Unavailable( + "Size of spmd_info.first for op[{op_name}]is unexpected.")); + for(auto& arg_dist : spmd_info.first) {{ + operand_dist_attrs.push_back(CvtToPirDistAttr(arg_dist)); + }} +""" + dist_branch_str += TEMPLATE.format( + args=', '.join(infer_spmd_args_list), + input_size=len(op_info.input_name_list), + op_name=op_info.class_name, + ) + + if len(op_info.mutable_attribute_name_list) > 0: + TEMPLATE = """ + for(int i = {input_size}; i < {all_input_size}; ++i) {{ + if(auto dist_type = input_values[i].type().dyn_cast()) {{ + operand_dist_attrs.push_back(dist_type.tensor_dist_attr()); + }} + else {{ + operand_dist_attrs.push_back(nullptr); + }} + }} +""" + dist_branch_str += TEMPLATE.format( + input_size=len(op_info.input_name_list), + all_input_size=len(op_info.input_name_list) + + len(op_info.mutable_attribute_name_list), + ) + + for idx, output_name in enumerate(op_info.output_name_list): + # is a vector + if 'pir::VectorType' in op_info.output_type_list[idx]: + # Todo: support vector case + dist_branch_str += "" + # is a Tensor + else: + TEMPLATE = """ + auto dist_attr_{name} = CvtToPirDistAttr(spmd_info.second[{idx}]); + result_dist_attrs.push_back(dist_attr_{name}); + argument_outputs.push_back(DistDenseTensorType::get(ctx, {name}_type.dyn_cast(), dist_attr_{name})); +""" + dist_branch_str += TEMPLATE.format(idx=idx, name=output_name) + TEMPLATE = """ + attributes[kAttrOpDistAttr] = OperationDistAttribute::get( + ctx, + op_mesh, + operand_dist_attrs, + result_dist_attrs + ); + return argument_outputs; + }} +""" + dist_branch_str += TEMPLATE.format() + return dist_branch_str + + +def gen_infermeta_func_str(args, op_info): + attr_args_is_map = True + mutable_attr_is_input = ( + True if len(op_info.mutable_attribute_name_list) > 0 else False + ) inuse_infer_meta_args = [] - for idx in range(len(op_infer_meta_map['param'])): - inuse_infer_meta_args.append(op_infer_meta_map['param'][idx]) + for idx in range(len(op_info.infer_meta_map['param'])): + inuse_infer_meta_args.append(op_info.infer_meta_map['param'][idx]) # Prepare outputs_meta_tensor for infer meta - for idx in range(len(op_output_name_list)): - if op_output_name_list[idx].endswith('_grad'): - inuse_infer_meta_args.append(f"{op_output_name_list[idx][0:-5]}") - if op_output_name_list[idx].endswith('_grad_'): - inuse_infer_meta_args.append(f"{op_output_name_list[idx][0:-6]}") - inuse_infer_meta_args.append(f"{op_output_name_list[idx]}") + for idx in range(len(op_info.output_name_list)): + if op_info.output_name_list[idx].endswith('_grad'): + inuse_infer_meta_args.append( + f"{op_info.output_name_list[idx][0:-5]}" + ) + if op_info.output_name_list[idx].endswith('_grad_'): + inuse_infer_meta_args.append( + f"{op_info.output_name_list[idx][0:-6]}" + ) + inuse_infer_meta_args.append(f"{op_info.output_name_list[idx]}") + + spmd_params = [] + if args.with_distributed and op_info.spmd_rule_func is not None: + spmd_params = op_info.input_name_list + op_info.attribute_name_list + if op_info.kernel_map is not None: + spmd_params = op_info.kernel_map['param'] + op_info.spmd_params = spmd_params infermeta_inputs_str = get_infermeta_inputs_str( - inuse_infer_meta_args, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - muta_attr_is_input, + op_info, + inuse_infer_meta_args + spmd_params, + op_info.input_name_list, + op_info.kernel_input_type_list, + op_info.input_optional_list, + op_info.mutable_attribute_name_list, + mutable_attr_is_input, ) get_attributes_str = GetAttributes( - op_class_name, - muta_attr_is_input, - inuse_infer_meta_args, - op_attribute_name_list, - op_attribute_type_list, - op_attribute_build_arg_type_list, - op_non_mutable_attribute_name_list, - op_non_mutable_attribute_type_list, - op_non_mutable_attribute_build_arg_type_list, + op_info, + mutable_attr_is_input, + inuse_infer_meta_args + spmd_params, attr_args_is_map, ) infermeta_outputs_str = GenBuildOutputsPart2( - op_class_name, - inuse_infer_meta_args, - op_input_name_list, - op_input_type_list, - op_input_optional_list, - op_mutable_attribute_name_list, - op_mutable_attribute_type_list, - op_output_name_list, - op_output_type_list, - op_output_size_list, - op_output_optional_list, - op_infer_meta_map, - op_inplace_map, - muta_attr_is_input, + args, + op_info, + inuse_infer_meta_args + spmd_params, + op_info.input_name_list, + op_info.kernel_input_type_list, + op_info.input_optional_list, + op_info.mutable_attribute_name_list, + op_info.mutable_attribute_type_list, + op_info.output_name_list, + op_info.kernel_output_type_list, + op_info.output_size_list, + op_info.output_optional_list, + op_info.infer_meta_map, + op_info.inplace_map, + mutable_attr_is_input, ) - infermeta_func = OP_INFERMETA_TEMPLATE.format( - op_name=op_class_name, + infermeta_func = OP_INFERMETA_IMPL_TEMPLATE_2.format( + op_name=op_info.class_name, infermeta_inputs=infermeta_inputs_str, get_attributes_str=get_attributes_str, infermeta_outputs=infermeta_outputs_str, @@ -709,7 +783,45 @@ def gen_infermeta_func_str( return infermeta_func -def gen_infermeta_by_invoke_func_str(op_class_name, invoke_class_name): - return OP_INFERMETA_BY_INVOKE_TEMPLATE.format( - op_name=op_class_name, invoke_class=invoke_class_name +def gen_infermeta_impl_str(args, op_info): + return ( + OP_INFERMETA_IMPL_TEMPLATE_1.format( + op_name=op_info.class_name, + infer_meta_func=op_info.infer_meta_func, + ) + + "\n" + + gen_infermeta_func_str(args, op_info) ) + + +def gen_infermeta_by_invoke_impl_str(op_info, op_info_items): + invoke_class_name = to_pascal_case(op_info.invoke_map['func']) + "Op" + return ( + OP_INFERMETA_IMPL_TEMPLATE_1.format( + op_name=op_info.class_name, + infer_meta_func=op_info_items[ + op_info.invoke_map['func'] + ].infer_meta_func, + ) + + "\n" + + OP_INFERMETA_IMPL_TEMPLATE_2_BY_INVOKE.format( + op_name=op_info.class_name, invoke_class=invoke_class_name + ) + ) + + +def gen_op_infermeta_func(args, op_info, op_info_items): + interface = [] + declare_str = "" + impl_str = "" + if op_info.infer_meta_func: + interface = ["paddle::dialect::InferMetaInterface"] + declare_str = OP_INFERMETA_DECL_STRING + impl_str = gen_infermeta_impl_str(args, op_info) + elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: + if op_info_items[op_info.invoke_map['func']].infer_meta_func: + interface = ["paddle::dialect::InferMetaInterface"] + declare_str = OP_INFERMETA_DECL_STRING + impl_str = gen_infermeta_by_invoke_impl_str(op_info, op_info_items) + + return interface, declare_str, impl_str diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 0a0cae38ec2e5..ce9990350e486 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -15,12 +15,6 @@ # generator interfaces from vjp_interface_black_list import vjp_interface_black_list -OP_INFER_SHAPE_TEMPLATE = """ -void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ - auto fn = PD_INFER_META(phi::{infer_meta_func}); - fn(infer_meta); -}} -""" CHECK_INPUT_TEMPLATE = """ PADDLE_ENFORCE_EQ( inputs_.size(), @@ -272,37 +266,8 @@ def gen_op_vjp_str( return str -def gen_op_infer_meta_str(op_info, op_class_name, op_info_items): - op_infer_meta_str = "" - if op_info.infer_meta_func: - op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( - op_name=op_class_name, - infer_meta_func=op_info.infer_meta_func, - ) - elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: - if op_info_items[op_info.invoke_map['func']].infer_meta_func: - op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( - op_name=op_class_name, - infer_meta_func=op_info_items[ - op_info.invoke_map['func'] - ].infer_meta_func, - ) - return op_infer_meta_str - - def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str = "" - if op_info.infer_meta_func: - exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );\n" - " static std::vector InferMeta( const std::vector& input_values, const pir::AttributeMap& attributes );" - ) - elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: - if op_info_items[op_info.invoke_map['func']].infer_meta_func: - exclusive_interface_str += ( - " static void InferMeta( phi::InferMetaContext *infer_meta );\n" - " static std::vector InferMeta( const std::vector& input_values, const pir::AttributeMap& attributes );" - ) if op_info.op_phi_name[0] not in vjp_interface_black_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py b/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py index e5a8b2c9eb15c..646392cb57e5c 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_kerneltype_gen.py @@ -67,7 +67,7 @@ def get_data_transform_check_str(op_data_transform_map): ) if "support_trans_dtype" in op_data_transform_map: args = op_data_transform_map["support_trans_dtype"] - # TODO:(chenxi67) comlete SUPPORT logic + # TODO:(chenxi67) complete SUPPORT logic if args is not None: if_cond_args = [] for support_arg in args: diff --git a/paddle/fluid/pir/dialect/op_generator/op_member_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_member_access_func_gen.py similarity index 79% rename from paddle/fluid/pir/dialect/op_generator/op_member_func_gen.py rename to paddle/fluid/pir/dialect/op_generator/op_member_access_func_gen.py index dd060692bd078..98e4e8de66e80 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_member_func_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_member_access_func_gen.py @@ -20,9 +20,13 @@ """ -def gen_op_get_inputs_outputs_str( - op_input_name_list, op_mutable_attribute_name_list, op_output_name_list -): +# =================================== # +# gen get input/output methods str # +# =================================== # +def gen_op_member_access_func(args, op_info, op_info_items): + op_input_name_list = op_info.input_name_list + op_mutable_attribute_name_list = op_info.mutable_attribute_name_list + op_output_name_list = op_info.output_name_list op_get_inputs_outputs_str = "" for idx in range(len(op_input_name_list)): op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format( @@ -39,4 +43,4 @@ def gen_op_get_inputs_outputs_str( output_name=op_output_name_list[idx], output_index=idx, ) - return op_get_inputs_outputs_str + return [], op_get_inputs_outputs_str, None diff --git a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py index 70770c64e0aaa..dbde0802f9982 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_verify_gen.py @@ -19,8 +19,8 @@ VLOG(4) << "Verifying inputs:"; {{ auto input_size = num_operands(); - IR_ENFORCE(input_size == {inputs_size}u, - "The size %d of inputs must be equal to {inputs_size}.", input_size);{inputs_type_check} + PADDLE_ENFORCE_EQ(input_size == {inputs_size}u, true, phi::errors::InvalidArgument( + "The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check} }} VLOG(4) << "Verifying attributes:"; {{{attributes_check} @@ -28,8 +28,8 @@ VLOG(4) << "Verifying outputs:"; {{ auto output_size = num_results(); - IR_ENFORCE(output_size == {outputs_size}u, - "The size %d of outputs must be equal to {outputs_size}.", output_size);{outputs_type_check} + PADDLE_ENFORCE_EQ(output_size == {outputs_size}u, true, phi::errors::InvalidArgument( + "The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check} }} VLOG(4) << "End Verifying for: {op_name}."; }} @@ -40,83 +40,83 @@ """ INPUT_TYPE_CHECK_TEMPLATE = """ - IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), - "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());""" + PADDLE_ENFORCE_EQ((*this)->operand_source({index}).type().isa<{standard}>(), true, + phi::errors::InvalidArgument("Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));""" INPUT_VECTORTYPE_CHECK_TEMPLATE = """ if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ - IR_ENFORCE(vec_type[i].isa<{standard}>(), - "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()); + PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type())); }} }} else {{ - IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(), - "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()); + PADDLE_ENFORCE_EQ((*this)->operand_source({index}).type().isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type())); }}""" INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ - IR_ENFORCE(val.type().isa<{standard}>(), - "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()); + PADDLE_ENFORCE_EQ(val.type().isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type())); }}""" INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto val = (*this)->operand({index})) {{ if (auto vec_type = val.type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ - IR_ENFORCE(vec_type[i].isa<{standard}>(), - "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()); + PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type())); }} }} else {{ - IR_ENFORCE(val.type().isa<{standard}>(), - "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()); + PADDLE_ENFORCE_EQ(val.type().isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type())); }} }}""" ATTRIBUTE_CHECK_TEMPLATE = """ - IR_ENFORCE(attributes.count("{attribute_name}")>0, - "{attribute_name} does not exist."); - IR_ENFORCE(attributes.at("{attribute_name}").isa<{standard}>(), - "Type of attribute: {attribute_name} is not {standard}."); + PADDLE_ENFORCE_GT(attributes.count("{attribute_name}"), 0, phi::errors::InvalidArgument( + "{attribute_name} does not exist.")); + PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type of attribute: {attribute_name} is not {standard}.")); """ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """ - IR_ENFORCE(attributes.count("{attribute_name}")>0, - "{attribute_name} does not exist."); - IR_ENFORCE(attributes.at("{attribute_name}").isa(), - "Type of attribute: {attribute_name} is not pir::ArrayAttribute."); + PADDLE_ENFORCE_GT(attributes.count("{attribute_name}"), 0, phi::errors::InvalidArgument( + "{attribute_name} does not exist.")); + PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa(), true, phi::errors::InvalidArgument( + "Type of attribute: {attribute_name} is not pir::ArrayAttribute.")); for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast().size(); i++) {{ - IR_ENFORCE(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), - "Type of attribute: {attribute_name} is not right."); + PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast().at(i).isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type of attribute: {attribute_name} is not right.")); }}""" OUTPUT_TYPE_CHECK_TEMPLATE = """ - IR_ENFORCE((*this)->result({index}).type().isa<{standard}>(), - "Type validation failed for the {index}th output.");""" + PADDLE_ENFORCE_EQ((*this)->result({index}).type().isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th output."));""" OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """ auto output_{index}_type = (*this)->result({index}).type(); if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ - IR_ENFORCE(vec_type[i].isa<{standard}>(), - "Type validation failed for the {index}th output."); + PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th output.")); }} }} else {{ - IR_ENFORCE(output_{index}_type.isa<{standard}>(), - "Type validation failed for the {index}th output."); + PADDLE_ENFORCE_EQ(output_{index}_type.isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th output.")); }}""" OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ - IR_ENFORCE(output_{index}_type.isa<{standard}>(), - "Type validation failed for the {index}th output."); + PADDLE_ENFORCE_EQ(output_{index}_type.isa<{standard}>(),true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th output.")); }}""" OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ if (auto output_{index}_type = (*this)->result({index}).type()) {{ if (auto vec_type = output_{index}_type.dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{ - IR_ENFORCE(vec_type[i].isa<{standard}>(), - "Type validation failed for the {index}th output."); + PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th output.")); }} }} else {{ - IR_ENFORCE(output_{index}_type.isa<{standard}>(), - "Type validation failed for the {index}th output."); + PADDLE_ENFORCE_EQ(output_{index}_type.isa<{standard}>(), true, phi::errors::InvalidArgument( + "Type validation failed for the {index}th output.")); }} }}""" diff --git a/paddle/fluid/pir/dialect/op_generator/op_vjp_interface_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_vjp_interface_func_gen.py new file mode 100644 index 0000000000000..53ff6b8e50eb4 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/op_vjp_interface_func_gen.py @@ -0,0 +1,26 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vjp_interface_black_list import vjp_interface_black_list + + +def gen_op_vjp_interface_func(args, op_info, op_info_items): + if ( + op_info.backward_name + and op_info.op_phi_name[0] not in vjp_interface_black_list + and args.dialect_name != "onednn_op" + ): + return ["paddle::dialect::VjpInterface"], None, None + else: + return [], None, None diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index 54b56a2e3c887..5ad1c5b562740 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -69,8 +69,12 @@ {{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},""" NEED_GEN_STATIC_ONLY_APIS = [ + 'c_allreduce_avg_', + 'c_allreduce_min_', + 'c_allreduce_prod_', + 'distributed_fused_lamb_init', + 'distributed_fused_lamb_init_', 'fetch', - 'fused_bias_dropout_residual_layer_norm', 'fused_embedding_eltwise_layernorm', 'fused_fc_elementwise_layernorm', 'fused_multi_transformer_xpu', @@ -114,56 +118,84 @@ 'quantize_linear_', 'dequantize_linear', 'dequantize_linear_', + 'coalesce_tensor_', ] NO_NEED_GEN_STATIC_ONLY_APIS = [ 'add_n_', - 'add_n_with_kernel', + 'all_reduce', + 'all_reduce_', + 'batch_fc', + 'barrier', 'c_allgather', + 'c_allreduce_avg', 'c_allreduce_max', 'c_allreduce_min', - 'c_allreduce_min_', 'c_allreduce_sum', 'c_allreduce_prod', - 'c_allreduce_prod_', 'c_embedding', 'c_identity', 'c_reduce_sum', 'c_reducescatter', 'c_softmax_with_cross_entropy', + 'c_split', 'decayed_adagrad', + 'distributed_push_sparse', 'distributed_lookup_table', 'dpsgd', 'embedding_grad_sparse', 'ftrl', + 'fused_adam_', 'fused_batch_norm_act_', 'fused_bn_add_activation_', 'fused_elemwise_add_activation', 'fused_scale_bias_relu_conv_bn', 'fused_scale_bias_add_relu', + 'fused_token_prune', 'fused_dconv_drelu_dbn', 'fused_dot_product_attention', 'nce', 'lars_momentum', 'lars_momentum_', 'max_pool2d_v2', + 'partial_sum', + 'random_routing', + 'rank_attention', 'recv_v2', 'rnn_', 'row_conv', 'seed', 'send_v2', 'shadow_feed', + 'shadow_feed_tensors', 'shuffle_batch', 'sparse_momentum', 'tdm_sampler', 'soft_relu', 'uniform_random_batch_size_like', 'match_matrix_tensor', + 'c_reduce_avg', + 'c_reduce_avg_', + 'c_reduce_max', + 'c_reduce_max_', 'c_reduce_min', 'c_reduce_min_', + 'c_reduce_prod', + 'c_reduce_prod_', + 'c_scatter', + 'prune_gate_by_capacity', 'push_sparse_v2', 'push_sparse_v2_', + 'partial_concat', 'partial_send', + 'partial_recv', + 'partial_allgather', + 'partial_allgather_', + 'nop', + 'nop_', + 'push_dense', + 'limit_by_capacity', + 'global_scatter', ] diff --git a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py index 38619ec22e049..1fc2987ec4ea2 100644 --- a/paddle/fluid/pir/dialect/op_generator/python_c_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/python_c_gen.py @@ -52,6 +52,7 @@ #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/enforce.h" +#include "paddle/fluid/pybind/op_callstack_utils.h" {body} @@ -71,8 +72,10 @@ {attrs} // Call ir static api + CallStackRecorder callstack_recorder("{api_name}"); + callstack_recorder.Record(); auto static_api_out = paddle::dialect::{api_name}({args}); - + callstack_recorder.AttachToOps(); return ToPyObject(static_api_out); }} catch (...) {{ ThrowExceptionToPython(std::current_exception()); @@ -94,8 +97,10 @@ {attrs} // Call ir static api + CallStackRecorder callstack_recorder("{api_name}"); + callstack_recorder.Record(); paddle::dialect::{api_name}({args}); - + callstack_recorder.AttachToOps(); return nullptr; }} catch (...) {{ ThrowExceptionToPython(std::current_exception()); @@ -129,7 +134,10 @@ {cast_attrs} // Call ir static api + CallStackRecorder callstack_recorder("{api_name}"); + callstack_recorder.Record(); auto static_api_out = paddle::dialect::{api_name}({args_with_mutable_attrs}); + callstack_recorder.AttachToOps(); return ToPyObject(static_api_out); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc new file mode 100644 index 0000000000000..42b3567290cda --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -0,0 +1,535 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h" +#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" + +namespace { + +inline void UpdatePaddingAndDilation( + std::vector *paddings, + std::vector *dilation, + const std::string padding_algorithm, + const std::vector data_dims, + const std::vector &strides, + const std::vector &ksize) { + // set padding size == data_dims.size() * 2 + if (paddings->size() == data_dims.size()) { + for (size_t i = 0; i < data_dims.size(); ++i) { + symbol::DimExpr copy_pad = *(paddings->begin() + 2 * i); + paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); + } + } + + // when padding_algorithm is "VALID" or "SAME" + symbol::DimExpr zero{0}; + symbol::DimExpr one{1}; + symbol::DimExpr two{2}; + if (padding_algorithm == "SAME") { + symbol::DimExprBuilder builder{nullptr}; + for (size_t i = 0; i < data_dims.size(); ++i) { + symbol::DimExpr out_size = (data_dims[i] + strides[i] - 1) / strides[i]; + symbol::DimExpr pad_sum = builder.Max( + (out_size - one) * strides[i] + ksize[i] - data_dims[i], zero); + + symbol::DimExpr pad_0 = pad_sum / two; + symbol::DimExpr pad_1 = pad_sum - pad_0; + + *(paddings->begin() + i * 2) = pad_0; + *(paddings->begin() + i * 2 + 1) = pad_1; + + // dilation + *(dilation->begin() + i) = one; + } + + } else if (padding_algorithm == "VALID") { + for (auto it = paddings->begin(); it != paddings->end(); it++) { + *it = zero; + } + } +} + +} // namespace +namespace paddle::dialect { + +bool Conv2dOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const std::vector strides = + paddle::dialect::details::GetVectorAttr(op, "strides"); + + std::vector paddings = + paddle::dialect::details::GetVectorAttr(op, "paddings"); + + std::vector dilations = + paddle::dialect::details::GetVectorAttr(op, "dilations"); + + const auto &attributes = op->attributes(); + const std::string data_format = + attributes.at("data_format").dyn_cast().AsString(); + + const std::string padding_algorithm = attributes.at("padding_algorithm") + .dyn_cast() + .AsString(); + + const auto in_s_or_d = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto filter_s_or_d = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + + const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); + + std::vector in_data_dims = + channel_last ? std::vector(in_s_or_d.shape().begin() + 1, + in_s_or_d.shape().end() - 1) + : std::vector(in_s_or_d.shape().begin() + 2, + in_s_or_d.shape().end()); + + std::vector filter_data_dims = std::vector( + filter_s_or_d.shape().begin() + 2, filter_s_or_d.shape().end()); + + std::vector ksize = filter_data_dims; + + std::vector new_paddings; + for (const auto &i : paddings) { + new_paddings.push_back(symbol::DimExpr{i}); + } + std::vector new_dilations; + for (const auto &i : dilations) { + new_dilations.push_back(symbol::DimExpr{i}); + } + + UpdatePaddingAndDilation(&new_paddings, + &new_dilations, + padding_algorithm, + in_data_dims, + strides, + ksize); + + const symbol::ShapeOrDataDimExprs &shape_data = [&] { + std::vector out_s_or_d({in_s_or_d.shape()[0]}); + if (!channel_last) { + out_s_or_d.push_back(filter_s_or_d.shape()[0]); + } + + for (size_t i = 0; i < in_data_dims.size(); ++i) { + if (!in_data_dims[i].isa() || + !filter_s_or_d.shape()[i + 2].isa()) { + out_s_or_d.push_back(shape_analysis->GetNextSymName()); + } else { + const symbol::DimExpr dkernel = + new_dilations[i] * (filter_data_dims[i] - 1) + 1; + symbol::DimExpr output_size = (in_data_dims[i] + new_paddings[2 * i] + + new_paddings[2 * i + 1] - dkernel) / + strides[i] + + 1; + out_s_or_d.push_back(output_size); + } + } + if (channel_last) { + out_s_or_d.push_back(filter_s_or_d.shape()[0]); + } + + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_s_or_d)}; + }(); + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + return true; +} + +bool Conv3dOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return Conv2dOpInferSymbolicShape(op, shape_analysis); +} + +bool EmbeddingOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto weight_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + const std::vector &x_dims = [&] { + std::vector dims; + if (x_shape_or_data.data().has_value()) { + dims = x_shape_or_data.data().value(); + } else { + dims = x_shape_or_data.shape(); + } + return dims; + }(); + + const std::vector &weight_dims = [&] { + std::vector dims; + if (weight_shape_or_data.data().has_value()) { + dims = weight_shape_or_data.data().value(); + } else { + dims = weight_shape_or_data.shape(); + } + return dims; + }(); + + const symbol::ShapeOrDataDimExprs &shape_data = [&] { + std::vector out_dims = x_dims; + // no need to check validation of weight_dims index, since all checks have + // been done at corresponding InferMeta + out_dims.emplace_back(weight_dims[1]); + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + }(); + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + return true; +} + +bool SparseWeightEmbeddingOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool ExpandAsOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool GatherOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &input_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto &index_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + + const auto &numel = [&] { + symbol::DimExpr numel{1}; + for (const auto &dim_expr : index_shape_or_data.shape()) { + numel = numel * dim_expr; + } + return numel; + }(); + + const auto &axis_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + + const std::vector &input_sym_shape = + input_shape_or_data.data().has_value() + ? input_shape_or_data.data().value() + : input_shape_or_data.shape(); + + const std::vector &index_sym_shape = + index_shape_or_data.data().has_value() + ? index_shape_or_data.data().value() + : index_shape_or_data.shape(); + + int axis = + static_cast(axis_shape_or_data.data().value()[0].Get()); + if (axis < 0) axis += input_sym_shape.size(); + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + + if (index_sym_shape.size() == 0) { + if (input_sym_shape.size() == 1) { + out_sym_shape.push_back(symbol::DimExpr{0}); + } else { + for (int i = 0; i < axis; ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + for (size_t i = axis + 1; i < input_sym_shape.size(); ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + } + } else { + for (int i = 0; i < axis; ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + out_sym_shape.push_back(numel); + for (size_t i = axis + 1; i < input_sym_shape.size(); ++i) { + out_sym_shape.push_back(input_sym_shape[i]); + } + } + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; +} + +bool GatherNdOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto &index_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + + const std::vector &x_sym_shape = + x_shape_or_data.data().has_value() ? x_shape_or_data.data().value() + : x_shape_or_data.shape(); + + const std::vector &index_sym_shape = + index_shape_or_data.data().has_value() + ? index_shape_or_data.data().value() + : index_shape_or_data.shape(); + + int x_dims_size = x_sym_shape.size(); + int index_dims_size = index_sym_shape.size(); + + std::vector result_sym_dims; + // The result dims is + // Index.shape[:-1] + X.shape[Index.shape[-1]:] + for (int i = 0; i < index_dims_size - 1; ++i) { + result_sym_dims.emplace_back(index_sym_shape[i]); + } + + PADDLE_ENFORCE_EQ( + index_sym_shape[index_dims_size - 1].Has(), + true, + phi::errors::InvalidArgument( + "in GatherNdOpInferSymbolicShape: index[-1] should be unknown")); + + for (int i = static_cast( + index_sym_shape[index_dims_size - 1].Get()); + i < x_dims_size; + ++i) { + result_sym_dims.emplace_back(x_sym_shape[i]); + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; +} + +bool KronOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &y_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)).shape(); + const int rank_x = x_shape_or_data.size(); + const int rank_y = y_shape_or_data.size(); + const int rank = (rank_x > rank_y) ? rank_x : rank_y; + + std::vector dim_out; + dim_out.reserve(rank); + const auto one = symbol::DimExpr{1}; + const auto minus_one = symbol::DimExpr{-1}; + for (int i = 0; i < rank; i++) { + symbol::DimExpr dim_xi = + (i < rank - rank_x) ? one : x_shape_or_data.at(i - (rank - rank_x)); + symbol::DimExpr dim_yi = + (i < rank - rank_y) ? one : y_shape_or_data.at(i - (rank - rank_y)); + dim_out.push_back(dim_xi * dim_yi); + } + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; +} + +bool MaskedSelectOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool MatmulOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + // x_dims can't be const or ref here, in case to be broadcasted + std::vector x_dims = [&] { + std::vector dims; + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + if (x_shape_or_data.data().has_value()) { + dims = x_shape_or_data.data().value(); + } else { + dims = x_shape_or_data.shape(); + } + return dims; + }(); + + // y_dims can't be const or ref here, in case to be broadcasted + std::vector y_dims = [&] { + std::vector dims; + const auto y_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + if (y_shape_or_data.data().has_value()) { + dims = y_shape_or_data.data().value(); + } else { + dims = y_shape_or_data.shape(); + } + return dims; + }(); + + size_t ndims_x = x_dims.size(); + size_t ndims_y = y_dims.size(); + + const bool x_broadcasted = [&] { + bool broadcasted = false; + if (ndims_x == 1) { + x_dims.insert(x_dims.begin(), 1); + ndims_x = 2; + broadcasted = true; + } + return broadcasted; + }(); + + const bool y_broadcasted = [&] { + bool broadcasted = false; + if (ndims_y == 1) { + y_dims.emplace_back(1); + ndims_y = 2; + broadcasted = true; + } + return broadcasted; + }(); + + std::vector out_dims; + if (ndims_x > ndims_y) { + out_dims.assign(x_dims.begin(), x_dims.end() - 2); + } else if (ndims_x < ndims_y) { + out_dims.assign(y_dims.begin(), y_dims.end() - 2); + } else { + symbol::DimExprBuilder builder{nullptr}; + for (size_t i = 0; i < ndims_x - 2; ++i) { + out_dims.emplace_back(builder.Broadcast(x_dims[i], y_dims[i])); + } + } + + bool transpose_x_attr = GetBoolAttr(op, "transpose_x"); + bool transpose_y_attr = GetBoolAttr(op, "transpose_y"); + symbol::DimExpr out_M = + transpose_x_attr ? x_dims[ndims_x - 1] : x_dims[ndims_x - 2]; + symbol::DimExpr out_N = + transpose_y_attr ? y_dims[ndims_y - 2] : y_dims[ndims_y - 1]; + if (!x_broadcasted) { + out_dims.emplace_back(out_M); + } + if (!y_broadcasted) { + out_dims.emplace_back(out_N); + } + + shape_analysis->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_dims)}); + + if ((ndims_x == ndims_y) && ndims_x >= 2) { + if (transpose_x_attr == false && transpose_y_attr == false) { + shape_analysis->DimExprBuilder().CstrEq(x_dims[ndims_x - 1], + y_dims[ndims_x - 2]); + } else if (transpose_x_attr == false && transpose_y_attr == true) { + shape_analysis->DimExprBuilder().CstrEq(x_dims[ndims_x - 1], + y_dims[ndims_x - 1]); + } else if (transpose_x_attr == true && transpose_y_attr == false) { + shape_analysis->DimExprBuilder().CstrEq(x_dims[ndims_x - 2], + y_dims[ndims_x - 2]); + } else { + shape_analysis->DimExprBuilder().CstrEq(x_dims[ndims_x - 2], + y_dims[ndims_x - 1]); + } + + for (size_t i = 0; i < ndims_x - 2; ++i) { + shape_analysis->DimExprBuilder().CstrEq(x_dims[i], y_dims[i]); + } + } + return true; +} + +bool SearchsortedOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool TakeAlongAxisOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + // input + const auto &arr_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto &indices_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + const auto &attributes = op->attributes(); + int axis = attributes.at("axis").dyn_cast().data(); + + const std::vector &arr_sym_shape = + arr_shape_or_data.data().has_value() ? arr_shape_or_data.data().value() + : arr_shape_or_data.shape(); + const std::vector &indices_sym_shape = + indices_shape_or_data.data().has_value() + ? indices_shape_or_data.data().value() + : indices_shape_or_data.shape(); + + if (axis < 0) axis += arr_sym_shape.size(); + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + for (int i = 0; i < axis; ++i) { + out_sym_shape.push_back(arr_sym_shape[i]); + } + out_sym_shape.push_back(indices_sym_shape[axis]); + for (size_t i = axis + 1; i < arr_sym_shape.size(); ++i) { + out_sym_shape.push_back(arr_sym_shape[i]); + } + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; +} + +bool TopPSamplingOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &x_dims = [op, shape_analysis] { + const auto &shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + if (shape_or_data.data().has_value()) { + return shape_or_data.data().value(); + } else { + return shape_or_data.shape(); + } + }(); + + // all the result have the same shape + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + const std::vector out_dims{x_dims[0], 1}; + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_dims)}); + } + + return true; +} + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h new file mode 100644 index 0000000000000..fb8bbf11ac08a --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { + +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gather) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GatherNd) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kron) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaskedSelect) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Matmul) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Searchsorted) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(TakeAlongAxis) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(TopPSampling) + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc index 0e8240434e070..be9e14eef1bb1 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" namespace cinn::dialect { @@ -41,6 +42,25 @@ bool ConcatOpInferSymbolicShape( const auto input_values = op->operands_source(); const auto input_size = input_values.size(); + if (shape_analysis->GetShapeOrDataForValue(input_values[0]) + .data() + .has_value()) { + std::vector out_data; + for (const auto &value : input_values) { + const auto &shape_or_data = shape_analysis->GetShapeOrDataForValue(value); + for (size_t i = 0; i < shape_or_data.data().value().size(); ++i) { + out_data.emplace_back(shape_or_data.data().value()[i]); + } + } + const std::vector shape{std::int64_t(out_data.size())}; + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(shape, out_data)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + } + int axis = op->attributes().at("axis").dyn_cast().data(); const auto &GetOutDimExprs = [&]() -> std::vector { @@ -56,7 +76,7 @@ bool ConcatOpInferSymbolicShape( out_dims[axis] = out_dims[axis] + operand_shape_or_data.shape()[axis]; } - for (size_t i = 1; i < rank; ++i) { + for (size_t i = 0; i < rank; ++i) { if (i == static_cast(axis)) continue; paddle::dialect::details::BuildCstrEqForTensorListAlongAxis( shape_analysis, input_values, i); @@ -65,6 +85,9 @@ bool ConcatOpInferSymbolicShape( return out_dims; }; + VLOG(3) << "constraints size:" + << shape_analysis->DimExprBuilder().constraints().size(); + symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(GetOutDimExprs())}; @@ -74,16 +97,11 @@ bool ConcatOpInferSymbolicShape( bool ReduceInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &attr_map = op->attributes(); - PADDLE_ENFORCE( - attr_map.count("keep_dim"), - phi::errors::PreconditionNotMet( - "attr [keep_dim] MUST in attribute map for [%s] op", op->name())); - bool keepdim = attr_map.at("keep_dim").dyn_cast().data(); + bool keep_dim = GetBoolAttr(op, "keep_dim"); auto axis = paddle::dialect::details::GetVectorAttr(op, "dim"); bool reduce_all = axis.size() == 0 ? true : false; return paddle::dialect::details::ReduceInferDim( - op, shape_analysis, axis, keepdim, reduce_all); + op, shape_analysis, axis, keep_dim, reduce_all); } bool ReduceMaxOpInferSymbolicShape( @@ -111,10 +129,73 @@ bool ReshapeOpInferSymbolicShape( std::vector shape = paddle::dialect::details::GetVectorAttr(op, "shape"); - std::vector out_dims; - for (int dim : shape) { - out_dims.emplace_back(static_cast(dim)); + const symbol::ShapeOrDataDimExprs &x_dim_expr = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + if (x_dim_expr.data().has_value()) { + if (shape.size() == 1 && shape.front() == 1) { + shape_analysis->SetShapeOrDataForValue( + op->result(0), + symbol::TensorShapeOrDataDimExprs(std::vector{1}, + x_dim_expr.data().value())); + return true; + } } + + const auto &GetProduct = [&](const auto &dim_exprs, const auto &Filter) { + symbol::DimExpr product{1}; + for (const auto &dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + product = product * dim_expr; + } + } + return product; + }; + + const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + + const auto &IsZero = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() == static_cast(0); + } + return false; + }; + + const auto &target_shape = [&] { + std::vector target_shape; + for (int dim : shape) { + target_shape.emplace_back(static_cast(dim)); + } + return target_shape; + }(); + + const auto &original_shape = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + + const auto &out_dims = [&] { + const auto &numel = + GetProduct(original_shape, [](const auto &) { return true; }); + + const auto &product_exclude_minus_one = + GetProduct(target_shape, IsNotMinusOne); + + std::vector out_dims; + out_dims.reserve(target_shape.size()); + for (size_t i = 0; i < target_shape.size(); ++i) { + auto out_dim_expr = IsNotMinusOne(target_shape[i]) + ? target_shape[i] + : (numel / product_exclude_minus_one); + out_dim_expr = IsZero(target_shape[i]) ? original_shape[i] : out_dim_expr; + out_dims.emplace_back(out_dim_expr); + } + + return out_dims; + }(); + symbol::ShapeOrDataDimExprs shape_data{ symbol::TensorShapeOrDataDimExprs(out_dims)}; shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); @@ -124,52 +205,30 @@ bool ReshapeOpInferSymbolicShape( bool SliceOpInferSymbolicShape(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - // TODO(zhangbopd): Not implemented yet, different from the one in paddle - // dialect. And Currently only support start/end/axis with single value. - pir::AttributeMap attributes = op->attributes(); - - auto GetAttrInt64Value = [&](const std::string &name) -> int64_t { - std::vector attr = - attributes[name].dyn_cast().AsVector(); - PADDLE_ENFORCE_GT( - attr.size(), - 0, - phi::errors::PreconditionNotMet( - "Only Support [%s] op len(%s) == 1 , but received %d.", - op->name(), - name, - attr.size())); - return attr[0].dyn_cast().data(); - }; - - const int64_t start = GetAttrInt64Value("starts"); - const int64_t end = GetAttrInt64Value("ends"); - const int64_t axis = GetAttrInt64Value("axes"); - - const pir::Value operand_source = op->operand_source(0); - const auto &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); + const std::vector starts_raw = + paddle::dialect::details::GetVectorAttr(op, "starts"); + const std::vector ends_raw = + paddle::dialect::details::GetVectorAttr(op, "ends"); + const std::vector axes_raw = + paddle::dialect::details::GetVectorAttr(op, "axes"); + const std::vector infer_flags_raw = + paddle::dialect::details::GetVectorAttr(op, "infer_flags"); + const std::vector decrease_axis_raw = + paddle::dialect::details::GetVectorAttr(op, "decrease_axis"); + + const ExprVec starts = paddle::dialect::details::VecInt642Expr(starts_raw); + const ExprVec ends = paddle::dialect::details::VecInt642Expr(ends_raw); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), + paddle::dialect::slice_utils::SliceRawInferSymbolicShape( + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)), + starts, + ends, + axes_raw, + infer_flags_raw, + decrease_axis_raw)); - const auto GetOutDimExprs = [&]() -> symbol::TensorShapeOrDataDimExprs { - std::vector out_sym_shape = operand_shape_or_data.shape(); - if (end == std::numeric_limits::max()) { - out_sym_shape[axis] = out_sym_shape[axis] - start; - } else { - out_sym_shape[axis] = end - start; - } - symbol::TensorShapeOrDataDimExprs shape_dim_expr(out_sym_shape); - if (operand_shape_or_data.data().has_value()) { - std::vector out_data; - for (int64_t i = start; i < end; i++) { - out_data.push_back(operand_shape_or_data.data().value()[i]); - } - shape_dim_expr.SetData(out_data); - } - return shape_dim_expr; - }; - symbol::ShapeOrDataDimExprs shape_data{GetOutDimExprs()}; - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); return true; } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h index b98f8e02d66e9..b3cc2232a1f91 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h @@ -16,32 +16,12 @@ #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" namespace cinn::dialect { - -bool BroadcastOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ConcatOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReduceMaxOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReduceMinOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReduceProdOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReduceSumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReshapeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Broadcast) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceMax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceMin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceProd) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ReduceSum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice) } // namespace cinn::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc new file mode 100644 index 0000000000000..170143307dc06 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" + +bool ShouldUseData(pir::Value val) { + if (!val.defining_op()) return false; + if (val.defining_op()->isa()) { + return true; + } + return false; +} + +bool InferSymbolicShapeElementWiseBinary( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &x_shapeordata = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + std::vector shape_0; + // For ElementWiseBinary ops, if the input tensor is from full op, the value + // of fullop is useless, only the shape need doing broadcast + if (ShouldUseData(op->operand_source(0)) && + x_shapeordata.data().has_value()) { + shape_0 = x_shapeordata.data().value(); + } else { + shape_0 = x_shapeordata.shape(); + } + + const auto &y_shapeordata = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + std::vector shape_1; + if (ShouldUseData(op->operand_source(1)) && + y_shapeordata.data().has_value()) { + shape_1 = y_shapeordata.data().value(); + } else { + shape_1 = y_shapeordata.shape(); + } + + int diff = shape_0.size() - shape_1.size(); + if (diff > 0) { + for (int i = 0; i < diff; i++) { + shape_1.emplace(shape_1.begin(), 1); + } + } else { + for (int i = 0; i < -diff; i++) { + shape_0.emplace(shape_0.begin(), 1); + } + } + + const std::vector shapes = [&] { + std::vector shapes; + symbol::DimExprBuilder builder{nullptr}; + for (size_t i = 0; i < shape_0.size(); i++) { + if (shape_0[i] == shape_1[i]) { + shapes.emplace_back(shape_0[i]); + } else if (shape_0[i] == 1) { + shapes.emplace_back(shape_1[i]); + } else if (shape_1[i] == 1) { + shapes.emplace_back(shape_0[i]); + } else { + shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); + } + } + return shapes; + }(); + + // TODO(lanxianghit): fill data when the operation is on shape computation + // std::vector data; + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(shapes)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + return true; +} + +#define OP_ELEMENT_WISE_BINARY(name) \ + bool name##OpInferSymbolicShape( \ + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { \ + return InferSymbolicShapeElementWiseBinary(op, shape_analysis); \ + } + +namespace paddle::dialect { +OP_ELEMENT_WISE_BINARY(Add) +OP_ELEMENT_WISE_BINARY(Add_) +OP_ELEMENT_WISE_BINARY(BitwiseAnd) +OP_ELEMENT_WISE_BINARY(BitwiseAnd_) +OP_ELEMENT_WISE_BINARY(BitwiseXor) +OP_ELEMENT_WISE_BINARY(BitwiseXor_) +OP_ELEMENT_WISE_BINARY(Complex) +OP_ELEMENT_WISE_BINARY(Divide) +OP_ELEMENT_WISE_BINARY(Divide_) +OP_ELEMENT_WISE_BINARY(ElementwisePow) +OP_ELEMENT_WISE_BINARY(Fmax) +OP_ELEMENT_WISE_BINARY(Fmin) +OP_ELEMENT_WISE_BINARY(GreaterEqual) +OP_ELEMENT_WISE_BINARY(GreaterEqual_) +OP_ELEMENT_WISE_BINARY(GreaterThan) +OP_ELEMENT_WISE_BINARY(GreaterThan_) +OP_ELEMENT_WISE_BINARY(LessEqual) +OP_ELEMENT_WISE_BINARY(LessEqual_) +OP_ELEMENT_WISE_BINARY(LessThan) +OP_ELEMENT_WISE_BINARY(LessThan_) +OP_ELEMENT_WISE_BINARY(LogicalAnd) +OP_ELEMENT_WISE_BINARY(LogicalAnd_) +OP_ELEMENT_WISE_BINARY(LogicalOr) +OP_ELEMENT_WISE_BINARY(LogicalOr_) +OP_ELEMENT_WISE_BINARY(LogicalXor) +OP_ELEMENT_WISE_BINARY(LogicalXor_) +OP_ELEMENT_WISE_BINARY(Maximum) +OP_ELEMENT_WISE_BINARY(Minimum) +OP_ELEMENT_WISE_BINARY(Multiply) +OP_ELEMENT_WISE_BINARY(MultiplySr) +OP_ELEMENT_WISE_BINARY(MultiplySr_) +OP_ELEMENT_WISE_BINARY(Multiply_) +OP_ELEMENT_WISE_BINARY(NotEqual) +OP_ELEMENT_WISE_BINARY(NotEqual_) +OP_ELEMENT_WISE_BINARY(Remainder) +OP_ELEMENT_WISE_BINARY(Remainder_) +OP_ELEMENT_WISE_BINARY(Subtract) +OP_ELEMENT_WISE_BINARY(Subtract_) + +} // namespace paddle::dialect + +#undef OP_ELEMENT_WISE_BINARY diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.h new file mode 100644 index 0000000000000..aaa6ebf1d5836 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.h @@ -0,0 +1,59 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Add) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Add_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BitwiseAnd) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BitwiseAnd_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BitwiseXor) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BitwiseXor_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Complex) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Divide) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Divide_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ElementwisePow) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fmax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fmin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterEqual) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterEqual_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterThan) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(GreaterThan_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessEqual) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessEqual_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessThan) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LessThan_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalAnd) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalAnd_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalOr) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalOr_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalXor) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalXor_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Maximum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Minimum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multiply) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiplySr) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(MultiplySr_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multiply_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(NotEqual) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(NotEqual_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Remainder) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Remainder_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Subtract) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Subtract_) + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc deleted file mode 100644 index 21da5351c617d..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.cc +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" - -bool InferSymbolicShapeElementWiseBinary( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &x_shapeordata = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - std::vector shape_0; - // For ElementWiseBinary ops, if the input tensor is from full op, the value - // of fullop is useless, only the shape need doing broadcast - bool x_from_fullop = - op->operand_source(0).defining_op()->isa(); - if (!x_from_fullop && x_shapeordata.data().has_value()) { - shape_0 = x_shapeordata.data().value(); - } else { - shape_0 = x_shapeordata.shape(); - } - - const auto &y_shapeordata = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - std::vector shape_1; - bool y_from_fullop = - op->operand_source(1).defining_op()->isa(); - if (!y_from_fullop && y_shapeordata.data().has_value()) { - shape_1 = y_shapeordata.data().value(); - } else { - shape_1 = y_shapeordata.shape(); - } - - int diff = shape_0.size() - shape_1.size(); - if (diff > 0) { - for (int i = 0; i < diff; i++) { - shape_1.emplace(shape_1.begin(), 1); - } - } else { - for (int i = 0; i < -diff; i++) { - shape_0.emplace(shape_0.begin(), 1); - } - } - - const std::vector shapes = [&] { - std::vector shapes; - symbol::DimExprBuilder builder{nullptr}; - for (size_t i = 0; i < shape_0.size(); i++) { - if (shape_0[i] == shape_1[i]) { - shapes.emplace_back(shape_0[i]); - } else if (shape_0[i] == 1) { - shapes.emplace_back(shape_1[i]); - } else if (shape_1[i] == 1) { - shapes.emplace_back(shape_0[i]); - } else { - shapes.emplace_back(builder.Broadcast(shape_0[i], shape_1[i])); - } - } - return shapes; - }(); - - // TODO(lanxianghit): fill data when the operation is on shape computation - // std::vector data; - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(shapes)}; - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - - return true; -} - -namespace paddle::dialect { - -bool AddOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool Add_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool BitwiseAndOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool BitwiseAnd_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return BitwiseAndOpInferSymbolicShape(op, shape_analysis); -} - -bool DivideOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} -bool Divide_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool ElementwisePowOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool GreaterThanOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool GreaterThan_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return GreaterThanOpInferSymbolicShape(op, shape_analysis); -} - -bool LessThanOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool LessThan_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return LessThanOpInferSymbolicShape(op, shape_analysis); -} - -bool LogicalAndOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool LogicalAnd_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return LogicalAndOpInferSymbolicShape(op, shape_analysis); -} - -bool MultiplyOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} -bool MultiplySrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} -bool Multiply_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} -bool MultiplySr_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool NotEqualOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return InferSymbolicShapeElementWiseBinary(op, shape_analysis); -} - -bool NotEqual_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return NotEqualOpInferSymbolicShape(op, shape_analysis); -} - -} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h deleted file mode 100644 index e15d769fc8b02..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" - -namespace paddle::dialect { -bool AddOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Add_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool BitwiseAndOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool BitwiseAnd_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool DivideOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Divide_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ElementwisePowOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool GreaterThanOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool GreaterThan_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LessThanOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LessThan_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LogicalAndOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LogicalAnd_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool MultiplyOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool MultiplySrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Multiply_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool MultiplySr_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool NotEqualOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool NotEqual_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h new file mode 100644 index 0000000000000..345c55e1a116b --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h @@ -0,0 +1,191 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" + +namespace paddle::dialect::slice_utils { + +inline ExprVec GetExprVecFromData(const ShapeOrData &shapeordata) { + if (shapeordata.isa()) { + ExprVec result; + TensorListExprs list = + shapeordata.dyn_cast(); + for (size_t i = 0; i < list.size(); i++) { + for (auto expr : list[i].data().value()) { + result.emplace_back(expr); + } + } + return result; + } else { + return shapeordata.data().value(); + } +} + +inline void CheckAndUpdateSliceAttrs( + const ExprVec &in_dims, + const std::vector &axes, + ExprVec *starts_p, + ExprVec *ends_p, + std::vector *infer_flags = nullptr) { + ExprVec &starts = *starts_p; + ExprVec &ends = *ends_p; + auto IsMaxInt = [](const symbol::DimExpr &expr) { + return expr.isa() && + expr.Get() == + static_cast(std::numeric_limits::max()); + }; + + for (size_t i = 0; i < axes.size(); ++i) { + int64_t axis = axes[i]; + int64_t start_i = 0; + if (starts[i].isa()) { + start_i = starts[i].Get(); + } + int64_t end_i = 0; + if (ends[i].isa()) { + end_i = ends[i].Get(); + } + + // For both start and end can be negative or positive, we need to handle the + // following different arrangements. + ends[i] = IsMaxInt(ends[i]) ? in_dims[axis] : ends[i]; + + bool both_negative_or_positive = + (start_i >= 0 && end_i >= 0) || (start_i <= 0 && end_i <= 0); + bool start_negative_end_positive = start_i <= 0 && end_i >= 0; + bool start_positive_end_negative = start_i >= 0 && end_i <= 0; + + if (both_negative_or_positive) { + continue; + } else if (start_negative_end_positive) { + starts[i] = starts[i] + in_dims[axis]; + } else if (start_positive_end_negative) { + starts[i] = starts[i] - in_dims[axis]; + } else { + PADDLE_THROW(phi::errors::Fatal("Dead code")); + } + } +} + +inline ExprVec GetSliceDims(const ExprVec &in_dims, + const std::vector &axes, + const ExprVec &starts, + const ExprVec &ends, + std::vector *infer_flags = nullptr) { + ExprVec slice_dims(in_dims); + + for (size_t i = 0; i < axes.size(); ++i) { + int64_t axis = axes[i]; + slice_dims[axis] = ends[i] - starts[i]; + } + + return slice_dims; +} + +inline ExprVec GetDecreasedDims(const ExprVec &slice_dims, + const std::vector &decrease_axes) { + ExprVec decreased_dims(slice_dims); + std::vector decrease_flag(slice_dims.size(), 0); + if (decrease_axes.size() > 0) { + for (size_t i = 0; i < decrease_axes.size(); ++i) { + int64_t axis = decrease_axes[i]; + decrease_flag[axis] = 1; + } + ExprVec new_shape; + for (size_t i = 0; i < slice_dims.size(); ++i) { + if (decrease_flag[i] == 0) { + new_shape.emplace_back(slice_dims[i]); + } + } + decreased_dims = new_shape; + } + return decreased_dims; +} + +inline std::vector FormatSliceAxes( + const std::vector &axes_raw, int64_t rank) { + std::vector axes_vec(axes_raw.size(), 0); + std::transform( + axes_raw.begin(), axes_raw.end(), axes_vec.begin(), [rank](int64_t axis) { + return axis >= 0 ? axis : std::max(int64_t(0), axis + rank); + }); + return axes_vec; +} + +inline ShapeOrData SliceRawInferSymbolicShape( + const ShapeOrData &in_shapeordata, + const ExprVec &starts_expr, + const ExprVec &ends_expr, + const std::vector &axes_raw, + const std::vector &infer_flags_raw, + const std::vector &decrease_axis) { + ExprVec starts = starts_expr; + ExprVec ends = ends_expr; + std::vector infer_flags = [&infer_flags_raw, &axes_raw] { + return infer_flags_raw.empty() ? std::vector(axes_raw.size(), 1) + : infer_flags_raw; + }(); + + const auto &GetShapeDimExprs = [&]() -> symbol::ShapeOrDataDimExprs { + const ExprVec &in_dims = in_shapeordata.shape(); + std::vector axes = FormatSliceAxes(axes_raw, in_dims.size()); + CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &infer_flags); + ExprVec slice_dims = + GetSliceDims(in_dims, axes, starts, ends, &infer_flags); + ExprVec out_dims = GetDecreasedDims(slice_dims, decrease_axis); + + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + }; + + // When `pd.slice` is operating on a tensor which is produced by a `pd.shape` + // op, the result should be written into data. + const auto &GetDataDimExprs = [&]() -> symbol::ShapeOrDataDimExprs { + std::vector out_data; + + // Currently, we DO NOT support the case that any element in `axes` `starts` + // or `ends` is a Symbol. + auto vec_int64 = details::VecExpr2Int64(starts); + IR_ENFORCE(vec_int64.has_value(), + "for slice op, all the elements in `starts` must be int64_t"); + std::vector starts_int = vec_int64.value(); + + vec_int64 = details::VecExpr2Int64(ends); + IR_ENFORCE(vec_int64.has_value(), + "for slice op, all the elements in `ends` must be int64_t"); + std::vector ends_int = vec_int64.value(); + + const int64_t start = + starts_int[0] < 0 ? starts_int[0] + in_shapeordata.data().value().size() + : starts_int[0]; + const int64_t end = + static_cast(std::numeric_limits::max()) == ends_int[0] + ? in_shapeordata.data().value().size() + : ends_int[0]; + + for (int64_t i = start; i < end; i++) { + out_data.push_back(in_shapeordata.data().value()[i]); + } + + const std::vector shape{std::int64_t(out_data.size())}; + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(shape, out_data)}; + }; + + return in_shapeordata.data().has_value() ? GetDataDimExprs() + : GetShapeDimExprs(); +} +} // namespace paddle::dialect::slice_utils diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc index 4e5f5df08732a..30730170e23a2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.cc @@ -16,6 +16,27 @@ namespace paddle::dialect::details { +std::optional> VecExpr2Int64(const ExprVec &expr_vec) { + std::vector int64vec; + for (auto item : expr_vec) { + if (!item.isa()) { + return std::nullopt; + } + int64vec.push_back(item.Get()); + } + return int64vec; +} + +ExprVec VecInt642Expr(const std::vector &int_vec) { + ExprVec expr_vec(int_vec.size(), 0); + std::transform( + int_vec.begin(), + int_vec.end(), + expr_vec.begin(), + [](int64_t val) -> symbol::DimExpr { return symbol::DimExpr(val); }); + return expr_vec; +} + bool ReduceInferDim(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis, const std::vector &axis, @@ -24,18 +45,18 @@ bool ReduceInferDim(pir::Operation *op, auto x = op->operand_source(0); int x_rank = x.type().dyn_cast().dims().size(); - const std::vector formated_axis = [&] { - std::vector formated_axis = axis; + const std::vector formatted_axis = [&] { + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); ++i) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + x_rank; + formatted_axis[i] = axis[i] + x_rank; } } - return formated_axis; + return formatted_axis; }(); bool full_dim = true; - std::set dims_set(formated_axis.begin(), formated_axis.end()); + std::set dims_set(formatted_axis.begin(), formatted_axis.end()); for (int64_t i = 0; i < x_rank; ++i) { if (dims_set.find(i) == dims_set.end()) { full_dim = false; @@ -83,8 +104,8 @@ void BuildCstrEqForTensorListAlongAxis( const symbol::TensorListShapeOrDataDimExprs &shape_data_list, int axis) { for (size_t i = 1; i < shape_data_list.size(); ++i) { - shape_analysis->CreateDimExprBuilder().CstrEq( - shape_data_list[0].shape()[axis], shape_data_list[i].shape()[axis]); + shape_analysis->DimExprBuilder().CstrEq(shape_data_list[0].shape()[axis], + shape_data_list[i].shape()[axis]); } } @@ -93,7 +114,7 @@ void BuildCstrEqForTensorListAlongAxis( const std::vector &values, int axis) { for (size_t i = 1; i < values.size(); ++i) { - shape_analysis->CreateDimExprBuilder().CstrEq( + shape_analysis->DimExprBuilder().CstrEq( shape_analysis->GetShapeOrDataForValue(values[0]).shape()[axis], shape_analysis->GetShapeOrDataForValue(values[i]).shape()[axis]); } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h index 8a14e40e6337a..42164c3c21254 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h @@ -14,9 +14,25 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" +inline bool GetBoolAttr(const pir::Operation *op, const std::string &str) { + const auto &attr_map = op->attributes(); + PADDLE_ENFORCE( + attr_map.count(str), + phi::errors::PreconditionNotMet( + "attr [%s] MUST in attribute map for [%s] op", str, op->name())); + return attr_map.at(str).dyn_cast().data(); +} + +// To make codes shorter +using ExprVec = std::vector; +using ShapeOrData = symbol::ShapeOrDataDimExprs; +using TensorExprs = symbol::TensorShapeOrDataDimExprs; +using TensorListExprs = symbol::TensorListShapeOrDataDimExprs; + namespace paddle::dialect::details { template struct AttributeTrait; @@ -31,6 +47,11 @@ struct AttributeTrait { using value_type = ::pir::Int32Attribute; }; +template <> +struct AttributeTrait { + using value_type = ::pir::FloatAttribute; +}; + template std::vector GetVectorAttr(const ::pir::Operation *op, const std::string &name) { @@ -60,6 +81,47 @@ std::vector GetVectorAttr(const ::pir::Operation *op, return vec_res; } +inline ExprVec GetExprVecFromData(const ShapeOrData &shapeordata) { + if (shapeordata.isa()) { + ExprVec result; + TensorListExprs list = + shapeordata.dyn_cast(); + for (size_t i = 0; i < list.size(); i++) { + if (list[i].data().has_value()) { + for (auto expr : list[i].data().value()) { + result.emplace_back(expr); + } + } + } + return result; + } else { + return shapeordata.data().value(); + } +} + +inline ExprVec GetExprVecFromShape(const ShapeOrData &shapeordata) { + const auto GetShapeExprsFromList = [&]() { + ExprVec result; + TensorListExprs list = + shapeordata.dyn_cast(); + for (size_t i = 0; i < list.size(); i++) { + for (auto expr : list[i].data().value()) { + result.emplace_back(expr); + } + } + return result; + }; + if (shapeordata.isa()) { + return GetShapeExprsFromList(); + } else { + return shapeordata.shape(); + } +} + +std::optional> VecExpr2Int64(const ExprVec &expr_vec); + +ExprVec VecInt642Expr(const std::vector &int_vec); + bool ReduceInferDim(pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis, const std::vector &axis, diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h index 4e1946acd75f1..6ad4d6609da94 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h @@ -14,10 +14,13 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cinn_op_infer_sym.h" -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_element_wise_binary.h" -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h" -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/element_wise_binary.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" // Type inference is currently modelled executionally for operation creation diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc new file mode 100644 index 0000000000000..3a1c411caf1b3 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -0,0 +1,407 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h" +#include "paddle/common/ddim.h" +#include "paddle/common/layout.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" + +namespace paddle::dialect { + +bool BicubicInterpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const symbol::ShapeOrDataDimExprs &x = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + + const auto &attributes = op->attributes(); + + const std::string data_format = + attributes.at("data_format").dyn_cast().AsString(); + int out_d = attributes.at("out_d").dyn_cast().data(); + int out_h = attributes.at("out_h").dyn_cast().data(); + int out_w = attributes.at("out_w").dyn_cast().data(); + const std::vector &scale = + paddle::dialect::details::GetVectorAttr(op, "scale"); + + std::vector size_tensor; + if (out_d != -1) size_tensor.push_back(out_d); + if (out_h != -1) size_tensor.push_back(out_h); + if (out_w != -1) size_tensor.push_back(out_w); + + const DataLayout data_layout = common::StringToDataLayout(data_format); + + if (x.shape().size() == 3) { + // shape check for 1D interpolate for input tensor shape NCHW + if (!size_tensor.empty()) { + // top priority size + std::vector dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {x.shape()[0], x.shape()[1], symbol::DimExpr{out_w}}; + } else { + dim_out = {x.shape()[0], symbol::DimExpr{out_w}, x.shape()[2]}; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + } + + symbol::DimExpr out_w_tmp{0}; + const auto &next_sym = shape_analysis->GetNextSymName(); + out_w_tmp = symbol::DimExpr(next_sym); + + std::vector dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {x.shape()[0], x.shape()[1], out_w_tmp}; + } else { + dim_out = {x.shape()[0], out_w_tmp, x.shape()[2]}; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + } else if (x.shape().size() == 4) { + // shape check for 2D interpolate for input tensor shape NCHW + if (!size_tensor.empty()) { + // top priority size + std::vector dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {x.shape()[0], + x.shape()[1], + symbol::DimExpr{out_h}, + symbol::DimExpr{out_w}}; + } else { + dim_out = {x.shape()[0], + symbol::DimExpr{out_h}, + symbol::DimExpr{out_w}, + x.shape()[3]}; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + } + + symbol::DimExpr out_h_tmp{0}; + symbol::DimExpr out_w_tmp{0}; + const auto &next_sym = shape_analysis->GetNextSymName(); + out_h_tmp = symbol::DimExpr(next_sym); + out_w_tmp = symbol::DimExpr(next_sym); + + std::vector dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {x.shape()[0], x.shape()[1], out_h_tmp, out_w_tmp}; + } else { + dim_out = {x.shape()[0], out_h_tmp, out_w_tmp, x.shape()[3]}; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + } else if (x.shape().size() == 5) { + // shape check for 3D interpolate for input tensor shape NCDHW + if (!size_tensor.empty()) { + // top priority size + std::vector dim_out; + if (data_layout == DataLayout::kNCHW) { + dim_out = {x.shape()[0], + x.shape()[1], + symbol::DimExpr{out_d}, + symbol::DimExpr{out_h}, + symbol::DimExpr{out_w}}; + } else { + dim_out = {x.shape()[0], + symbol::DimExpr{out_d}, + symbol::DimExpr{out_h}, + symbol::DimExpr{out_w}, + x.shape()[4]}; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + } + + symbol::DimExpr out_d_tmp{0}; + symbol::DimExpr out_h_tmp{0}; + symbol::DimExpr out_w_tmp{0}; + const auto &next_sym = shape_analysis->GetNextSymName(); + out_d_tmp = symbol::DimExpr(next_sym); + out_h_tmp = symbol::DimExpr(next_sym); + out_w_tmp = symbol::DimExpr(next_sym); + + std::vector dim_out; + + if (data_layout == DataLayout::kNCHW) { + dim_out = {x.shape()[0], x.shape()[1], out_d_tmp, out_h_tmp, out_w_tmp}; + } else { + dim_out = {x.shape()[0], out_d_tmp, out_h_tmp, out_w_tmp, x.shape()[4]}; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(dim_out)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; + + } else { + PADDLE_THROW(phi::errors::Fatal("Input(X) dimension must be 3, 4 or 5!")); + } + + return true; +} + +bool BilinearInterpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return BicubicInterpOpInferSymbolicShape(op, shape_analysis); +} + +bool ConcatOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const auto &shape_data_list = + shape_analysis->GetShapeOrDataForValue(operand_source) + .dyn_cast(); + + CHECK(op->operand_source(1).defining_op()->isa()); + + int64_t axis = op->operand_source(1) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + size_t rank = shape_data_list[0].shape().size(); + axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); + + if (shape_data_list[0].data().has_value()) { + if (rank == 1) { + const auto &s_or_d = + shape_analysis->GetShapeOrDataForValue(operand_source); + ExprVec data = details::GetExprVecFromData(s_or_d); + + const std::vector shape{std::int64_t(data.size())}; + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(shape, data)}; + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + + " 's InferSymbolicShape can NOT deal with rank > 1 now.")); + } + std::vector data; + data.reserve(shape_data_list.size()); + for (auto &data_elem : shape_data_list) { + data.push_back(data_elem.data().value()[0]); + } + const std::vector shape{std::int64_t(data.size())}; + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(shape, data)}; + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; + } + + const std::vector &out_dims = [&] { + std::vector out_dims = shape_data_list[0].shape(); + for (size_t i = 0; i < rank; ++i) { + if (i != static_cast(axis)) { + details::BuildCstrEqForTensorListAlongAxis( + shape_analysis, shape_data_list, i); + continue; + } + for (size_t j = 1; j < shape_data_list.size(); ++j) { + out_dims[axis] = out_dims[axis] + shape_data_list[j].shape()[axis]; + } + } + return out_dims; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; +} + +bool FullWithTensorOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + const auto &out_shape = operand_shape_or_data.data().has_value() + ? operand_shape_or_data.data().value() + : operand_shape_or_data.shape(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape)); + return true; +} + +bool FlashAttnOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &q = + shape_analysis->GetShapeOrDataForValue(operand_source); + + const symbol::ShapeOrDataDimExprs &v = + shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + + std::vector out_shape = q.shape(); + + out_shape.back() = v.shape().back(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape)); + return true; +} + +bool LinspaceOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &num_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + const auto step = [&] { + symbol::DimExpr expr; + if (num_shape_or_data.data().has_value()) { + expr = num_shape_or_data.data().value()[0]; + } else { + expr = num_shape_or_data.shape()[0]; + } + return expr; + }(); + const symbol::ShapeOrDataDimExprs &shape_data = [&] { + std::vector out_dims{step}; + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + }(); + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + +bool LinearInterpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return BicubicInterpOpInferSymbolicShape(op, shape_analysis); +} + +bool LogspaceOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return LinspaceOpInferSymbolicShape(op, shape_analysis); +} + +bool NearestInterpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return BicubicInterpOpInferSymbolicShape(op, shape_analysis); +} + +bool StackOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + + const auto &attributes = op->attributes(); + int axis = attributes.at("axis").dyn_cast().data(); + + const symbol::TensorListShapeOrDataDimExprs &shape_data_list = + shape_analysis->GetShapeOrDataForValue(operand_source) + .dyn_cast(); + + int rank = shape_data_list[0].shape().size(); + if (axis < 0) axis += rank + 1; + + const symbol::ShapeOrDataDimExprs shape_data = [&] { + std::vector shape_dim_exprs; + std::vector data_dim_exprs; + for (size_t i = 0; i < shape_data_list.size(); ++i) { + if (shape_data_list[i].data().has_value() && axis == 0) { + data_dim_exprs.emplace_back(shape_data_list[i].data().value()[0]); + } + } + + if (!data_dim_exprs.empty()) { + shape_dim_exprs.emplace_back( + static_cast(shape_data_list.size())); + } else { + for (int i = 0; i < rank; ++i) { + details::BuildCstrEqForTensorListAlongAxis( + shape_analysis, shape_data_list, i); + } + shape_dim_exprs.insert(shape_dim_exprs.begin() + axis, + static_cast(shape_data_list.size())); + } + + return symbol::ShapeOrDataDimExprs( + symbol::TensorShapeOrDataDimExprs(shape_dim_exprs, data_dim_exprs)); + }(); + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; +} + +bool TrilinearInterpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return BicubicInterpOpInferSymbolicShape(op, shape_analysis); +} + +bool WhereOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + shape_analysis->SetShapeOrDataForValue( + op->result(0), + shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); + + const std::vector &operands = {op->operand_source(0), + op->operand_source(1)}; + + size_t rank = shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + .shape() + .size(); + + for (size_t i = 0; i < rank; ++i) { + paddle::dialect::details::BuildCstrEqForTensorListAlongAxis( + shape_analysis, operands, i); + } + + return true; +} + +bool Where_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return WhereOpInferSymbolicShape(op, shape_analysis); +} + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h new file mode 100644 index 0000000000000..c5869cce7eb63 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -0,0 +1,35 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { + +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BicubicInterp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BilinearInterp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Concat) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullWithTensor) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttn) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Linspace) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LinearInterp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilinearInterp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where_) + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc new file mode 100644 index 0000000000000..0e294991449c1 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.cc @@ -0,0 +1,385 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" + +namespace paddle::dialect { + +bool ArangeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &start_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto &end_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + const auto &step_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); + + const auto start = [&] { + symbol::DimExpr expr; + if (start_shape_or_data.data().has_value()) { + expr = start_shape_or_data.data().value()[0]; + } else { + expr = start_shape_or_data.shape()[0]; + } + return expr; + }(); + + const auto end = [&] { + symbol::DimExpr expr; + if (end_shape_or_data.data().has_value()) { + expr = end_shape_or_data.data().value()[0]; + } else { + expr = end_shape_or_data.shape()[0]; + } + return expr; + }(); + + const auto step = [&] { + symbol::DimExpr expr; + if (step_shape_or_data.data().has_value()) { + expr = step_shape_or_data.data().value()[0]; + } else { + expr = step_shape_or_data.shape()[0]; + } + return expr; + }(); + + const symbol::ShapeOrDataDimExprs &shape_data = [&] { + std::vector out_dims; + // TODO(lanxianghit, jiahy0825): here should be ceil((end - start) / step), + // but DimExpr doesn't support ceil and float now + out_dims.emplace_back((end - start) / step); + return symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + }(); + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + return true; +} + +bool AssignValueOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const std::vector shape = + paddle::dialect::details::GetVectorAttr(op, "shape"); + std::vector sym_dims; + sym_dims.reserve(shape.size()); + for (const int &dim : shape) { + sym_dims.emplace_back(symbol::DimExpr(static_cast(dim))); + } + + const auto &attributes = op->attributes(); + std::vector values; + for (size_t i = 0; + i < attributes.at("values").dyn_cast().size(); + i++) { + values.push_back(attributes.at("values") + .dyn_cast() + .at(i) + .dyn_cast() + .data() + .to()); + } + if (values.size() == 1) { + std::vector data{values[0]}; + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims, data)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + +bool AssignValue_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return AssignValueOpInferSymbolicShape(op, shape_analysis); +} + +bool DataOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &attributes = op->attributes(); + pir::Attribute attr = attributes.at("shape"); + + const std::vector sym_dims = [&] { + std::vector sym_dims; + const std::vector &dims = + attr.dyn_cast().data().GetData(); + for (auto dim : dims) { + symbol::DimExpr dim_expr; + if (dim == pir::ShapedTypeInterface::kDynamic) { + symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); + dim_expr = symbolic_dim_expr; + } else { + symbol::DimExpr numeric_dim_expr(dim); + dim_expr = numeric_dim_expr; + } + sym_dims.push_back(dim_expr); + } + return sym_dims; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + return true; +} + +bool EmptyOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &shape_gen_op = op->operand_source(0).defining_op(); + if (shape_gen_op->isa()) { + std::vector shape = details::GetVectorAttr( + shape_gen_op->dyn_cast(), "value"); + std::vector sym_dims; + sym_dims.reserve(shape.size()); + for (const int64_t &dim : shape) { + sym_dims.emplace_back(symbol::DimExpr(dim)); + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; + + } else { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + shape_analysis->SetShapeOrDataForValue(op->result(0), + operand_shape_or_data); + return true; + } +} + +bool FeedOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const common::DDim &result_dims = + op->result(0).type().dyn_cast().dims(); + std::vector out_dims; + for (int i = 0; i < result_dims.size(); i++) { + if (result_dims[i] == -1) { + out_dims.emplace_back(shape_analysis->GetNextSymName()); + } else { + out_dims.emplace_back(result_dims[i]); + } + } + + shape_analysis->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} + +bool FullOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &attributes = op->attributes(); + + const std::vector shape = [&] { + pir::Attribute attr_shape = attributes.at("shape"); + const auto &shape_vec = + attr_shape.dyn_cast() + .data() + .GetData(); + std::vector shape(shape_vec.begin(), shape_vec.end()); + return shape; + }(); + + const auto shape_data = [&]() -> symbol::TensorShapeOrDataDimExprs { + // NOTE(Aurelius84): to is a risky operation when Scalar's dtype is + // not int32/int64. However, we found Full's Value could be like '3.0' but + // used as int. + const int64_t value = attributes.at("value") + .dyn_cast() + .data() + .to(); + const size_t shape_size = shape.size(); + // NOTE(Aurelius84): When shape.size()==1, a new std::vector with + // length = shape[0] will be constructed, but not all cases are used for + // ShapeAnalysis. Considering MAX_RANK < 9 in Paddle, we limit it below + // DATA_MAX_LENGTH = 128 and will not create this vector once length > + // DATA_MAX_LENGTH. + constexpr int64_t DATA_MAX_LENGTH = 128; + if (shape_size == 0U) { + std::vector data{value}; + return symbol::TensorShapeOrDataDimExprs(shape, data); + } else if (shape_size == 1U && + shape[0].template Get() <= DATA_MAX_LENGTH) { + std::vector data(shape[0].template Get(), + symbol::DimExpr(value)); + return symbol::TensorShapeOrDataDimExprs(shape, data); + } else { + return symbol::TensorShapeOrDataDimExprs(shape); + } + }(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::ShapeOrDataDimExprs(shape_data)); + return true; +} + +bool FullIntArrayOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &attributes = op->attributes(); + pir::Attribute attr_value = attributes.at("value"); + const auto &vec = attr_value.dyn_cast().AsVector(); + + const std::vector data = [&] { + std::vector data; + for (auto item : vec) { + int64_t i = item.dyn_cast().data(); + data.push_back(symbol::DimExpr(i)); + } + return data; + }(); + + const std::vector shape{std::int64_t(vec.size())}; + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(shape, data)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + return true; +} + +bool GaussianOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &shape_gen_op = op->operand_source(0).defining_op(); + + if (shape_gen_op->isa()) { + std::vector shape = details::GetVectorAttr( + shape_gen_op->dyn_cast(), "value"); + std::vector sym_dims; + sym_dims.reserve(shape.size()); + for (const int64_t &dim : shape) { + sym_dims.emplace_back(symbol::DimExpr(dim)); + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; + + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently shape must comes from FullIntArrayOp in GaussianOp's " + "InferSymbolicShape.")); + return true; + } +} + +bool RandintOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &shape_gen_op = op->operand_source(0).defining_op(); + + if (shape_gen_op->isa()) { + std::vector shape = details::GetVectorAttr( + shape_gen_op->dyn_cast(), "value"); + std::vector sym_dims; + sym_dims.reserve(shape.size()); + for (const int64_t &dim : shape) { + sym_dims.emplace_back(symbol::DimExpr(dim)); + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(sym_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; + + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Currently shape must comes from FullIntArrayOp in RandintOp's " + "InferSymbolicShape.")); + return true; + } +} + +bool TrilIndicesOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &attributes = op->attributes(); + int rows = attributes.at("rows").dyn_cast().data(); + int cols = attributes.at("cols").dyn_cast().data(); + int offset = attributes.at("offset").dyn_cast().data(); + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + auto n_first_row = + offset > 0 ? std::min(cols, 1 + offset) : rows + offset > 0; + auto n_last_row = + std::max(0, std::min(cols, rows + offset)); + auto n_row_all = + std::max(0, std::min(rows, rows + offset)); + auto n_row_trapezoid = (n_last_row - n_first_row + 1); + auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1; + auto diff_row = n_row_all - n_row_trapezoid; + if (diff_row > 0) { + tril_size += diff_row * cols; + } + out_sym_shape.emplace_back(std::int64_t(2)); + out_sym_shape.emplace_back(std::int64_t(tril_size)); + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} +bool TriuIndicesOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const auto &attributes = op->attributes(); + int row = attributes.at("row").dyn_cast().data(); + int col = attributes.at("col").dyn_cast().data(); + int offset = attributes.at("offset").dyn_cast().data(); + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + offset = offset - 1; + auto n_first_row = + offset > 0 ? std::min(col, 1 + offset) : row + offset > 0; + auto n_last_row = + std::max(0, std::min(col, row + offset)); + auto n_row_all = std::max(0, std::min(row, row + offset)); + auto n_row_trapezoid = (n_last_row - n_first_row + 1); + auto tril_size = (n_first_row + n_last_row) * n_row_trapezoid >> 1; + auto diff_row = n_row_all - n_row_trapezoid; + if (diff_row > 0) { + tril_size += diff_row * col; + } + out_sym_shape.emplace_back(std::int64_t(2)); + out_sym_shape.emplace_back(std::int64_t(row * col - tril_size)); + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} +bool UniformOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return GaussianOpInferSymbolicShape(op, shape_analysis); +} + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h new file mode 100644 index 0000000000000..a221eec936528 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/nullary_infer_sym.h @@ -0,0 +1,33 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Arange) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignValue) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(AssignValue_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Data) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Empty) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Feed) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Full) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(FullIntArray) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Gaussian) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Randint) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilIndices) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(TriuIndices) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Uniform) +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc deleted file mode 100644 index 65e9770350c80..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.cc +++ /dev/null @@ -1,1851 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" - -namespace paddle::dialect { - -// To make codes shorter -using ShapeOrData = symbol::ShapeOrDataDimExprs; -using TensorExprs = symbol::TensorShapeOrDataDimExprs; -using TensorListExprs = symbol::TensorListShapeOrDataDimExprs; - -bool DataOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &attributes = op->attributes(); - pir::Attribute attr = attributes.at("shape"); - - const std::vector sym_dims = [&] { - std::vector sym_dims; - const std::vector &dims = - attr.dyn_cast().data().GetData(); - for (auto dim : dims) { - symbol::DimExpr dim_expr; - if (dim == pir::ShapedTypeInterface::kDynamic) { - symbol::DimExpr symbolic_dim_expr(shape_analysis->GetNextSymName()); - dim_expr = symbolic_dim_expr; - } else { - symbol::DimExpr numeric_dim_expr(dim); - dim_expr = numeric_dim_expr; - } - sym_dims.push_back(dim_expr); - } - return sym_dims; - }(); - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(sym_dims)}; - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - - return true; -} - -bool ShapeOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - - const std::vector sym_shape = [&] { - std::vector sym_shape; - symbol::DimExpr dim_expr( - op->result(0).type().dyn_cast().dims()[0]); - sym_shape.emplace_back(dim_expr); - return sym_shape; - }(); - - symbol::ShapeOrDataDimExprs shape_or_data{symbol::TensorShapeOrDataDimExprs( - sym_shape, operand_shape_or_data.shape())}; - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_or_data); - - return true; -} - -bool ShapeSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ShapeOpInferSymbolicShape(op, shape_analysis); -} - -bool StackOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - - const auto &attributes = op->attributes(); - int axis = attributes.at("axis").dyn_cast().data(); - - const symbol::TensorListShapeOrDataDimExprs &shape_data_list = - shape_analysis->GetShapeOrDataForValue(operand_source) - .dyn_cast(); - - int rank = shape_data_list[0].shape().size(); - if (axis < 0) axis += rank + 1; - - const symbol::ShapeOrDataDimExprs shape_data = [&] { - std::vector shape_dim_exprs; - std::vector data_dim_exprs; - for (size_t i = 0; i < shape_data_list.size(); ++i) { - if (shape_data_list[i].data().has_value() && axis == 0) { - data_dim_exprs.emplace_back(shape_data_list[i].data().value()[0]); - } - } - - if (!data_dim_exprs.empty()) { - shape_dim_exprs.emplace_back( - static_cast(shape_data_list.size())); - } else { - for (int i = 0; i < rank; ++i) { - if (i == axis) continue; - details::BuildCstrEqForTensorListAlongAxis( - shape_analysis, shape_data_list, i); - } - shape_dim_exprs.insert(shape_dim_exprs.begin() + axis, - static_cast(shape_data_list.size())); - } - - return symbol::ShapeOrDataDimExprs( - symbol::TensorShapeOrDataDimExprs(shape_dim_exprs, data_dim_exprs)); - }(); - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - return true; -} - -bool SumOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &attributes = op->attributes(); - bool keepdim = attributes.at("keepdim").dyn_cast().data(); - - bool reduce_all = false; - - auto axis_gen_op = op->operand_source(1).defining_op(); - if (axis_gen_op->isa()) { - std::vector axis = details::GetVectorAttr( - axis_gen_op->dyn_cast(), "value"); - if (axis.size() == 0) { - reduce_all = true; - } - return details::ReduceInferDim( - op, shape_analysis, axis, keepdim, reduce_all); - } else { - // TODO(lanxianghit): deal with other source: pir::VectorType, - // paddle::dialect::DenseTensorType - PADDLE_THROW( - phi::errors::Unimplemented("SumOpInferSymbolicShape: 'axis' only " - "support FullIntArrayOp's result now.")); - } - - return true; -} - -bool ProdOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &attributes = op->attributes(); - bool keepdim = - attributes.at("keep_dim").dyn_cast().data(); - - bool reduce_all = - attributes.at("reduce_all").dyn_cast().data(); - - auto axis_gen_op = op->operand_source(1).defining_op(); - if (axis_gen_op->isa()) { - std::vector axis = details::GetVectorAttr( - axis_gen_op->dyn_cast(), "value"); - return details::ReduceInferDim( - op, shape_analysis, axis, keepdim, reduce_all); - } else { - // TODO(lanxianghit): deal with other source: pir::VectorType, - // paddle::dialect::DenseTensorType - PADDLE_THROW( - phi::errors::Unimplemented("ProdOpInferSymbolicShape: 'axis' only " - "support FullIntArrayOp's result now.")); - } - - return true; -} - -bool ReshapeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - if (shape_analysis->GetShapeOrDataForValue(operand_source) - .data() - .has_value()) { - const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); - shape_analysis->SetShapeOrDataForValue(op->result(0), - operand_shape_or_data); - return true; - } - - pir::Value operand_source_shape = op->operand_source(1); - - const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source_shape); - - const auto &GetProduct = [&](const auto &dim_exprs, const auto &Filter) { - symbol::DimExpr product{1}; - for (const auto &dim_expr : dim_exprs) { - if (Filter(dim_expr)) { - product = product * dim_expr; - } - } - return product; - }; - - const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { - if (dim_expr.isa()) { - return dim_expr.dyn_cast() != static_cast(-1); - } - return true; - }; - - const std::vector out_dims = [&] { - const auto &original_shape = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); - - const auto &numel = - GetProduct(original_shape, [](const auto &) { return true; }); - - const auto &product_exclude_minus_one = - GetProduct(operand_shape_or_data.data().value(), IsNotMinusOne); - - const auto &input_dims = operand_shape_or_data.data().value(); - - std::vector out_dims; - out_dims.reserve(input_dims.size()); - for (const auto &dim_expr : input_dims) { - const auto &out_dim_expr = IsNotMinusOne(dim_expr) - ? dim_expr - : (numel / product_exclude_minus_one); - out_dims.emplace_back(out_dim_expr); - } - - return out_dims; - }(); - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(out_dims)}; - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - shape_analysis->SetShapeOrDataForValue( - op->result(1), - shape_analysis->GetShapeOrDataForValue(operand_source_shape)); - return true; -} - -bool Reshape_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return ReshapeOpInferSymbolicShape(op, shape_analysis); -} - -bool FullIntArrayOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &attributes = op->attributes(); - pir::Attribute attr_value = attributes.at("value"); - const auto &vec = attr_value.dyn_cast().AsVector(); - - const std::vector data = [&] { - std::vector data; - for (auto item : vec) { - int64_t i = item.dyn_cast().data(); - data.push_back(symbol::DimExpr(i)); - } - return data; - }(); - - const std::vector shape{std::int64_t(vec.size())}; - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(shape, data)}; - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - return true; -} - -bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - // TODO(zhangbopd): Not implemented yet. - pir::Value operand_source = op->operand_source(0); - pir::Value operand_starts = op->operand_source(1); - pir::Value operand_ends = op->operand_source(2); - pir::Value res = op->result(0); - - const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); - const symbol::ShapeOrDataDimExprs &starts_shape_data = - shape_analysis->GetShapeOrDataForValue(operand_starts); - const symbol::ShapeOrDataDimExprs &ends_shape_data = - shape_analysis->GetShapeOrDataForValue(operand_ends); - - // Currently, we DO NOT support the case that any element in `axes` `starts` - // or `ends` is a Symbol. - const std::vector axes = [&] { - const auto &attributes = op->attributes(); - pir::Attribute attr_axes = attributes.at("axes"); - - const auto &axes_vec = attr_axes.dyn_cast().AsVector(); - std::vector axes; - int64_t rank = int64_t(operand_shape_or_data.shape().size()); - for (auto item : axes_vec) { - int64_t axis = item.dyn_cast().data(); - axes.emplace_back(axis >= 0 ? axis : std::max(int64_t(0), axis + rank)); - } - return axes; - }(); - - const std::vector starts = [&] { - std::vector starts; - for (auto item : starts_shape_data.data().value()) { - IR_ENFORCE(item.isa(), - "Currently, we DO NOT support the case that any element in " - "`starts` is a Symbol."); - starts.push_back(item.Get()); - } - return starts; - }(); - - const std::vector ends = [&] { - std::vector ends; - for (auto item : ends_shape_data.data().value()) { - IR_ENFORCE(item.isa(), - "Currently, we DO NOT support the case that any element in " - "`ends` is a Symbol."); - ends.push_back(item.Get()); - } - return ends; - }(); - - // When `pd.slice` is operating on a tensor which is produced by a `pd.shape` - // op, the reseult should be written into data. - const auto &GetDataDimExprs = [&]() -> symbol::ShapeOrDataDimExprs { - const std::vector out_data = [&] { - std::vector out_data; - const int64_t start = - starts[0] < 0 - ? starts[0] + operand_shape_or_data.data().value().size() - : starts[0]; - const int64_t end = - static_cast(std::numeric_limits::max()) == ends[0] - ? operand_shape_or_data.data().value().size() - : ends[0]; - - for (int64_t i = start; i < end; i++) { - out_data.push_back(operand_shape_or_data.data().value()[i]); - } - return out_data; - }(); - const std::vector shape{std::int64_t(out_data.size())}; - return symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(shape, out_data)}; - }; - - // Othewise, the reseult should be written into the shape. - const auto &GetShapeDimExprs = [&]() -> symbol::ShapeOrDataDimExprs { - std::vector out_shape = operand_shape_or_data.shape(); - - const std::vector &dim_expr_starts = - starts_shape_data.data().value(); - const std::vector &dim_expr_ends = - ends_shape_data.data().value(); - - // For both start and end can be negtive or positive, we need to handle the - // following different arrangements. - auto IsMaxInt = [](const symbol::DimExpr &expr) { - return expr.isa() && - expr.Get() == - static_cast(std::numeric_limits::max()); - }; - for (size_t i = 0; i < axes.size(); ++i) { - const int64_t axis = axes[i]; - auto end = - IsMaxInt(dim_expr_ends[i]) ? out_shape[axis] : dim_expr_ends[i]; - - bool both_negative_or_positive = - (starts[i] >= 0 && ends[i] >= 0) || (starts[i] <= 0 && ends[i] <= 0); - bool start_negative_end_positive = starts[i] <= 0 && ends[i] >= 0; - bool start_positive_end_negative = starts[i] >= 0 && ends[i] <= 0; - - if (both_negative_or_positive) { - out_shape[axis] = end - dim_expr_starts[i]; - } else if (start_negative_end_positive) { - out_shape[axis] = end - dim_expr_starts[i] - out_shape[axis]; - } else if (start_positive_end_negative) { - out_shape[axis] = out_shape[axis] - dim_expr_starts[i] + end; - } else { - LOG(FATAL) << "Dead code"; - } - } - - return symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(out_shape)}; - }; - - symbol::ShapeOrDataDimExprs shape_data = - operand_shape_or_data.data().has_value() ? GetDataDimExprs() - : GetShapeDimExprs(); - - shape_analysis->SetShapeOrDataForValue(res, shape_data); - return true; -} - -bool FullOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &attributes = op->attributes(); - - const std::vector shape = [&] { - std::vector shape; - pir::Attribute attr_shape = attributes.at("shape"); - const auto &shape_vec = - attr_shape.dyn_cast() - .data() - .GetData(); - - for (auto &dim : shape_vec) { - shape.push_back(symbol::DimExpr(dim)); - } - return shape; - }(); - - // Keep shape info always with `int64_t` type. - int64_t value = attributes.at("value") - .dyn_cast() - .data() - .to(); - std::vector data{symbol::DimExpr(value)}; - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(shape, data)}; - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - return true; -} - -bool ConcatOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - const auto &shape_data_list = - shape_analysis->GetShapeOrDataForValue(operand_source) - .dyn_cast(); - - CHECK(op->operand_source(1).defining_op()->isa()); - - int64_t axis = op->operand_source(1) - .defining_op() - .attributes() - .at("value") - .dyn_cast() - .data() - .to(); - size_t rank = shape_data_list[0].shape().size(); - axis = axis >= 0 ? axis : std::max(int64_t(0), int64_t(axis + rank)); - - if (shape_data_list[0].data().has_value()) { - std::vector data; - data.reserve(shape_data_list.size()); - for (auto &data_elem : shape_data_list) { - data.push_back(data_elem.data().value()[0]); - } - const std::vector shape{std::int64_t(data.size())}; - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(shape, data)}; - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - - return true; - } - - const std::vector &out_dims = [&] { - std::vector out_dims = shape_data_list[0].shape(); - for (size_t i = 0; i < rank; ++i) { - if (i != static_cast(axis)) { - details::BuildCstrEqForTensorListAlongAxis( - shape_analysis, shape_data_list, i); - continue; - } - for (size_t j = 1; j < shape_data_list.size(); ++j) { - out_dims[axis] = out_dims[axis] + shape_data_list[j].shape()[axis]; - } - } - return out_dims; - }(); - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(out_dims)}; - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - - return true; -} - -bool GatherNdOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - auto index_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - - std::vector x_sym_shape; - if (x_shape_or_data.data().has_value()) { - x_sym_shape = x_shape_or_data.data().value(); - } else { - x_sym_shape = x_shape_or_data.shape(); - } - int x_dims_size = x_sym_shape.size(); - - std::vector index_sym_shape; - if (index_shape_or_data.data().has_value()) { - index_sym_shape = index_shape_or_data.data().value(); - } else { - index_sym_shape = index_shape_or_data.shape(); - } - int index_dims_size = index_sym_shape.size(); - - std::vector result_sym_dims; - // The result dims is - // Index.shape[:-1] + X.shape[Index.shape[-1]:] - for (int i = 0; i < index_dims_size - 1; ++i) { - result_sym_dims.emplace_back(index_sym_shape[i]); - } - - PADDLE_ENFORCE_EQ( - index_sym_shape[index_dims_size - 1].Has(), - true, - phi::errors::InvalidArgument( - "in GatherNdOpInferSymbolicShape: index[-1] should be unknown")); - - for (int i = static_cast( - index_sym_shape[index_dims_size - 1].Get()); - i < x_dims_size; - ++i) { - result_sym_dims.emplace_back(x_sym_shape[i]); - } - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - - return true; -} - -bool SqueezeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - IR_ENFORCE(op->num_operands() == 2, - "SqueezeOpInferSymbolicShape ONLY support num_operands() == 2 " - "now, but got %d operands", - op->num_operands()); - - auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - auto axes_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - - std::vector in_dims_sym; - if (x_shape_or_data.data().has_value()) { - in_dims_sym = x_shape_or_data.data().value(); - } else { - in_dims_sym = x_shape_or_data.shape(); - } - - std::vector squeeze_dims_sym; - if (axes_shape_or_data.data().has_value()) { - squeeze_dims_sym = axes_shape_or_data.data().value(); - } else { - squeeze_dims_sym = axes_shape_or_data.shape(); - } - - std::vector squeeze_dims; - for (auto squeeze_dim : squeeze_dims_sym) { - IR_ENFORCE(squeeze_dim.Has(), - "in SqueezeOpInferSymbolicShape, axes must be known int type, " - "but got: %s", - symbol::ToString(squeeze_dim)); - squeeze_dims.emplace_back( - static_cast(squeeze_dim.Get())); - } - - // GetOutputSqueezeShape - size_t num_squeeze_dims = squeeze_dims.size(); - std::vector should_squeeze(in_dims_sym.size(), false); - // Mark dimensions need to be squeezed. - if (num_squeeze_dims == 0) { - for (size_t i = 0; i < in_dims_sym.size(); ++i) { - // TODO(lanxianghit): if symbol here, maybe we need the result of dim expr - // simplification - if (in_dims_sym[i] == 1) { - should_squeeze[i] = true; - } - } - } else { - for (size_t i = 0; i < num_squeeze_dims; ++i) { - if (in_dims_sym.size() == 0) { - continue; - } - int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims_sym.size() - : squeeze_dims[i]; - - if (!should_squeeze[current]) { - // At compile time, dim of SYMBOL is allowed to squeeze? - if (in_dims_sym[current] == 1) { - should_squeeze[current] = true; - } else if (!in_dims_sym[current].Has()) { - PADDLE_THROW( - phi::errors::Unimplemented("SqueezeOpInferSymbolicShape CAN NOT " - "deal with symbol in axis now")); - } - } - } - } - - // Make output dimensions - std::vector output_shape_sym; - for (size_t i = 0; i < in_dims_sym.size(); ++i) { - if (!should_squeeze[i]) { - output_shape_sym.emplace_back(in_dims_sym[i]); - } - } - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(output_shape_sym)}; - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - - return true; -} -bool Squeeze_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SqueezeOpInferSymbolicShape(op, shape_analysis); -} - -bool UnsqueezeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - IR_ENFORCE(op->num_operands() == 2, - "UnsqueezeOp InferSymbolicShape ONLY support num_operands() == 2 " - "now, but got %d operands", - op->num_operands()); - - auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - auto axes_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - - std::vector x_sym_shape; - if (x_shape_or_data.data().has_value()) { - x_sym_shape = x_shape_or_data.data().value(); - } else { - x_sym_shape = x_shape_or_data.shape(); - } - int x_dims_size = x_sym_shape.size(); - - std::vector axes_sym; - if (axes_shape_or_data.data().has_value()) { - axes_sym = axes_shape_or_data.data().value(); - } else { - axes_sym = axes_shape_or_data.shape(); - } - int axes_sym_size = axes_sym.size(); - - // GetUnsqueezeShape - int output_rank = x_dims_size + axes_sym_size; - std::vector result_sym_dims(output_rank, 0); - - int cur_output_rank = x_dims_size; - for (auto axis_expr : axes_sym) { - IR_ENFORCE(axis_expr.Has(), - "in UnsqueezeOpInferSymbolicShape, axes must be known int type, " - "but got: %s", - symbol::ToString(axis_expr)); - int axis = static_cast(axis_expr.Get()); - int cur = axis < 0 ? axis + cur_output_rank + 1 : axis; - - // Move old axis, and insert new axis - for (int i = cur_output_rank; i >= cur; --i) { - if (result_sym_dims[i] == 1) { - // Move axis - result_sym_dims[i + 1] = 1; - result_sym_dims[i] = 0; - } - } - result_sym_dims[cur] = 1; - // Add the output size. - cur_output_rank++; - } - - // Make output shape - for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) { - if (result_sym_dims[out_idx] == 0) { - result_sym_dims[out_idx] = x_sym_shape[in_idx++]; - } - } - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - - return true; -} -bool Unsqueeze_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return UnsqueezeOpInferSymbolicShape(op, shape_analysis); -} - -bool TileOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_x = op->operand_source(0); - symbol::ShapeOrDataDimExprs x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_x); - pir::Value operand_repeat_times = op->operand_source(1); - symbol::ShapeOrDataDimExprs repeat_times_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_repeat_times); - - std::vector x_dimexpr; - if (x_shape_or_data.data().has_value()) { - x_dimexpr = x_shape_or_data.data().value(); - } else { - x_dimexpr = x_shape_or_data.shape(); - } - - std::vector repeat_times_dimexpr; - if (repeat_times_shape_or_data.data().has_value()) { - repeat_times_dimexpr = repeat_times_shape_or_data.data().value(); - } else { - repeat_times_dimexpr = repeat_times_shape_or_data.shape(); - } - if (repeat_times_dimexpr.empty()) { - repeat_times_dimexpr = std::vector(x_dimexpr.size(), 1); - } - - auto out_rank = std::max(static_cast(x_dimexpr.size()), - repeat_times_dimexpr.size()); - std::vector out_shape(out_rank); - if (x_dimexpr.size() > repeat_times_dimexpr.size()) { - auto diff = x_dimexpr.size() - repeat_times_dimexpr.size(); - repeat_times_dimexpr.insert(repeat_times_dimexpr.begin(), diff, 1); - } else { - auto diff = repeat_times_dimexpr.size() - x_dimexpr.size(); - x_dimexpr.insert(x_dimexpr.begin(), diff, 1); - } - - for (size_t i = 0; i < repeat_times_dimexpr.size(); ++i) { - out_shape[i] = x_dimexpr[i] * repeat_times_dimexpr[i]; - } - - symbol::ShapeOrDataDimExprs shape_data{ - symbol::TensorShapeOrDataDimExprs(out_shape)}; - - pir::Value res = op->result(0); - shape_analysis->SetShapeOrDataForValue(res, shape_data); - - return true; -} - -bool TransposeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - std::vector perm = - op->attributes().at("perm").dyn_cast().AsVector(); - if (perm.size() == 1) { - // perm must be [0], which means nothing to do with input, just copy the - // info from input - shape_analysis->SetShapeOrDataForValue( - op->result(0), - shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); - return true; - } - const std::vector &x_dims = [&] { - std::vector dims; - const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - if (x_shape_or_data.data().has_value()) { - dims = x_shape_or_data.data().value(); - } else { - dims = x_shape_or_data.shape(); - } - return dims; - }(); - - int x_rank = x_dims.size(); - - const std::vector formated_axis = [op, x_rank, &perm] { - std::vector out(perm.size(), 0); - std::transform(perm.begin(), - perm.end(), - out.begin(), - [](pir::Attribute &p) -> int32_t { - return p.dyn_cast().data(); - }); - - // format the negtive axis - std::for_each(out.begin(), out.end(), [x_rank](int32_t &v) { - if (v < 0) { - v += x_rank; - } - }); - return out; - }(); - - int axis_size = static_cast(formated_axis.size()); - - std::vector out_dims(x_dims); - for (int i = 0; i < axis_size; ++i) { - out_dims[i] = x_dims[formated_axis[i]]; - } - - shape_analysis->SetShapeOrDataForValue(op->result(0), - ShapeOrData{TensorExprs(out_dims)}); - - return true; -} -bool Transpose_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return TransposeOpInferSymbolicShape(op, shape_analysis); -} - -bool ArangeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto &start_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - const auto &end_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - const auto &step_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(2)); - - const auto start = [&] { - symbol::DimExpr expr; - if (start_shape_or_data.data().has_value()) { - expr = start_shape_or_data.data().value()[0]; - } else { - expr = start_shape_or_data.shape()[0]; - } - return expr; - }(); - - const auto end = [&] { - symbol::DimExpr expr; - if (end_shape_or_data.data().has_value()) { - expr = end_shape_or_data.data().value()[0]; - } else { - expr = end_shape_or_data.shape()[0]; - } - return expr; - }(); - - const auto step = [&] { - symbol::DimExpr expr; - if (step_shape_or_data.data().has_value()) { - expr = step_shape_or_data.data().value()[0]; - } else { - expr = step_shape_or_data.shape()[0]; - } - return expr; - }(); - - const symbol::ShapeOrDataDimExprs &shape_data = [&] { - std::vector out_dims; - // TODO(lanxianghit, jiahy0825): here should be ceil((end - start) / step), - // but DimExpr doesn't support ceil and float now - out_dims.emplace_back((end - start) / step); - return symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(out_dims)}; - }(); - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - - return true; -} - -bool EmbeddingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - const auto weight_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - const std::vector &x_dims = [&] { - std::vector dims; - if (x_shape_or_data.data().has_value()) { - dims = x_shape_or_data.data().value(); - } else { - dims = x_shape_or_data.shape(); - } - return dims; - }(); - - const std::vector &weight_dims = [&] { - std::vector dims; - if (weight_shape_or_data.data().has_value()) { - dims = weight_shape_or_data.data().value(); - } else { - dims = weight_shape_or_data.shape(); - } - return dims; - }(); - - const symbol::ShapeOrDataDimExprs &shape_data = [&] { - std::vector out_dims = x_dims; - // no need to check validation of weight_dims index, since all checks have - // been done at corresponding InferMeta - out_dims.emplace_back(weight_dims[1]); - return symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(out_dims)}; - }(); - - shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); - - return true; -} - -bool SparseWeightEmbeddingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - -bool ExpandOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - -bool MatmulOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - // x_dims can't be const or ref here, in case to be broadcasted - std::vector x_dims = [&] { - std::vector dims; - const auto &x_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - if (x_shape_or_data.data().has_value()) { - dims = x_shape_or_data.data().value(); - } else { - dims = x_shape_or_data.shape(); - } - return dims; - }(); - - // y_dims can't be const or ref here, in case to be broadcasted - std::vector y_dims = [&] { - std::vector dims; - const auto y_shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); - if (y_shape_or_data.data().has_value()) { - dims = y_shape_or_data.data().value(); - } else { - dims = y_shape_or_data.shape(); - } - return dims; - }(); - - size_t ndims_x = x_dims.size(); - size_t ndims_y = y_dims.size(); - - const bool x_broadcasted = [&] { - bool broadcasted = false; - if (ndims_x == 1) { - x_dims.insert(x_dims.begin(), 1); - ndims_x = 2; - broadcasted = true; - } - return broadcasted; - }(); - - const bool y_broadcasted = [&] { - bool broadcasted = false; - if (ndims_y == 1) { - y_dims.emplace_back(1); - ndims_y = 2; - broadcasted = true; - } - return broadcasted; - }(); - - std::vector out_dims; - if (ndims_x > ndims_y) { - out_dims.assign(x_dims.begin(), x_dims.end() - 2); - } else if (ndims_x < ndims_y) { - out_dims.assign(y_dims.begin(), y_dims.end() - 2); - } else { - symbol::DimExprBuilder builder{nullptr}; - for (size_t i = 0; i < ndims_x - 2; ++i) { - out_dims.emplace_back(builder.Broadcast(x_dims[i], y_dims[i])); - } - } - - symbol::DimExpr out_M = - op->attributes().at("transpose_x").dyn_cast().data() - ? x_dims[ndims_x - 1] - : x_dims[ndims_x - 2]; - symbol::DimExpr out_N = - op->attributes().at("transpose_y").dyn_cast().data() - ? y_dims[ndims_y - 2] - : y_dims[ndims_y - 1]; - if (!x_broadcasted) { - out_dims.emplace_back(out_M); - } - if (!y_broadcasted) { - out_dims.emplace_back(out_N); - } - - shape_analysis->SetShapeOrDataForValue(op->result(0), - ShapeOrData{TensorExprs(out_dims)}); - - return true; -} - -bool MaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - bool keepdim = - op->attributes().at("keepdim").dyn_cast().data(); - - const std::vector axis = [&] { - pir::Operation *axis_gen_op = op->operand_source(1).defining_op(); - std::vector axis_vec; - if (axis_gen_op->isa()) { - axis_vec = details::GetVectorAttr( - axis_gen_op->dyn_cast(), "value"); - } else { - // TODO(lanxianghit): there's other source: pir::VectorType, - // paddle::dialect::DenseTensorType, but after PRIM, maybe always - // FullIntArrayOp, to be confirmed - PADDLE_THROW( - phi::errors::Unimplemented("MaxOpInferSymbolicShape: 'axis' only " - "support FullIntArrayOp's result now.")); - } - return axis_vec; - }(); - - bool reduce_all = axis.size() == 0 ? true : false; - - return details::ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); -} - -bool WhereOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - shape_analysis->SetShapeOrDataForValue( - op->result(0), - shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); - return true; -} - -bool Where_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return WhereOpInferSymbolicShape(op, shape_analysis); -} - -bool FeedOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - const common::DDim &result_dims = - op->result(0).type().dyn_cast().dims(); - std::vector out_dims; - for (int i = 0; i < result_dims.size(); i++) { - if (result_dims[i] == -1) { - out_dims.emplace_back(shape_analysis->GetNextSymName()); - } else { - out_dims.emplace_back(result_dims[i]); - } - } - - shape_analysis->SetShapeOrDataForValue( - op->result(0), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); - - return true; -} - -bool TopPSamplingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - - const auto &x_dims = [op, shape_analysis] { - const auto &shape_or_data = - shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); - if (shape_or_data.data().has_value()) { - return shape_or_data.data().value(); - } else { - return shape_or_data.shape(); - } - }(); - - // all the result have the same shape - for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { - const std::vector out_dims{x_dims[0], 1}; - shape_analysis->SetShapeOrDataForValue( - op->result(rst_idx), - symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(out_dims)}); - } - - return true; -} - -bool ExpandAsOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - -bool SplitOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - -// Not Impelmented Ops. -bool AcosOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Acos_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AcoshOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Acosh_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AngleOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ArgmaxOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ArgminOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ArgsortOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AsComplexOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AsRealOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AsStridedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AsinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Asin_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AsinhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Asinh_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AtanOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Atan_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool AtanhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Atanh_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool BernoulliOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool BitwiseNotOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool BitwiseNot_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool BitwiseXorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool BitwiseXor_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CeilOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Ceil_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ComplexOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ConjOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CosOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Cos_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CoshOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Cosh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CummaxOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CumminOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CumprodOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Cumprod_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool CumsumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Cumsum_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool DiagEmbedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool DiagonalOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool DirichletOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ErfOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Erf_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ErfinvOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Erfinv_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Expm1OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Expm1_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool FlipOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool FloorOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Floor_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool FmaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool FminOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool GatherOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ImagOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool IsinfOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool IsinfSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool IsnanOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool IsnanSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool KronOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool KthvalueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LgammaOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Lgamma_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Log1pOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Log1p_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogcumsumexpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogicalOrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogicalOr_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogicalXorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogicalXor_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogitOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Logit_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool MaskedSelectOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool PoissonOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool PutAlongAxisOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool PutAlongAxis_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool RealOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool RollOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool RoundOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Round_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ScatterNdAddOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool ScatterOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Scatter_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool SearchsortedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool SignOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool SinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Sin_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool SinhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Sinh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool TakeAlongAxisOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool TanOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Tan_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool TanhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Tanh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool TopkOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool UnbindOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool UniqueConsecutiveOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - -bool EinsumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool EmptyOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool EqualOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Equal_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Exponential_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool GaussianOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool GreaterEqualOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool GreaterEqual_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LessEqualOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LessEqual_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LinspaceOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogspaceOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool LogsumexpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool MaximumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool MinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool MinimumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool PadOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool RandintOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool RemainderOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool Remainder_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool RepeatInterleaveOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool SplitWithNumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool TrilIndicesOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool TriuIndicesOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool UniformOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} -bool UniqueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - PADDLE_THROW(phi::errors::Unimplemented( - op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); - return true; -} - -} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h deleted file mode 100644 index ee5bcacf63a1f..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h +++ /dev/null @@ -1,362 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" - -namespace paddle::dialect { - -bool DataOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ShapeOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ShapeSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool StackOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SumOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReshapeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Reshape_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool FullIntArrayOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SliceOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool FullOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ConcatOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool GatherNdOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SqueezeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Squeeze_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool UnsqueezeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Unsqueeze_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool TileOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool TransposeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Transpose_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ProdOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ArangeOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool EmbeddingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SparseWeightEmbeddingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ExpandOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool MatmulOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool MaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool TransposeOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool WhereOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Where_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool FeedOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool TopPSamplingOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ExpandAsOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SplitOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -// Not Impelmented Ops. -bool AcosOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Acos_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AcoshOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Acosh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AngleOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ArgmaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ArgminOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ArgsortOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AsComplexOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AsRealOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AsStridedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AsinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Asin_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AsinhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Asinh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AtanOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Atan_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool AtanhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Atanh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool BernoulliOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool BitwiseNotOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool BitwiseNot_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool BitwiseXorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool BitwiseXor_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CeilOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Ceil_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ComplexOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ConjOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CosOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Cos_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CoshOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Cosh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CummaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CumminOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CumprodOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Cumprod_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool CumsumOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Cumsum_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool DiagEmbedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool DiagonalOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool DirichletOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ErfOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Erf_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ErfinvOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Erfinv_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Expm1OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Expm1_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool FlipOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool FloorOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Floor_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool FmaxOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool FminOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool GatherOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ImagOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool IsinfOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool IsinfSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool IsnanOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool IsnanSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool KronOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool KthvalueOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LgammaOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Lgamma_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Log1pOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Log1p_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogcumsumexpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogicalOrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogicalOr_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogicalXorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogicalXor_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogitOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Logit_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool MaskedSelectOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool PoissonOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool PutAlongAxisOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool PutAlongAxis_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool RealOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool RollOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool RoundOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Round_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ScatterNdAddOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ScatterOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Scatter_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool SearchsortedOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool SignOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool SinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Sin_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool SinhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Sinh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool TakeAlongAxisOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool TanOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Tan_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool TanhOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Tanh_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool TopkOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool UnbindOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool UniqueConsecutiveOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool EinsumOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool EmptyOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool EqualOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Equal_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Exponential_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool GaussianOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool GreaterEqualOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool GreaterEqual_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LessEqualOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LessEqual_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LinspaceOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogspaceOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool LogsumexpOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool MaximumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool MinOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool MinimumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool PadOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool RandintOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool RemainderOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Remainder_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool RepeatInterleaveOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool SplitWithNumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool TrilIndicesOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool TriuIndicesOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool UniformOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool UniqueOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc deleted file mode 100644 index 98a6d670869ca..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.cc +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h" - -bool SameOperandsAndResultShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - pir::Value operand_source = op->operand_source(0); - const symbol::ShapeOrDataDimExprs &operand_shape_or_data = - shape_analysis->GetShapeOrDataForValue(operand_source); - - shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); - return true; -} - -namespace paddle::dialect { - -bool AbsOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Abs_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool AssignOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Assign_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return AssignOpInferSymbolicShape(op, shape_analysis); -} - -bool CastOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Cast_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool ExpOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Exp_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool FetchOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - shape_analysis->SetShapeOrDataForValue( - op->result(0), - shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); - - return true; -} - -bool IncrementOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Increment_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return IncrementOpInferSymbolicShape(op, shape_analysis); -} - -bool LogOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Log_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return LogOpInferSymbolicShape(op, shape_analysis); -} - -bool LogicalNotOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool LogicalNot_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return LogicalNotOpInferSymbolicShape(op, shape_analysis); -} - -bool FullWithTensorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool PowOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} -bool Pow_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool ReluOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Relu_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool RsqrtOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} -bool Rsqrt_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool ScaleOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} -bool Scale_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} -bool ScaleSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} -bool ScaleSr_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool SubtractOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Subtract_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool TrilOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return SameOperandsAndResultShape(op, shape_analysis); -} - -bool Tril_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis) { - return TrilOpInferSymbolicShape(op, shape_analysis); -} - -} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h deleted file mode 100644 index d96f4efe1f825..0000000000000 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_and_result.h +++ /dev/null @@ -1,104 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" - -namespace paddle::dialect { -bool AbsOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Abs_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool AssignOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Assign_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool CastOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Cast_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ExpOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Exp_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool FetchOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool FullWithTensorOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool IncrementOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Increment_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LogOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Log_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LogicalNotOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool LogicalNot_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool PowOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Pow_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ReluOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Relu_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool RsqrtOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Rsqrt_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool ScaleOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Scale_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ScaleSrOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool ScaleSr_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool SubtractOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); -bool Subtract_OpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool TrilOpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -bool Tril_OpInferSymbolicShape(pir::Operation *op, - pir::ShapeConstraintIRAnalysis *shape_analysis); - -} // namespace paddle::dialect - -namespace cinn::dialect { -using paddle::dialect::ScaleOpInferSymbolicShape; -} diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc new file mode 100644 index 0000000000000..04e5032098367 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc @@ -0,0 +1,175 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h" + +#define OP_SAME_OPERANDS_AND_RESULT(name) \ + bool name##OpInferSymbolicShape( \ + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { \ + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = \ + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); \ + shape_analysis->SetShapeOrDataForValue(op->result(0), \ + operand_shape_or_data); \ + return true; \ + } + +namespace paddle::dialect { + +OP_SAME_OPERANDS_AND_RESULT(Abs) +OP_SAME_OPERANDS_AND_RESULT(Abs_) +OP_SAME_OPERANDS_AND_RESULT(Acos) +OP_SAME_OPERANDS_AND_RESULT(Acos_) +OP_SAME_OPERANDS_AND_RESULT(Acosh) +OP_SAME_OPERANDS_AND_RESULT(Acosh_) +OP_SAME_OPERANDS_AND_RESULT(Angle) +OP_SAME_OPERANDS_AND_RESULT(Argsort) +OP_SAME_OPERANDS_AND_RESULT(Asin) +OP_SAME_OPERANDS_AND_RESULT(Asin_) +OP_SAME_OPERANDS_AND_RESULT(Asinh) +OP_SAME_OPERANDS_AND_RESULT(Asinh_) +OP_SAME_OPERANDS_AND_RESULT(Assign) +OP_SAME_OPERANDS_AND_RESULT(Assign_) +OP_SAME_OPERANDS_AND_RESULT(Atan) +OP_SAME_OPERANDS_AND_RESULT(Atan_) +OP_SAME_OPERANDS_AND_RESULT(Atanh) +OP_SAME_OPERANDS_AND_RESULT(Atanh_) +OP_SAME_OPERANDS_AND_RESULT(Bernoulli) +OP_SAME_OPERANDS_AND_RESULT(BitwiseNot) +OP_SAME_OPERANDS_AND_RESULT(BitwiseNot_) +OP_SAME_OPERANDS_AND_RESULT(Cast) +OP_SAME_OPERANDS_AND_RESULT(Cast_) +OP_SAME_OPERANDS_AND_RESULT(Ceil) +OP_SAME_OPERANDS_AND_RESULT(Ceil_) +OP_SAME_OPERANDS_AND_RESULT(Conj) +OP_SAME_OPERANDS_AND_RESULT(Cos) +OP_SAME_OPERANDS_AND_RESULT(Cos_) +OP_SAME_OPERANDS_AND_RESULT(Cosh) +OP_SAME_OPERANDS_AND_RESULT(Cosh_) +OP_SAME_OPERANDS_AND_RESULT(Digamma) +OP_SAME_OPERANDS_AND_RESULT(Digamma_) +OP_SAME_OPERANDS_AND_RESULT(Dirichlet) +OP_SAME_OPERANDS_AND_RESULT(Equal) +OP_SAME_OPERANDS_AND_RESULT(Equal_) +OP_SAME_OPERANDS_AND_RESULT(Erf) +OP_SAME_OPERANDS_AND_RESULT(Erf_) +OP_SAME_OPERANDS_AND_RESULT(Erfinv) +OP_SAME_OPERANDS_AND_RESULT(Erfinv_) +OP_SAME_OPERANDS_AND_RESULT(Exp) +OP_SAME_OPERANDS_AND_RESULT(Exp_) +OP_SAME_OPERANDS_AND_RESULT(Expm1) +OP_SAME_OPERANDS_AND_RESULT(Expm1_) +OP_SAME_OPERANDS_AND_RESULT(Exponential_) +OP_SAME_OPERANDS_AND_RESULT(Fetch) +OP_SAME_OPERANDS_AND_RESULT(Flip) +OP_SAME_OPERANDS_AND_RESULT(Floor) +OP_SAME_OPERANDS_AND_RESULT(Floor_) +OP_SAME_OPERANDS_AND_RESULT(Imag) +OP_SAME_OPERANDS_AND_RESULT(Increment) +OP_SAME_OPERANDS_AND_RESULT(Increment_) +OP_SAME_OPERANDS_AND_RESULT(Isinf) +OP_SAME_OPERANDS_AND_RESULT(IsinfSr) +OP_SAME_OPERANDS_AND_RESULT(Isnan) +OP_SAME_OPERANDS_AND_RESULT(IsnanSr) +OP_SAME_OPERANDS_AND_RESULT(Lgamma) +OP_SAME_OPERANDS_AND_RESULT(Lgamma_) +OP_SAME_OPERANDS_AND_RESULT(Log1p) +OP_SAME_OPERANDS_AND_RESULT(Log1p_) +OP_SAME_OPERANDS_AND_RESULT(Log) +OP_SAME_OPERANDS_AND_RESULT(Log_) +OP_SAME_OPERANDS_AND_RESULT(LogicalNot) +OP_SAME_OPERANDS_AND_RESULT(LogicalNot_) +OP_SAME_OPERANDS_AND_RESULT(Logit) +OP_SAME_OPERANDS_AND_RESULT(Logit_) +OP_SAME_OPERANDS_AND_RESULT(Pow) +OP_SAME_OPERANDS_AND_RESULT(Poisson) +OP_SAME_OPERANDS_AND_RESULT(Pow_) +OP_SAME_OPERANDS_AND_RESULT(Print) +OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis) +OP_SAME_OPERANDS_AND_RESULT(PutAlongAxis_) +OP_SAME_OPERANDS_AND_RESULT(Real) +OP_SAME_OPERANDS_AND_RESULT(Relu) +OP_SAME_OPERANDS_AND_RESULT(Relu_) +OP_SAME_OPERANDS_AND_RESULT(Roll) +OP_SAME_OPERANDS_AND_RESULT(Round) +OP_SAME_OPERANDS_AND_RESULT(Round_) +OP_SAME_OPERANDS_AND_RESULT(Rsqrt) +OP_SAME_OPERANDS_AND_RESULT(Rsqrt_) +OP_SAME_OPERANDS_AND_RESULT(ScaleSr) +OP_SAME_OPERANDS_AND_RESULT(ScaleSr_) +OP_SAME_OPERANDS_AND_RESULT(Scale_) +OP_SAME_OPERANDS_AND_RESULT(ScatterNdAdd) +OP_SAME_OPERANDS_AND_RESULT(Scatter) +OP_SAME_OPERANDS_AND_RESULT(Scatter_) +OP_SAME_OPERANDS_AND_RESULT(Sign) +OP_SAME_OPERANDS_AND_RESULT(Sin) +OP_SAME_OPERANDS_AND_RESULT(Sin_) +OP_SAME_OPERANDS_AND_RESULT(Sinh) +OP_SAME_OPERANDS_AND_RESULT(Sinh_) +OP_SAME_OPERANDS_AND_RESULT(Softmax) +OP_SAME_OPERANDS_AND_RESULT(Softmax_) +OP_SAME_OPERANDS_AND_RESULT(Tan) +OP_SAME_OPERANDS_AND_RESULT(Tan_) +OP_SAME_OPERANDS_AND_RESULT(Tanh) +OP_SAME_OPERANDS_AND_RESULT(Tanh_) +OP_SAME_OPERANDS_AND_RESULT(Tril) +OP_SAME_OPERANDS_AND_RESULT(Tril_) +OP_SAME_OPERANDS_AND_RESULT(Triu) +OP_SAME_OPERANDS_AND_RESULT(Triu_) +OP_SAME_OPERANDS_AND_RESULT(Trunc) +OP_SAME_OPERANDS_AND_RESULT(Trunc_) + +bool ScaleOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + std::vector shape(operand_shape_or_data.shape()); + + if (operand_shape_or_data.data()) { + const std::vector data = [&] { + const symbol::DimExpr scale = [&]() -> symbol::DimExpr { + if (op->num_operands() == 2) { + return shape_analysis->GetShapeOrDataForValue(op->operand_source(1)) + .data() + ->at(0); + } + return static_cast( + op->attribute("scale").dyn_cast().data()); + }(); + int bias = op->attribute("bias").dyn_cast().data(); + + std::vector data; + for (auto &val : *(operand_shape_or_data.data())) { + data.push_back(val * scale + bias); + } + return data; + }(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs(shape, data)); + } else { + shape_analysis->SetShapeOrDataForValue(op->result(0), + operand_shape_or_data); + } + + return true; +} + +} // namespace paddle::dialect + +namespace cinn::dialect { +using paddle::dialect::ScaleOpInferSymbolicShape; +} + +#undef OP_SAME_OPERANDS_AND_RESULT diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h new file mode 100644 index 0000000000000..41363fbe70604 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h @@ -0,0 +1,128 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Abs) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Abs_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Acos) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Acos_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Acosh) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Acosh_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Angle) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argsort) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Asin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Asin_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Asinh) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Asinh_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Assign) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Assign_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atan) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atan_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atanh) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Atanh_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Bernoulli) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BitwiseNot) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BitwiseNot_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cast) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cast_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Ceil) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Ceil_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conj) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cos) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cos_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cosh) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cosh_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Digamma) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Digamma_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Dirichlet) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Equal) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Equal_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Erf) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Erf_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Erfinv) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Erfinv_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exp_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Expm1) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Expm1_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Exponential_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Fetch) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Flip) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Floor) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Floor_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Imag) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Increment) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Increment_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isinf) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(IsinfSr) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Isnan) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(IsnanSr) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lgamma) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lgamma_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log1p) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log1p_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Log_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalNot) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(LogicalNot_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logit_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Poisson) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pow_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Print) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(PutAlongAxis) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(PutAlongAxis_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Real) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Relu_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Roll) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Round) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Round_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rsqrt) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Rsqrt_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scale) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScaleSr) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScaleSr_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scale_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ScatterNdAdd) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Scatter_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sinh) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sinh_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Softmax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Softmax_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tan) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tan_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tanh) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tanh_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tril) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tril_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Triu) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Triu_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Trunc) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Trunc_) + +} // namespace paddle::dialect + +namespace cinn::dialect { +using paddle::dialect::ScaleOpInferSymbolicShape; +} diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc new file mode 100644 index 0000000000000..cdbb016158b23 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -0,0 +1,1070 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_slice_utils.h" +#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_sym_utils.h" + +namespace paddle::dialect { + +bool ArgmaxOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + bool flatten = GetBoolAttr(op, "flatten"); + bool keepdims = GetBoolAttr(op, "keepdims"); + + const auto &input_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + + const auto &axis_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + int axis = + static_cast(axis_shape_or_data.data().value()[0].Get()); + + const std::vector &input_sym_shape = + input_shape_or_data.data().has_value() + ? input_shape_or_data.data().value() + : input_shape_or_data.shape(); + + int rank = input_sym_shape.size(); + if (axis < 0) axis += rank; + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + if (flatten) { + if (keepdims) { + out_sym_shape.emplace_back(std::int64_t(rank)); + } else { + out_sym_shape.emplace_back(std::int64_t(0)); + } + } else { + for (int i = 0; i < axis; i++) { + out_sym_shape.emplace_back(input_sym_shape[i]); + } + if (keepdims) { + out_sym_shape.emplace_back(std::int64_t(1)); + } + + for (int i = axis + 1; i < rank; i++) { + out_sym_shape.emplace_back(input_sym_shape[i]); + } + } + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + +bool ArgminOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ArgmaxOpInferSymbolicShape(op, shape_analysis); +} + +bool AsComplexOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + const std::vector out_dims = [&] { + std::vector out_dims = operand_shape_or_data.shape(); + out_dims.pop_back(); + return out_dims; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} +bool AsRealOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + const std::vector out_dims = [&] { + std::vector out_dims = operand_shape_or_data.shape(); + out_dims.push_back(symbol::DimExpr(2)); + return out_dims; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + +bool CummaxOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + shape_analysis->SetShapeOrDataForValue(op->result(1), operand_shape_or_data); + return true; +} +bool CumminOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return CummaxOpInferSymbolicShape(op, shape_analysis); +} +bool CumprodOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + shape_analysis->SetShapeOrDataForValue(op->result(0), operand_shape_or_data); + return true; +} +bool Cumprod_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return CumprodOpInferSymbolicShape(op, shape_analysis); +} +bool CumsumOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + bool flatten = GetBoolAttr(op, "flatten"); + if (flatten) { + symbol::DimExpr product{1}; + const auto &dim_exprs = operand_shape_or_data.shape(); + for (const auto &dim_expr : dim_exprs) { + product = product * dim_expr; + } + const std::vector out_dims = {product}; + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + } else { + shape_analysis->SetShapeOrDataForValue(op->result(0), + operand_shape_or_data); + } + return true; +} +bool Cumsum_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return CumsumOpInferSymbolicShape(op, shape_analysis); +} + +bool DiagEmbedOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const auto &attributes = op->attributes(); + int dim1 = attributes.at("dim1").dyn_cast().data(); + int dim2 = attributes.at("dim2").dyn_cast().data(); + int offset = attributes.at("offset").dyn_cast().data(); + + const auto &x_dims = operand_shape_or_data.shape(); + int dim1_ = dim1 < 0 ? x_dims.size() + dim1 + 1 : dim1; + int dim2_ = dim2 < 0 ? x_dims.size() + dim2 + 1 : dim2; + int64_t offset_ = static_cast(std::abs(offset)); + symbol::DimExpr new_dim_len = + symbol::DimExpr(offset_) + x_dims[x_dims.size() - 1]; + + const auto &out_dims = [&] { + std::vector out_dims = x_dims; + out_dims.pop_back(); + out_dims.insert(out_dims.begin() + std::min(dim1_, dim2_), new_dim_len); + out_dims.insert(out_dims.begin() + std::max(dim1_, dim2_), new_dim_len); + return out_dims; + }(); + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} +bool DiagonalOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const auto &attributes = op->attributes(); + int axis1 = attributes.at("axis1").dyn_cast().data(); + int axis2 = attributes.at("axis2").dyn_cast().data(); + int offset = attributes.at("offset").dyn_cast().data(); + + const auto &x_dims = operand_shape_or_data.shape(); + int axis1_ = axis1 < 0 ? x_dims.size() + axis1 : axis1; + int axis2_ = axis2 < 0 ? x_dims.size() + axis2 : axis2; + + auto out_dims = x_dims; + auto axis1_size = out_dims[axis1_]; + auto axis2_size = out_dims[axis2_]; + out_dims.erase(out_dims.begin() + std::max(axis1_, axis2_)); + out_dims.erase(out_dims.begin() + std::min(axis1_, axis2_)); + + symbol::DimExprBuilder builder{nullptr}; + symbol::DimExpr zero{0}; + symbol::DimExpr res_shape; + symbol::DimExpr offset_sym{offset}; + if (offset == 0) { + res_shape = builder.Min(axis1_size, axis2_size); + } else if (offset > 0) { + if (axis2_size.isa()) { + res_shape = (axis2_size.dyn_cast() - offset) > 0 + ? builder.Min(axis1_size, axis2_size - offset_sym) + : zero; + } else { + res_shape = shape_analysis->GetNextSymName(); + } + } else { + if (axis1_size.isa()) { + res_shape = (axis1_size.dyn_cast() + offset) > 0 + ? builder.Min(axis1_size + offset_sym, axis2_size) + : zero; + } else { + res_shape = shape_analysis->GetNextSymName(); + } + } + out_dims.push_back(res_shape); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + return true; +} + +bool EinsumOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool KthvalueOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const auto &attributes = op->attributes(); + int axis = attributes.at("axis").dyn_cast().data(); + bool keepdim = GetBoolAttr(op, "keepdim"); + + const auto &input_dims = operand_shape_or_data.shape(); + const int &dim_size = input_dims.size(); + if (axis < 0) axis += dim_size; + std::vector out_dims; + for (int i = 0; i < axis; i++) { + out_dims.emplace_back(input_dims[i]); + } + if (keepdim && dim_size > 0) { + out_dims.emplace_back(symbol::DimExpr(1)); + } + for (int i = axis + 1; i < dim_size; i++) { + out_dims.emplace_back(input_dims[i]); + } + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data); + return true; +} + +bool LogcumsumexpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + // same as CumsumOpInferSymbolicShape + return CumsumOpInferSymbolicShape(op, shape_analysis); +} + +bool LogsumexpOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + bool keepdim = GetBoolAttr(op, "keepdim"); + std::vector axis = details::GetVectorAttr(op, "axis"); + bool reduce_all = axis.size() == 0 ? true : false; + return details::ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); +} + +bool MaxOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + bool keepdim = GetBoolAttr(op, "keepdim"); + + const std::vector axis = [&] { + pir::Operation *axis_gen_op = op->operand_source(1).defining_op(); + std::vector axis_vec; + if (axis_gen_op->isa()) { + axis_vec = details::GetVectorAttr( + axis_gen_op->dyn_cast(), "value"); + } else { + // TODO(lanxianghit): there's other source: pir::VectorType, + // paddle::dialect::DenseTensorType, but after PRIM, maybe always + // FullIntArrayOp, to be confirmed + PADDLE_THROW( + phi::errors::Unimplemented("MaxOpInferSymbolicShape: 'axis' only " + "support FullIntArrayOp's result now.")); + } + return axis_vec; + }(); + + bool reduce_all = axis.size() == 0 ? true : false; + + return details::ReduceInferDim(op, shape_analysis, axis, keepdim, reduce_all); +} + +bool MinOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + return MaxOpInferSymbolicShape(op, shape_analysis); +} + +bool PadOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool ProdOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + bool keepdim = GetBoolAttr(op, "keep_dim"); + bool reduce_all = GetBoolAttr(op, "reduce_all"); + + auto axis_gen_op = op->operand_source(1).defining_op(); + if (axis_gen_op->isa()) { + std::vector axis = details::GetVectorAttr( + axis_gen_op->dyn_cast(), "value"); + return details::ReduceInferDim( + op, shape_analysis, axis, keepdim, reduce_all); + } else { + // TODO(lanxianghit): deal with other source: pir::VectorType, + // paddle::dialect::DenseTensorType + PADDLE_THROW( + phi::errors::Unimplemented("ProdOpInferSymbolicShape: 'axis' only " + "support FullIntArrayOp's result now.")); + } + + return true; +} + +bool RepeatInterleaveOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + + const auto &attributes = op->attributes(); + int repeats = attributes.at("repeats").dyn_cast().data(); + // what should I do if axis is null + int axis = attributes.at("axis").dyn_cast().data(); + + const std::vector &in_dims_sym = [&] { + std::vector dims; + if (operand_shape_or_data.data().has_value()) { + dims = operand_shape_or_data.data().value(); + } else { + dims = operand_shape_or_data.shape(); + } + return dims; + }(); + + int x_rank = in_dims_sym.size(); + if (axis < 0) axis += x_rank; + + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + for (int i = 0; i < x_rank; i++) { + if (i == axis) { + out_sym_shape.push_back(in_dims_sym[i] * repeats); + } else { + out_sym_shape.push_back(in_dims_sym[i]); + } + } + return out_sym_shape; + }(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}); + + return true; +} + +symbol::ShapeOrDataDimExprs CreateShapeOrDataForXShape( + const symbol::ShapeOrDataDimExprs &x_shape) { + const std::vector result = [&] { + std::vector new_x_dims; + new_x_dims.reserve(x_shape.shape().size() + 1); + new_x_dims.push_back(symbol::DimExpr{0}); + new_x_dims.insert( + new_x_dims.end(), x_shape.shape().begin(), x_shape.shape().end()); + return new_x_dims; + }(); + return symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(result)}; +} + +bool ReshapeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + const symbol::ShapeOrDataDimExprs &x_dim_expr = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const symbol::ShapeOrDataDimExprs &shape_dim_expr = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + if (x_dim_expr.data().has_value()) { + const auto &shape_data = details::GetExprVecFromData(shape_dim_expr); + auto IsOne = [](const symbol::DimExpr &expr) { + return expr.isa() && expr.dyn_cast() == 1; + }; + if (shape_data.size() == 1 && IsOne(shape_data.at(0))) { + shape_analysis->SetShapeOrDataForValue( + op->result(0), + symbol::TensorShapeOrDataDimExprs(shape_data, + x_dim_expr.data().value())); + return true; + } + } + + const auto &GetProduct = [&](const auto &dim_exprs, const auto &Filter) { + symbol::DimExpr product{1}; + for (const auto &dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + product = product * dim_expr; + } + } + return product; + }; + + const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + + const auto &IsZero = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() == static_cast(0); + } + return false; + }; + + const std::vector out_dims = [&] { + const auto &original_shape = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + + const auto &numel = + GetProduct(original_shape, [](const auto &) { return true; }); + + ExprVec target_shape = details::GetExprVecFromData(shape_dim_expr); + const auto &product_exclude_minus_one = + GetProduct(target_shape, IsNotMinusOne); + + const auto &input_dims = target_shape; + + std::vector out_dims; + out_dims.reserve(input_dims.size()); + for (size_t i = 0; i < input_dims.size(); ++i) { + auto out_dim_expr = IsNotMinusOne(input_dims[i]) + ? input_dims[i] + : (numel / product_exclude_minus_one); + out_dim_expr = IsZero(input_dims[i]) ? original_shape[i] : out_dim_expr; + out_dims.emplace_back(out_dim_expr); + } + + return out_dims; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_dims)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + + const auto UNUSED &x_shape = [&] { + std::vector x_shape{symbol::DimExpr(0)}; + const auto &original_shape = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)).shape(); + for (const auto &dim : original_shape) { + x_shape.push_back(dim); + } + return x_shape; + }(); + shape_analysis->SetShapeOrDataForValue( + op->result(1), + CreateShapeOrDataForXShape( + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)))); + return true; +} + +bool Reshape_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ReshapeOpInferSymbolicShape(op, shape_analysis); +} + +bool ShapeOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + const auto &out_data = operand_shape_or_data.shape(); + const std::vector shape{std::int64_t(out_data.size())}; + symbol::ShapeOrDataDimExprs shape_or_data{ + symbol::TensorShapeOrDataDimExprs(shape, out_data)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_or_data); + return true; +} + +bool ShapeSrOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return ShapeOpInferSymbolicShape(op, shape_analysis); +} + +bool SliceOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_source = op->operand_source(0); + pir::Value operand_starts = op->operand_source(1); + pir::Value operand_ends = op->operand_source(2); + pir::Value res = op->result(0); + + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_source); + const symbol::ShapeOrDataDimExprs &starts_shape_data = + shape_analysis->GetShapeOrDataForValue(operand_starts); + const symbol::ShapeOrDataDimExprs &ends_shape_data = + shape_analysis->GetShapeOrDataForValue(operand_ends); + + std::vector axes_vec = details::GetVectorAttr(op, "axes"); + + // // Currently, we DO NOT support any element in `starts` is a Symbol. + ExprVec starts = slice_utils::GetExprVecFromData(starts_shape_data); + ExprVec ends = slice_utils::GetExprVecFromData(ends_shape_data); + + std::vector infer_flags = details::GetVectorAttr(op, "infer_flags"); + + const std::vector decrease_axis = + details::GetVectorAttr(op, "decrease_axis"); + + shape_analysis->SetShapeOrDataForValue( + res, + slice_utils::SliceRawInferSymbolicShape(operand_shape_or_data, + starts, + ends, + axes_vec, + infer_flags, + decrease_axis)); + + return true; +} + +bool SplitOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + // input + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + PADDLE_ENFORCE_EQ(x_shape_or_data.data().has_value(), + false, + phi::errors::InvalidArgument( + "InferSymbolicShape of SplitOp only support input with " + "value now.")); + const auto &x_dims_sym = x_shape_or_data.shape(); + + // axis + CHECK(op->operand_source(2).defining_op()->isa()); + + int64_t axis = op->operand_source(2) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + + // sections + const std::vector §ions_sym = [&] { + const auto §ions_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + std::vector sections_sym; + if (sections_shape_or_data.data().has_value()) { + sections_sym = sections_shape_or_data.data().value(); + } else { + sections_sym = sections_shape_or_data.shape(); + } + return sections_sym; + }(); + + // output + const symbol::TensorListShapeOrDataDimExprs &output_shape_data_list = [&] { + const auto &GetSum = [&](const auto &dim_exprs, const auto &Filter) { + symbol::DimExpr sum{0}; + for (const auto &dim_expr : dim_exprs) { + if (Filter(dim_expr)) { + sum = sum + dim_expr; + } + } + return sum; + }; + const auto &All = [&](const auto &dim_exprs, const auto &Cond) { + for (const auto &dim_expr : dim_exprs) { + if (!Cond(dim_expr)) { + return false; + } + } + return true; + }; + const auto &IsNotMinusOne = [&](const symbol::DimExpr &dim_expr) { + if (dim_expr.isa()) { + return dim_expr.dyn_cast() != static_cast(-1); + } + return true; + }; + const auto &sum_exclude_minus_one = GetSum(sections_sym, IsNotMinusOne); + + const bool &all_sections_sym_not_minus_one = + All(sections_sym, IsNotMinusOne); + if (all_sections_sym_not_minus_one) { + shape_analysis->DimExprBuilder().CstrEq(x_dims_sym[axis], + sum_exclude_minus_one); + } + + symbol::TensorListShapeOrDataDimExprs shape_data_list; + std::vector output_dims_sym = x_dims_sym; + if (!all_sections_sym_not_minus_one && sections_sym.size() == 1) { + VLOG(3) << "[SplitOp]-1 is the only split section. The output shape is " + "identical to the input shape."; + shape_data_list.push_back( + symbol::TensorShapeOrDataDimExprs(output_dims_sym)); + return shape_data_list; + } + for (uint32_t idx = 0; idx < sections_sym.size(); idx++) { + const auto §ion_sym = sections_sym[idx]; + output_dims_sym[axis] = IsNotMinusOne(section_sym) + ? section_sym + : x_dims_sym[axis] - sum_exclude_minus_one; + + shape_data_list.push_back( + symbol::TensorShapeOrDataDimExprs(output_dims_sym)); + } + return shape_data_list; + }(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), symbol::ShapeOrDataDimExprs{output_shape_data_list}); + + return true; +} + +bool SplitWithNumOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + int64_t axis = op->operand_source(1) + .defining_op() + .attributes() + .at("value") + .dyn_cast() + .data() + .to(); + const auto &attributes = op->attributes(); + int num = attributes.at("num").dyn_cast().data(); + const auto &x_s_or_d = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + int rank = x_s_or_d.shape().size(); + axis = axis < 0 ? axis + rank : axis; + + symbol::DimExpr input_axis_dim = x_s_or_d.shape().at(axis); + symbol::DimExpr axis_shape = input_axis_dim / symbol::DimExpr{num}; + + const auto &out_s_d = [&] { + std::vector out_s_d; + for (size_t i = 0; i < x_s_or_d.shape().size(); ++i) { + const auto &sym_dim = + axis == static_cast(i) ? axis_shape : x_s_or_d.shape()[i]; + out_s_d.push_back(sym_dim); + } + return symbol::TensorShapeOrDataDimExprs(out_s_d); + }(); + + symbol::TensorListShapeOrDataDimExprs outs_s_d(num, out_s_d); + shape_analysis->SetShapeOrDataForValue(op->result(0), + symbol::ShapeOrDataDimExprs{outs_s_d}); + return true; +} + +bool SumOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + bool keepdim = GetBoolAttr(op, "keepdim"); + bool reduce_all = false; + + auto axis_gen_op = op->operand_source(1).defining_op(); + if (axis_gen_op->isa()) { + std::vector axis = details::GetVectorAttr( + axis_gen_op->dyn_cast(), "value"); + if (axis.size() == 0) { + reduce_all = true; + } + return details::ReduceInferDim( + op, shape_analysis, axis, keepdim, reduce_all); + } else { + // TODO(lanxianghit): deal with other source: pir::VectorType, + // paddle::dialect::DenseTensorType + PADDLE_THROW( + phi::errors::Unimplemented("SumOpInferSymbolicShape: 'axis' only " + "support FullIntArrayOp's result now.")); + } + + return true; +} + +bool TileOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Value operand_x = op->operand_source(0); + symbol::ShapeOrDataDimExprs x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_x); + pir::Value operand_repeat_times = op->operand_source(1); + symbol::ShapeOrDataDimExprs repeat_times_shape_or_data = + shape_analysis->GetShapeOrDataForValue(operand_repeat_times); + + std::vector x_dimexpr; + if (x_shape_or_data.data().has_value()) { + x_dimexpr = x_shape_or_data.data().value(); + } else { + x_dimexpr = x_shape_or_data.shape(); + } + + std::vector repeat_times_dimexpr; + if (repeat_times_shape_or_data.data().has_value()) { + repeat_times_dimexpr = repeat_times_shape_or_data.data().value(); + } else { + repeat_times_dimexpr = repeat_times_shape_or_data.shape(); + } + if (repeat_times_dimexpr.empty()) { + repeat_times_dimexpr = std::vector(x_dimexpr.size(), 1); + } + + auto out_rank = std::max(static_cast(x_dimexpr.size()), + repeat_times_dimexpr.size()); + std::vector out_shape(out_rank); + if (x_dimexpr.size() > repeat_times_dimexpr.size()) { + auto diff = x_dimexpr.size() - repeat_times_dimexpr.size(); + repeat_times_dimexpr.insert(repeat_times_dimexpr.begin(), diff, 1); + } else { + auto diff = repeat_times_dimexpr.size() - x_dimexpr.size(); + x_dimexpr.insert(x_dimexpr.begin(), diff, 1); + } + + for (size_t i = 0; i < repeat_times_dimexpr.size(); ++i) { + out_shape[i] = x_dimexpr[i] * repeat_times_dimexpr[i]; + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_shape)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + + return true; +} + +bool TopkOpInferSymbolicShape(pir::Operation *op, + pir::ShapeConstraintIRAnalysis *shape_analysis) { + symbol::ShapeOrDataDimExprs x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + symbol::ShapeOrDataDimExprs k_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + const auto &attributes = op->attributes(); + int axis = attributes.at("axis").dyn_cast().data(); + const std::vector &in_dims_sym = [&] { + std::vector dims; + if (x_shape_or_data.data().has_value()) { + dims = x_shape_or_data.data().value(); + } else { + dims = x_shape_or_data.shape(); + } + return dims; + }(); + + int x_rank = in_dims_sym.size(); + + int k = k_shape_or_data.data().value()[0].Get(); + + if (axis < 0) axis += x_rank; + const auto &out_sym_shape = [&] { + std::vector out_sym_shape; + for (int i = 0; i < x_rank; ++i) { + if (i == axis) { + out_sym_shape.push_back(symbol::DimExpr(k)); + } else { + out_sym_shape.push_back(in_dims_sym[i]); + } + } + return out_sym_shape; + }(); + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(out_sym_shape)}; + + shape_analysis->SetShapeOrDataForValue(op->result(0), shape_data); + shape_analysis->SetShapeOrDataForValue(op->result(1), shape_data); + + return true; +} + +bool TransposeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + std::vector perm = + op->attributes().at("perm").dyn_cast().AsVector(); + if (perm.size() == 1) { + // perm must be [0], which means nothing to do with input, just copy the + // info from input + shape_analysis->SetShapeOrDataForValue( + op->result(0), + shape_analysis->GetShapeOrDataForValue(op->operand_source(0))); + return true; + } + const std::vector &x_dims = [&] { + std::vector dims; + const auto &x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + if (x_shape_or_data.data().has_value()) { + dims = x_shape_or_data.data().value(); + } else { + dims = x_shape_or_data.shape(); + } + return dims; + }(); + + int x_rank = x_dims.size(); + + const std::vector formatted_axis = [x_rank, &perm] { + std::vector out(perm.size(), 0); + std::transform(perm.begin(), + perm.end(), + out.begin(), + [](pir::Attribute &p) -> int32_t { + return p.dyn_cast().data(); + }); + + // format the negative axis + std::for_each(out.begin(), out.end(), [x_rank](int32_t &v) { + if (v < 0) { + v += x_rank; + } + }); + return out; + }(); + + int axis_size = static_cast(formatted_axis.size()); + + std::vector out_dims(x_dims); + for (int i = 0; i < axis_size; ++i) { + out_dims[i] = x_dims[formatted_axis[i]]; + } + + shape_analysis->SetShapeOrDataForValue(op->result(0), + ShapeOrData{TensorExprs(out_dims)}); + + return true; +} + +bool Transpose_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return TransposeOpInferSymbolicShape(op, shape_analysis); +} + +bool SqueezeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + IR_ENFORCE(op->num_operands() == 2, + "SqueezeOpInferSymbolicShape ONLY support num_operands() == 2 " + "now, but got %d operands", + op->num_operands()); + + auto x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + auto axes_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + + std::vector in_dims_sym; + if (x_shape_or_data.data().has_value()) { + in_dims_sym = x_shape_or_data.data().value(); + } else { + in_dims_sym = x_shape_or_data.shape(); + } + + std::vector squeeze_dims_sym; + if (axes_shape_or_data.data().has_value()) { + squeeze_dims_sym = axes_shape_or_data.data().value(); + } else { + squeeze_dims_sym = axes_shape_or_data.shape(); + } + + std::vector squeeze_dims; + for (auto squeeze_dim : squeeze_dims_sym) { + IR_ENFORCE(squeeze_dim.Has(), + "in SqueezeOpInferSymbolicShape, axes must be known int type, " + "but got: %s", + symbol::ToString(squeeze_dim)); + squeeze_dims.emplace_back( + static_cast(squeeze_dim.Get())); + } + + // GetOutputSqueezeShape + size_t num_squeeze_dims = squeeze_dims.size(); + std::vector should_squeeze(in_dims_sym.size(), false); + // Mark dimensions need to be squeezed. + if (num_squeeze_dims == 0) { + for (size_t i = 0; i < in_dims_sym.size(); ++i) { + // TODO(lanxianghit): if symbol here, maybe we need the result of dim expr + // simplification + if (in_dims_sym[i] == 1) { + should_squeeze[i] = true; + } + } + } else { + for (size_t i = 0; i < num_squeeze_dims; ++i) { + if (in_dims_sym.size() == 0) { + continue; + } + int current = squeeze_dims[i] < 0 ? squeeze_dims[i] + in_dims_sym.size() + : squeeze_dims[i]; + + if (!should_squeeze[current]) { + // At compile time, dim of SYMBOL is allowed to squeeze? + if (in_dims_sym[current] == 1) { + should_squeeze[current] = true; + } else if (!in_dims_sym[current].Has()) { + should_squeeze[current] = true; + } else { + should_squeeze[current] = true; + } + } + } + } + + // Make output dimensions + std::vector output_shape_sym; + for (size_t i = 0; i < in_dims_sym.size(); ++i) { + if (!should_squeeze[i]) { + output_shape_sym.emplace_back(in_dims_sym[i]); + } + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(output_shape_sym)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + shape_analysis->SetShapeOrDataForValue( + op->result(1), CreateShapeOrDataForXShape(x_shape_or_data)); + + return true; +} +bool Squeeze_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return SqueezeOpInferSymbolicShape(op, shape_analysis); +} + +bool UnbindOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool UniqueOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool UniqueConsecutiveOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + PADDLE_THROW(phi::errors::Unimplemented( + op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); + return true; +} + +bool UnsqueezeOpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + IR_ENFORCE(op->num_operands() == 2, + "UnsqueezeOp InferSymbolicShape ONLY support num_operands() == 2 " + "now, but got %d operands", + op->num_operands()); + + auto x_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)); + auto axes_shape_or_data = + shape_analysis->GetShapeOrDataForValue(op->operand_source(1)); + + std::vector x_sym_shape; + if (x_shape_or_data.data().has_value()) { + x_sym_shape = x_shape_or_data.data().value(); + } else { + x_sym_shape = x_shape_or_data.shape(); + } + int x_dims_size = x_sym_shape.size(); + + std::vector axes_sym; + if (axes_shape_or_data.data().has_value()) { + axes_sym = axes_shape_or_data.data().value(); + } else { + axes_sym = axes_shape_or_data.shape(); + } + int axes_sym_size = axes_sym.size(); + + // GetUnsqueezeShape + int output_rank = x_dims_size + axes_sym_size; + std::vector result_sym_dims(output_rank, 0); + + int cur_output_rank = x_dims_size; + for (auto axis_expr : axes_sym) { + IR_ENFORCE(axis_expr.Has(), + "in UnsqueezeOpInferSymbolicShape, axes must be known int type, " + "but got: %s", + symbol::ToString(axis_expr)); + int axis = static_cast(axis_expr.Get()); + int cur = axis < 0 ? axis + cur_output_rank + 1 : axis; + + // Move old axis, and insert new axis + for (int i = cur_output_rank; i >= cur; --i) { + if (result_sym_dims[i] == 1) { + // Move axis + result_sym_dims[i + 1] = 1; + result_sym_dims[i] = 0; + } + } + result_sym_dims[cur] = 1; + // Add the output size. + cur_output_rank++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_rank; ++out_idx) { + if (result_sym_dims[out_idx] == 0) { + result_sym_dims[out_idx] = x_sym_shape[in_idx++]; + } + } + + symbol::ShapeOrDataDimExprs shape_data{ + symbol::TensorShapeOrDataDimExprs(result_sym_dims)}; + + pir::Value res = op->result(0); + shape_analysis->SetShapeOrDataForValue(res, shape_data); + shape_analysis->SetShapeOrDataForValue( + op->result(1), CreateShapeOrDataForXShape(x_shape_or_data)); + + return true; +} +bool Unsqueeze_OpInferSymbolicShape( + pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + return UnsqueezeOpInferSymbolicShape(op, shape_analysis); +} + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h new file mode 100644 index 0000000000000..2b7cd2c3cf4f9 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -0,0 +1,61 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" + +namespace paddle::dialect { +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsComplex) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsReal) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummax) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummin) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(DiagEmbed) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Diagonal) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Einsum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Max) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Min) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Prod) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(RepeatInterleave) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Reshape_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Shape) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShapeSr) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Slice) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Split) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(SplitWithNum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Squeeze_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sum) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Tile) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Topk) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Transpose) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Transpose_) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unbind) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unique) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(UniqueConsecutive) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unsqueeze) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Unsqueeze_) + +} // namespace paddle::dialect diff --git a/paddle/fluid/pir/dialect/operator/interface/infermeta.h b/paddle/fluid/pir/dialect/operator/interface/infermeta.h index bd6d1f7d42013..d5197af5be94f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infermeta.h +++ b/paddle/fluid/pir/dialect/operator/interface/infermeta.h @@ -25,13 +25,12 @@ class InferMetaInterface : public pir::OpInterfaceBase { struct Concept { explicit Concept(void (*infer_meta)(phi::InferMetaContext *), std::vector (*infer_meta_by_value)( - const std::vector &, - const pir::AttributeMap &)) + const std::vector &, pir::AttributeMap *)) : infer_meta_(infer_meta), infer_meta_by_value_(infer_meta_by_value) {} void (*infer_meta_)(phi::InferMetaContext *); std::vector (*infer_meta_by_value_)( - const std::vector &, const pir::AttributeMap &); + const std::vector &, pir::AttributeMap *); }; template @@ -41,8 +40,8 @@ class InferMetaInterface : public pir::OpInterfaceBase { } static inline std::vector InferMetaByValue( const std::vector &input_values, - const pir::AttributeMap &attributes) { - return ConcreteOp::InferMeta(input_values, attributes); + pir::AttributeMap *p_attributes) { + return ConcreteOp::InferMeta(input_values, p_attributes); } Model() : Concept(InferMeta, InferMetaByValue) {} }; @@ -56,8 +55,8 @@ class InferMetaInterface : public pir::OpInterfaceBase { } std::vector InferMeta(const std::vector &input_values, - const pir::AttributeMap &attributes) { - return impl_->infer_meta_by_value_(input_values, attributes); + pir::AttributeMap *p_attributes) { + return impl_->infer_meta_by_value_(input_values, p_attributes); } private: diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.cc b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.cc index 5469237524880..3ef55f41c264b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.cc +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.cc @@ -32,6 +32,14 @@ KernelKeyTuple SaveCombineOpParseKernelKey(pir::Operation* op) { return {phi::DataType::FLOAT32, phi::Backend::UNDEFINED}; } +KernelKeyTuple NopOpParseKernelKey(pir::Operation* op) { + return {phi::DataType::FLOAT32, phi::Backend::UNDEFINED}; +} + +KernelKeyTuple Nop_OpParseKernelKey(pir::Operation* op) { + return {phi::DataType::FLOAT32, phi::Backend::UNDEFINED}; +} + } // namespace paddle::dialect IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ParseKernelKeyInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h index 7913893fdb7d7..0da0ea073486f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -59,6 +59,10 @@ KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op); KernelKeyTuple SaveCombineOpParseKernelKey(pir::Operation *op); +KernelKeyTuple NopOpParseKernelKey(pir::Operation *op); + +KernelKeyTuple Nop_OpParseKernelKey(pir::Operation *op); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index 7f490cdd24f8a..f674c35096018 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -575,14 +575,6 @@ void WhileOp::VerifySig() { phi::errors::PreconditionNotMet( "Type validation failed for the 0th input, it should be a " "bool DenseTensorType.")); - } else if (auto cond_type = - operand_type(0).dyn_cast()) { - PADDLE_ENFORCE_EQ( - cond_type.dtype().isa(), - true, - phi::errors::PreconditionNotMet( - "Type validation failed for the 0th input, it should be a " - "bool DenseTensorType.")); } else { PADDLE_THROW(phi::errors::PreconditionNotMet( "Currently, the while op cond input only support bool dense_tensor " @@ -746,6 +738,46 @@ bool WhileOp::InferSymbolicShape( pir::InferSymExprForBlock(body(), shape_analysis); + // add constraints for args + const auto &body_args = block_args(); + for (size_t i = 0; i < body_args.size(); ++i) { + const auto &input_arg_shape = + shape_analysis->GetShapeOrDataForValue(body_args[i]).shape(); + const auto &yield_value_shape = + shape_analysis + ->GetShapeOrDataForValue(body().back().operand_source(i + 1)) + .shape(); + PADDLE_ENFORCE_EQ(input_arg_shape.size(), + yield_value_shape.size(), + phi::errors::InvalidArgument( + "while op's input[%d] rank should equal to " + "output[%d]'s rank, Now the rank of input is %d," + "the rank of output is %d.", + i, + i + 1, + input_arg_shape.size(), + yield_value_shape.size())); + const auto &original_input_shape = + shape_analysis->GetShapeOrDataForValue(operand_source(i + 1)).shape(); + for (size_t j = 0; j < input_arg_shape.size(); ++j) { + if (input_arg_shape[j].isa()) { + continue; + } + if (input_arg_shape[j] == + yield_value_shape[j]) { // Dim isn't changed in while + shape_analysis->DimExprBuilder().CstrEq(original_input_shape[j], + input_arg_shape[j]); + continue; + } + if (original_input_shape.size() == yield_value_shape.size() && + original_input_shape[j] == yield_value_shape[j]) { + shape_analysis->DimExprBuilder().CstrEq(original_input_shape[j], + input_arg_shape[j]); + continue; + } + } + } + const auto &last_op = body().back(); for (size_t i = 1; i < last_op.operands_source().size(); ++i) { shape_analysis->SetShapeOrDataForValue( @@ -765,11 +797,11 @@ std::vector> TuplePushOpVjpInterfaceModel::Vjp( PADDLE_ENFORCE_EQ( inputs.size() >= 1u, true, - phi::errors::InvalidArgument( - "tupe_push op's inputs' size should be greater_equal than 1, and the " - "inputs[i] should be non-empty. " - "Now the inputs's size is %d.", - inputs.size())); + phi::errors::InvalidArgument("tuple_push op's inputs' size should be " + "greater_equal than 1, and the " + "inputs[i] should be non-empty. " + "Now the inputs's size is %d.", + inputs.size())); auto pop_op = ApiBuilder::Instance().GetBuilder()->Build( TuplePushOp::dyn_cast(op).outlet()); std::vector> res{inputs.size()}; @@ -803,8 +835,7 @@ void HasElementsOp::VerifySig() { // Verify outputs: IR_ENFORCE(num_results() == 1u, "The size of outputs must be equal to 1."); - IR_ENFORCE((*this)->result_type(0).isa() || - (*this)->result_type(0).isa(), + IR_ENFORCE((*this)->result_type(0).isa(), "The type of cf.has_elements' output is not correct."); } @@ -874,8 +905,7 @@ void AssertOp::VerifySig() { (*this)->operand(1).type().dyn_cast()) { for (size_t i = 0; i < vec_type.size(); ++i) { IR_ENFORCE(vec_type[i].isa() || - vec_type[i].isa() || - vec_type[i].isa(), + vec_type[i].isa(), "Type validation failed for the 1th input."); } } else { @@ -885,7 +915,6 @@ void AssertOp::VerifySig() { ->operand(1) .type() .isa(), - (*this)->operand(1).type().isa(), "Type validation failed for the 1th input."); } } @@ -999,19 +1028,20 @@ bool SelectInputOp::InferSymbolicShape( const auto &input1_dims = GetSymExprForValue(operand_source(0)); const auto &input2_dims = GetSymExprForValue(operand_source(1)); + // for compatibility, we just return second_shape. + if (input1_dims.size() != input2_dims.size()) { + shape_analysis->SetShapeOrDataForValue( + result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(input2_dims)}); + return true; + } + std::vector out_dims = input1_dims; // merge shape for input1 and input2, since we don't know which will be // selected in compile time, the strategy is same with IfOp, see IfOp's // comments for details and examples if (input2_dims.size() != 0) { - // now only support input1 and input2 have same rank. - PADDLE_ENFORCE_EQ(input1_dims.size(), - input2_dims.size(), - phi::errors::PreconditionNotMet( - "The true and false block should have same rank, " - "but got true_rank(%d) and false_rank(%d)", - input1_dims.size(), - input2_dims.size())); for (size_t i = 0; i < input1_dims.size(); i++) { if (input1_dims[i] != input2_dims[i]) { out_dims[i] = symbol::DimExpr{shape_analysis->GetNextSymName()}; diff --git a/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h b/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h index 37000c86b5b65..856ddb2f7542c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h +++ b/paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h @@ -14,6 +14,8 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" @@ -87,5 +89,14 @@ class IrSelectedRows size_t offset_{0}; }; +inline SelectedRowsType CvtToSelectedRowsType(const IrSelectedRows& ir_tensor) { + return SelectedRowsType::get(pir::IrContext::Instance(), + TransToIrDataType(ir_tensor.dtype()), + ir_tensor.dims(), + ir_tensor.layout(), + ir_tensor.lod(), + ir_tensor.offset()); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h b/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h index e2c3229b04df0..45847d3080387 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h +++ b/paddle/fluid/pir/dialect/operator/ir/ir_tensor.h @@ -14,9 +14,11 @@ #pragma once +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/phi/core/allocator.h" #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/pir/include/core/builtin_type.h" namespace paddle { namespace dialect { @@ -81,10 +83,19 @@ class IrTensor : public phi::TensorBase, private: phi::DDim dims_; phi::DataType dtype_{phi::DataType::FLOAT32}; - phi::DataLayout layout_{phi::DataLayout::ANY}; + phi::DataLayout layout_{phi::DataLayout::NCHW}; LoD lod_; size_t offset_{0}; }; +inline pir::DenseTensorType CvtToDenseTensorType(const IrTensor& ir_tensor) { + return pir::DenseTensorType::get(pir::IrContext::Instance(), + TransToIrDataType(ir_tensor.dtype()), + ir_tensor.dims(), + ir_tensor.layout(), + ir_tensor.lod(), + ir_tensor.offset()); +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc index 3dedf0b14da3f..9228c85c13011 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_api.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_api.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/dialect/operator/ir/manual_api.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_api.h" @@ -63,8 +64,17 @@ void set_parameter(const pir::Value& parameter, const std::string& name) { } void shadow_output(const pir::Value& persist_value, const std::string& name) { - ApiBuilder::Instance().GetBuilder()->Build(persist_value, - name); + auto& builder = ApiBuilder::Instance().GetBuilder(); + auto op = builder->Build(persist_value, name); + if (auto dist_interface = + persist_value.type().dyn_cast()) { + op->set_attribute( + kAttrOpDistAttr, + OperationDistAttribute::get(builder->ir_context(), + dist_interface.process_mesh_attr(), + {dist_interface.tensor_dist_attr()}, + {})); + } } pir::Value embedding_grad(const pir::Value& x, diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc index 352677f0047c8..4e4b7f46b382c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.cc @@ -18,7 +18,6 @@ paddle::onednn::dialect::ExpandOp #include "paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" -#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/paddle_op_infer_sym.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h" #include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h" @@ -114,7 +113,7 @@ void ExpandOp::Build(pir::Builder& builder, argument_attributes.insert({"mkldnn_data_type", attr_mkldnn_data_type}); std::vector argument_outputs = - ExpandOp::InferMeta(argument_inputs, argument_attributes); + ExpandOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); } @@ -157,7 +156,7 @@ void ExpandOp::Build(pir::Builder& builder, argument_attributes.insert({"mkldnn_data_type", attr_mkldnn_data_type}); std::vector argument_outputs = - ExpandOp::InferMeta(argument_inputs, argument_attributes); + ExpandOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); } @@ -181,7 +180,7 @@ void ExpandOp::Build(pir::Builder& builder, argument_attributes.insert({"mkldnn_data_type", attr_mkldnn_data_type}); std::vector argument_outputs = - ExpandOp::InferMeta(argument_inputs, argument_attributes); + ExpandOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); } @@ -244,7 +243,11 @@ void ExpandOp::InferMeta(phi::InferMetaContext* infer_meta) { std::vector ExpandOp::InferMeta( const std::vector& input_values, - const pir::AttributeMap& attributes) { + pir::AttributeMap* p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", input_values.size()); @@ -256,15 +259,6 @@ std::vector ExpandOp::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_x = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_x.dtype(), - allocated_x.dims(), - allocated_x.data_layout(), - allocated_x.lod(), - allocated_x.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -273,22 +267,22 @@ std::vector ExpandOp::InferMeta( phi::IntArray shape; if (shape_.defining_op()->isa()) { - shape = std::move(phi::IntArray(paddle::dialect::GetInt64Vector( + shape = phi::IntArray(paddle::dialect::GetInt64Vector( shape_.defining_op() ->dyn_cast() - .attribute("value")))); + .attribute("value"))); } else if (shape_.type().isa()) { size_t shape_size = shape_.type().dyn_cast().size(); // In ExpandInferMeta use -2 to represent the element in expand_shape is a // var. - shape = std::move(phi::IntArray(std::vector(shape_size, -2))); + shape = phi::IntArray(std::vector(shape_size, -2)); shape.SetFromTensor(true); } else if (shape_.type().isa()) { size_t shape_size = common::product( shape_.type().dyn_cast().dims()); // In ExpandInferMeta use -2 to represent the element in expand_shape is a // var. - shape = std::move(phi::IntArray(std::vector(shape_size, -2))); + shape = phi::IntArray(std::vector(shape_size, -2)); shape.SetFromTensor(true); } else { PADDLE_THROW(phi::errors::Unimplemented( @@ -334,8 +328,9 @@ phi::DataType ExpandOp::GetKernelTypeForVar( bool ExpandOp::InferSymbolicShape( pir::ShapeConstraintIRAnalysis* shape_analysis) { VLOG(4) << "Infer symbolic shape for op: ExpandOp"; - return paddle::dialect::ExpandOpInferSymbolicShape(this->operation(), - shape_analysis); + PADDLE_THROW(phi::errors::Unimplemented( + " ExpandOp's InferSymbolicShape interface is NOT implemented now.")); + return true; } } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h index 3c8050480ade9..58f15f5582e65 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_onednn_op.h @@ -84,7 +84,7 @@ class ExpandOp : public pir::Op InferMeta( const std::vector& input_values, - const pir::AttributeMap& attributes); + pir::AttributeMap* p_attributes); // NOLINT }; } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 1f645b0a29d66..c5dc4457b737e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -13,8 +13,7 @@ // limitations under the License. #ifdef GET_OP_LIST #undef GET_OP_LIST -paddle::dialect::AddNOp, paddle::dialect::AddN_Op, - paddle::dialect::AddNWithKernelOp, paddle::dialect::AddNArrayOp, +paddle::dialect::AddNOp, paddle::dialect::AddN_Op, paddle::dialect::AddNArrayOp, paddle::dialect::FusedGemmEpilogueOp, paddle::dialect::AssignOut_Op, paddle::dialect::FusedGemmEpilogueGradOp, paddle::dialect::SplitGradOp, paddle::dialect::ExpandOp, paddle::dialect::CreateArrayOp, @@ -134,7 +133,7 @@ void AddNOp::Build(pir::Builder &builder, // NOLINT VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - AddNOp::InferMeta(argument_inputs, argument_attributes); + AddNOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -147,7 +146,7 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AddNOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta AddNOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -167,16 +166,6 @@ std::vector AddNOp::InferMeta( x[i].dyn_cast().data_layout(), x[i].dyn_cast().lod(), x[i].dyn_cast().offset())); - } else if (x[i].isa()) { - vec_dense_x.push_back(paddle::dialect::IrTensor( - TransToPhiDataType( - x[i].dyn_cast() - .dtype()), - x[i].dyn_cast().dims(), - x[i].dyn_cast() - .data_layout(), - x[i].dyn_cast().lod(), - x[i].dyn_cast().offset())); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -196,7 +185,7 @@ std::vector AddNOp::InferMeta( paddle::dialect::IrTensor dense_out; paddle::dialect::IrMetaTensor meta_out(&dense_out); - phi::AddNInferMeta(meta_x, &meta_out); + phi::AddNInferMeta(meta_x, &meta_out, phi::MetaConfig(false, false)); std::vector argument_outputs; pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( @@ -240,7 +229,7 @@ void AddN_Op::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - AddN_Op::InferMeta(argument_inputs, argument_attributes); + AddN_Op::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } @@ -303,7 +292,7 @@ void AddN_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AddN_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta AddN_Op"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -322,22 +311,6 @@ std::vector AddN_Op::InferMeta( inputs[i].dyn_cast().data_layout(), inputs[i].dyn_cast().lod(), inputs[i].dyn_cast().offset())); - } else if (inputs[i].isa()) { - vec_dense_inputs.push_back(paddle::dialect::IrTensor( - TransToPhiDataType( - inputs[i] - .dyn_cast() - .dtype()), - inputs[i] - .dyn_cast() - .dims(), - inputs[i] - .dyn_cast() - .data_layout(), - inputs[i].dyn_cast().lod(), - inputs[i] - .dyn_cast() - .offset())); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -358,197 +331,7 @@ std::vector AddN_Op::InferMeta( paddle::dialect::IrTensor dense_out; paddle::dialect::IrMetaTensor meta_out(&dense_out); - phi::AddNInferMeta(meta_inputs, &meta_out); - - std::vector argument_outputs; - pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( - pir::IrContext::Instance(), - paddle::dialect::TransToIrDataType(dense_out.dtype()), - dense_out.dims(), - dense_out.layout(), - dense_out.lod(), - dense_out.offset()); - argument_outputs.push_back(out_dense_tensor_type); - return argument_outputs; -} - -OpInfoTuple AddNWithKernelOp::GetOpInfo() { - std::vector inputs = { - paddle::dialect::OpInputInfo( - "inputs", - "pir::VectorType", - false, - false, - false, - true)}; - std::vector attributes = {}; - std::vector outputs = { - paddle::dialect::OpOutputInfo( - "out", "paddle::dialect::DenseTensorType", false, false)}; - paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo( - "AddNInferMeta", {"inputs"}, "add_n", {"inputs"}, {}, {}, {}, {}); - return std::make_tuple( - inputs, attributes, outputs, run_time_info, "add_n_with_kernel"); -} - -void AddNWithKernelOp::Build(pir::Builder &builder, - pir::OperationArgument &argument, - pir::Value inputs_) { - VLOG(4) << "Start build AddNWithKernelOp"; - - VLOG(4) << "Builder construction inputs"; - std::vector argument_inputs = {inputs_}; - argument.AddInput(inputs_); - - VLOG(4) << "Builder construction attributes"; - pir::AttributeMap argument_attributes = {}; - std::vector argument_outputs = - AddNWithKernelOp::InferMeta(argument_inputs, argument_attributes); - - argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); -} - -void AddNWithKernelOp::VerifySig() { - VLOG(4) << "Start Verifying inputs, outputs and attributes for: " - "AddNWithKernelOp."; - VLOG(4) << "Verifying inputs:"; - { - auto input_size = num_operands(); - PADDLE_ENFORCE_EQ( - input_size, - 1u, - phi::errors::PreconditionNotMet( - "The size %d of inputs must be equal to 1.", input_size)); - if (auto vec_type = - (*this)->operand_source(0).type().dyn_cast()) { - for (size_t i = 0; i < vec_type.size(); ++i) { - PADDLE_ENFORCE(vec_type[i].isa() || - vec_type[i].isa(), - phi::errors::PreconditionNotMet( - "Type validation failed for the 0th input.")); - } - } else { - PADDLE_ENFORCE((*this)->operand_source(0) - .type() - .isa() || - (*this) - ->operand_source(0) - .type() - .isa(), - phi::errors::PreconditionNotMet( - "Type validation failed for the 0th input.")); - } - } - VLOG(4) << "Verifying attributes:"; - { - // Attributes num is 0, not need to check attributes type. - } - VLOG(4) << "Verifying outputs:"; - { - auto output_size = num_results(); - PADDLE_ENFORCE_EQ( - output_size, - 1u, - phi::errors::PreconditionNotMet( - "The size %d of outputs must be equal to 1.", output_size)); - PADDLE_ENFORCE( - (*this)->result(0).type().isa() || - (*this)->result(0).type().isa(), - phi::errors::PreconditionNotMet( - "Type validation failed for the 0th output.")); - } - VLOG(4) << "End Verifying for: AddNWithKernelOp."; -} - -void AddNWithKernelOp::InferMeta(phi::InferMetaContext *infer_meta) { - auto fn = PD_INFER_META(phi::AddNInferMeta); - fn(infer_meta); -} - -std::vector AddNWithKernelOp::InferMeta( - const std::vector &input_values, - const pir::AttributeMap &attributes) { - VLOG(4) << "Start infermeta AddNWithKernelOp"; - IR_ENFORCE(input_values.size() == 1, - "Num of inputs is expected to be 1 but got %d.", - input_values.size()); - pir::Value inputs_ = input_values[0]; - - VLOG(4) << "Builder construction outputs"; - pir::VectorType inputs = inputs_.type().dyn_cast(); - std::vector vec_dense_inputs; - for (size_t i = 0; i < static_cast(inputs.size()); i++) { - if (inputs[i].isa()) { - vec_dense_inputs.push_back(paddle::dialect::IrTensor( - paddle::dialect::TransToPhiDataType( - inputs[i].dyn_cast().dtype()), - inputs[i].dyn_cast().dims(), - inputs[i].dyn_cast().data_layout(), - inputs[i].dyn_cast().lod(), - inputs[i].dyn_cast().offset())); - } else if (inputs[i].isa()) { - vec_dense_inputs.push_back(paddle::dialect::IrTensor( - TransToPhiDataType( - inputs[i] - .dyn_cast() - .dtype()), - inputs[i] - .dyn_cast() - .dims(), - inputs[i] - .dyn_cast() - .data_layout(), - inputs[i].dyn_cast().lod(), - inputs[i] - .dyn_cast() - .offset())); - } else if (inputs[i].isa()) { - vec_dense_inputs.push_back(paddle::dialect::IrTensor( - paddle::dialect::TransToPhiDataType( - inputs[i].dyn_cast().dtype()), - inputs[i].dyn_cast().dims(), - inputs[i].dyn_cast().data_layout(), - inputs[i].dyn_cast().lod(), - inputs[i].dyn_cast().offset())); - } else if (inputs[i].isa()) { - vec_dense_inputs.push_back(paddle::dialect::IrTensor( - TransToPhiDataType( - inputs[i] - .dyn_cast() - .dtype()), - inputs[i] - .dyn_cast() - .dims(), - inputs[i] - .dyn_cast() - .data_layout(), - inputs[i] - .dyn_cast() - .lod(), - inputs[i] - .dyn_cast() - .offset())); - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Only support DenseTensorType or AllocatedDenseTensorType or " - "SelectedRowsType or AllocatedSelectedRowsType")); - } - } - - std::vector vec_meta_inputs; - for (size_t i = 0; i < vec_dense_inputs.size(); i++) { - vec_meta_inputs.push_back( - paddle::dialect::IrMetaTensor(&vec_dense_inputs[i])); - } - - std::vector meta_inputs; - for (size_t i = 0; i < static_cast(vec_meta_inputs.size()); i++) { - meta_inputs.push_back(&vec_meta_inputs[i]); - } - paddle::dialect::IrTensor dense_out; - paddle::dialect::IrMetaTensor meta_out(&dense_out); - - phi::AddNInferMeta(meta_inputs, &meta_out); + phi::AddNInferMeta(meta_inputs, &meta_out, phi::MetaConfig(false, false)); std::vector argument_outputs; pir::Type out_dense_tensor_type = paddle::dialect::DenseTensorType::get( @@ -645,9 +428,10 @@ void AddNArrayOp::Build(pir::Builder &builder, // NOLINT VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - AddNArrayOp::InferMeta(argument_inputs, argument_attributes); + AddNArrayOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + argument.AddAttributes(argument_attributes); ::pir::PassStopGradientsDefaultly(argument); } @@ -658,7 +442,7 @@ void AddNArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AddNArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta AddNArrayOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -680,18 +464,6 @@ std::vector AddNArrayOp::InferMeta( .dyn_cast() .data_layout(), {})); - } else if (inputs[i] - .isa()) { - vec_dense_inputs.push_back(paddle::dialect::IrTensor( - TransToPhiDataType( - inputs[i] - .dyn_cast() - .dtype()), - inputs[i].dyn_cast().dims(), - inputs[i] - .dyn_cast() - .data_layout(), - {})); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -726,8 +498,10 @@ std::vector AddNArrayOp::InferMeta( return argument_outputs; } -const char *FusedGemmEpilogueOp::attributes_name[3] = { - "trans_x", "trans_y", "activation"}; +const char *FusedGemmEpilogueOp::attributes_name[3] = { // NOLINT + "trans_x", + "trans_y", + "activation"}; OpInfoTuple FusedGemmEpilogueOp::GetOpInfo() { std::vector inputs = { @@ -810,7 +584,7 @@ void FusedGemmEpilogueOp::Build(pir::Builder &builder, argument.AddAttribute("activation", attr_activation); argument_attributes.insert({"activation", attr_activation}); std::vector argument_outputs = - FusedGemmEpilogueOp::InferMeta(argument_inputs, argument_attributes); + FusedGemmEpilogueOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } @@ -889,7 +663,12 @@ void FusedGemmEpilogueOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector FusedGemmEpilogueOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta FusedGemmEpilogueOp"; IR_ENFORCE(input_values.size() == 3, "Num of inputs is expected to be 3 but got %d.", @@ -921,15 +700,6 @@ std::vector FusedGemmEpilogueOp::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_x = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_x.dtype(), - allocated_x.dims(), - allocated_x.data_layout(), - allocated_x.lod(), - allocated_x.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -939,15 +709,6 @@ std::vector FusedGemmEpilogueOp::InferMeta( paddle::dialect::DenseTensorType y; if (y_.type().isa()) { y = y_.type().dyn_cast(); - } else if (y_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_y = - y_.type().dyn_cast(); - y = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_y.dtype(), - allocated_y.dims(), - allocated_y.data_layout(), - allocated_y.lod(), - allocated_y.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -957,15 +718,6 @@ std::vector FusedGemmEpilogueOp::InferMeta( paddle::dialect::DenseTensorType bias; if (bias_.type().isa()) { bias = bias_.type().dyn_cast(); - } else if (bias_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_bias = - bias_.type().dyn_cast(); - bias = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_bias.dtype(), - allocated_bias.dims(), - allocated_bias.data_layout(), - allocated_bias.lod(), - allocated_bias.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -1040,8 +792,10 @@ std::vector FusedGemmEpilogueOp::InferMeta( return argument_outputs; } -const char *FusedGemmEpilogueGradOp::attributes_name[3] = { - "trans_x", "trans_y", "activation_grad"}; +const char *FusedGemmEpilogueGradOp::attributes_name[3] = { // NOLINT + "trans_x", + "trans_y", + "activation_grad"}; OpInfoTuple FusedGemmEpilogueGradOp::GetOpInfo() { std::vector inputs = { @@ -1145,7 +899,7 @@ void FusedGemmEpilogueGradOp::Build(pir::Builder &builder, argument.AddAttribute("activation_grad", attr_activation_grad); argument_attributes.insert({"activation_grad", attr_activation_grad}); std::vector argument_outputs = - FusedGemmEpilogueGradOp::InferMeta(argument_inputs, argument_attributes); + FusedGemmEpilogueGradOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } @@ -1159,7 +913,12 @@ void FusedGemmEpilogueGradOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector FusedGemmEpilogueGradOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; IR_ENFORCE(input_values.size() == 4, "Num of inputs is expected to be 4 but got %d.", input_values.size()); @@ -1193,15 +952,6 @@ std::vector FusedGemmEpilogueGradOp::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_x = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_x.dtype(), - allocated_x.dims(), - allocated_x.data_layout(), - allocated_x.lod(), - allocated_x.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -1211,15 +961,6 @@ std::vector FusedGemmEpilogueGradOp::InferMeta( paddle::dialect::DenseTensorType y; if (y_.type().isa()) { y = y_.type().dyn_cast(); - } else if (y_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_y = - y_.type().dyn_cast(); - y = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_y.dtype(), - allocated_y.dims(), - allocated_y.data_layout(), - allocated_y.lod(), - allocated_y.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -1231,18 +972,6 @@ std::vector FusedGemmEpilogueGradOp::InferMeta( if (reserve_space_.type().isa()) { reserve_space = reserve_space_.type().dyn_cast(); - } else if (reserve_space_.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_reserve_space = - reserve_space_.type() - .dyn_cast(); - reserve_space = paddle::dialect::DenseTensorType::get( - pir::IrContext::Instance(), - allocated_reserve_space.dtype(), - allocated_reserve_space.dims(), - allocated_reserve_space.data_layout(), - allocated_reserve_space.lod(), - allocated_reserve_space.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -1255,17 +984,6 @@ std::vector FusedGemmEpilogueGradOp::InferMeta( paddle::dialect::DenseTensorType out_grad; if (out_grad_.type().isa()) { out_grad = out_grad_.type().dyn_cast(); - } else if (out_grad_.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_out_grad = - out_grad_.type().dyn_cast(); - out_grad = - paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_out_grad.dtype(), - allocated_out_grad.dims(), - allocated_out_grad.data_layout(), - allocated_out_grad.lod(), - allocated_out_grad.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -1362,7 +1080,7 @@ std::vector FusedGemmEpilogueGradOp::InferMeta( return argument_outputs; } -const char *SplitGradOp::attributes_name[1] = {"axis"}; +const char *SplitGradOp::attributes_name[1] = {"axis"}; // NOLINT OpInfoTuple SplitGradOp::GetOpInfo() { std::vector inputs = { @@ -1413,7 +1131,7 @@ void SplitGradOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - SplitGradOp::InferMeta(argument_inputs, argument_attributes); + SplitGradOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -1432,7 +1150,7 @@ void SplitGradOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - SplitGradOp::InferMeta(argument_inputs, argument_attributes); + SplitGradOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -1497,7 +1215,7 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector SplitGradOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta SplitGradOp"; IR_ENFORCE(input_values.size() == 2, @@ -1551,7 +1269,7 @@ std::vector SplitGradOp::InferMeta( return argument_outputs; } -const char *CreateArrayOp::attributes_name[1] = {"dtype"}; +const char *CreateArrayOp::attributes_name[1] = {"dtype"}; // NOLINT OpInfoTuple CreateArrayOp::GetOpInfo() { std::vector inputs = {}; @@ -1590,7 +1308,7 @@ void CreateArrayOp::Build(pir::Builder &builder, argument.AddAttribute("dtype", attr_dtype); argument_attributes.insert({"dtype", attr_dtype}); std::vector argument_outputs = - CreateArrayOp::InferMeta(argument_inputs, argument_attributes); + CreateArrayOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -1636,7 +1354,12 @@ void CreateArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector CreateArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta CreateArrayOp"; PADDLE_ENFORCE( @@ -1708,7 +1431,7 @@ void CreateArrayLikeOp::Build(pir::Builder &builder, // NOLINT argument.AddAttribute("val", attr_val); argument_attributes.insert({"val", attr_val}); std::vector argument_outputs = - CreateArrayLikeOp::InferMeta(argument_inputs, argument_attributes); + CreateArrayLikeOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -1754,7 +1477,7 @@ void CreateArrayLikeOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector CreateArrayLikeOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta CreateArrayLikeOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -1766,16 +1489,6 @@ std::vector CreateArrayLikeOp::InferMeta( if (input_.type().isa()) { input_type = input_.type().dyn_cast(); - } else if (input_.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - input_.type() - .dyn_cast(); - input_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -1837,7 +1550,7 @@ void ArrayLengthOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ArrayLengthOp::InferMeta(argument_inputs, argument_attributes); + ArrayLengthOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); } @@ -1885,7 +1598,7 @@ void ArrayLengthOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayLengthOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta ArrayLengthOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -1895,14 +1608,6 @@ std::vector ArrayLengthOp::InferMeta( paddle::dialect::DenseTensorArrayType x_type; if (x_.type().isa()) { x_type = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - x_.type().dyn_cast(); - x_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -1977,7 +1682,7 @@ void ArrayReadOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ArrayReadOp::InferMeta(argument_inputs, argument_attributes); + ArrayReadOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -1994,7 +1699,7 @@ void ArrayReadOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ArrayReadOp::InferMeta(argument_inputs, argument_attributes); + ArrayReadOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -2049,7 +1754,7 @@ void ArrayReadOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayReadOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta ArrayLengthOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -2062,16 +1767,6 @@ std::vector ArrayReadOp::InferMeta( if (array_.type().isa()) { array_type = array_.type().dyn_cast(); - } else if (array_.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - array_.type() - .dyn_cast(); - array_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -2087,15 +1782,14 @@ std::vector ArrayReadOp::InferMeta( phi::Scalar i_scalar; if (i_.isa() && i_.defining_op()->isa()) { - i_scalar = - std::move(phi::Scalar(i_.defining_op() - ->dyn_cast() - .attribute("value") - .dyn_cast() - .data() - .to())); + i_scalar = phi::Scalar(i_.defining_op() + ->dyn_cast() + .attribute("value") + .dyn_cast() + .data() + .to()); } else { - i_scalar = std::move(phi::Scalar(-1)); + i_scalar = phi::Scalar(-1); i_scalar.SetFromTensor(true); } @@ -2160,7 +1854,7 @@ void ArrayWrite_Op::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ArrayWrite_Op::InferMeta(argument_inputs, argument_attributes); + ArrayWrite_Op::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); constexpr char kStopGradientAttrName[] = "stop_gradient"; @@ -2228,7 +1922,7 @@ void ArrayWrite_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayWrite_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta ArrayWrite_Op"; IR_ENFORCE(input_values.size() == 3, "Num of inputs is expected to be 3 but got %d.", @@ -2241,16 +1935,6 @@ std::vector ArrayWrite_Op::InferMeta( if (array_.type().isa()) { array_type = array_.type().dyn_cast(); - } else if (array_.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - array_.type() - .dyn_cast(); - array_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -2268,17 +1952,6 @@ std::vector ArrayWrite_Op::InferMeta( phi::Place place = phi::CPUPlace(); if (x_.type().isa()) { x_type = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_input = - x_.type().dyn_cast(); - place = allocated_input.place(), - x_type = - paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout(), - allocated_input.lod(), - allocated_input.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -2306,20 +1979,19 @@ std::vector ArrayWrite_Op::InferMeta( dense_array.layout()); // update array's dims as x's dims. // TOOD(chenxi67) Do not change if dim is set by custom - if (array_.type().isa()) { - array_.set_type( - paddle::dialect::DenseTensorArrayType::get(pir::IrContext::Instance(), - array_type.dtype(), - x_type.dims(), - array_type.data_layout())); - } else if (array_.type() - .isa()) { + if (array_.type().isa()) { array_.set_type(paddle::dialect::AllocatedDenseTensorArrayType::get( pir::IrContext::Instance(), place, array_type.dtype(), x_type.dims(), array_type.data_layout())); + } else if (array_.type().isa()) { + array_.set_type( + paddle::dialect::DenseTensorArrayType::get(pir::IrContext::Instance(), + array_type.dtype(), + x_type.dims(), + array_type.data_layout())); } argument_outputs.push_back(out_type); @@ -2381,7 +2053,7 @@ void ArrayToTensorOp::Build(pir::Builder &builder, // NOLINT argument.AddAttribute("use_stack", attr_use_stack); argument_attributes.insert({"use_stack", attr_use_stack}); std::vector argument_outputs = - ArrayToTensorOp::InferMeta(argument_inputs, argument_attributes); + ArrayToTensorOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -2442,7 +2114,12 @@ void ArrayToTensorOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayToTensorOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta ArrayToTensorOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -2462,14 +2139,6 @@ std::vector ArrayToTensorOp::InferMeta( paddle::dialect::DenseTensorArrayType x_type; if (x_.type().isa()) { x_type = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - x_.type().dyn_cast(); - x_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -2576,7 +2245,7 @@ void TensorToArrayOp::Build(pir::Builder &builder, // NOLINT argument.AddAttribute("use_stack", attr_use_stack); argument_attributes.insert({"use_stack", attr_use_stack}); std::vector argument_outputs = - TensorToArrayOp::InferMeta(argument_inputs, argument_attributes); + TensorToArrayOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -2639,7 +2308,12 @@ void TensorToArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector TensorToArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta TensorToArrayOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -2664,14 +2338,6 @@ std::vector TensorToArrayOp::InferMeta( if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -2687,17 +2353,6 @@ std::vector TensorToArrayOp::InferMeta( paddle::dialect::DenseTensorType out_grad; if (out_grad_.type().isa()) { out_grad = out_grad_.type().dyn_cast(); - } else if (out_grad_.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_input = - out_grad_.type().dyn_cast(); - out_grad = - paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout(), - allocated_input.lod(), - allocated_input.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -2815,16 +2470,15 @@ void SliceArrayOp::VerifySig() { phi::IntArray CalcSliceBoundsFromValue(pir::Value starts_or_ends) { phi::IntArray starts_or_ends_list; if (starts_or_ends.defining_op()->isa()) { - starts_or_ends_list = - std::move(phi::IntArray(paddle::dialect::GetInt64Vector( - starts_or_ends.defining_op() - ->dyn_cast() - .attribute("value")))); + starts_or_ends_list = phi::IntArray(paddle::dialect::GetInt64Vector( + starts_or_ends.defining_op() + ->dyn_cast() + .attribute("value"))); } else if (starts_or_ends.type().isa()) { size_t starts_or_ends_size = starts_or_ends.type().dyn_cast().size(); starts_or_ends_list = - std::move(phi::IntArray(std::vector(starts_or_ends_size, -1))); + phi::IntArray(std::vector(starts_or_ends_size, -1)); starts_or_ends_list.SetFromTensor(true); } else if (starts_or_ends.type().isa()) { common::DDim starts_or_ends_dim = @@ -2836,20 +2490,7 @@ phi::IntArray CalcSliceBoundsFromValue(pir::Value starts_or_ends) { starts_or_ends_size = 1; } starts_or_ends_list = - std::move(phi::IntArray(std::vector(starts_or_ends_size, -1))); - starts_or_ends_list.SetFromTensor(true); - } else if (starts_or_ends.type() - .isa()) { - common::DDim starts_or_ends_dim = - starts_or_ends.type() - .dyn_cast() - .dims(); - size_t starts_or_ends_size = common::product(starts_or_ends_dim); - if (common::contain_unknown_dim(starts_or_ends_dim)) { - starts_or_ends_size = 1; - } - starts_or_ends_list = - std::move(phi::IntArray(std::vector(starts_or_ends_size, -1))); + phi::IntArray(std::vector(starts_or_ends_size, -1)); starts_or_ends_list.SetFromTensor(true); } else { PADDLE_THROW( @@ -2872,7 +2513,7 @@ void SliceArrayOp::Build(pir::Builder &builder, // NOLINT pir::AttributeMap argument_attributes = {}; VLOG(4) << "Builder construction outputs"; std::vector argument_outputs = - SliceArrayOp::InferMeta(argument_inputs, argument_attributes); + SliceArrayOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); } @@ -2884,7 +2525,7 @@ void SliceArrayOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector SliceArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta SliceArrayOp"; IR_ENFORCE(input_values.size() == 3, "Num of inputs is expected to be 3 but got %d.", @@ -2897,15 +2538,6 @@ std::vector SliceArrayOp::InferMeta( paddle::dialect::DenseTensorArrayType input_type; if (input.type().isa()) { input_type = input.type().dyn_cast(); - } else if (input.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - input.type().dyn_cast(); - input_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::AllocatedDenseTensorArrayType or " @@ -3031,7 +2663,7 @@ void SliceArrayDenseOp::Build(pir::Builder &builder, // NOLINT VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - SliceArrayDenseOp::InferMeta(argument_inputs, argument_attributes); + SliceArrayDenseOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3044,7 +2676,7 @@ void SliceArrayDenseOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector SliceArrayDenseOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta SliceArrayDenseOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -3056,15 +2688,6 @@ std::vector SliceArrayDenseOp::InferMeta( paddle::dialect::DenseTensorArrayType input_type; if (input.type().isa()) { input_type = input.type().dyn_cast(); - } else if (input.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - input.type().dyn_cast(); - input_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -3138,7 +2761,7 @@ void AssignArrayOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction outputs"; std::vector argument_outputs = - AssignArrayOp::InferMeta(argument_inputs, argument_attributes); + AssignArrayOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); } @@ -3192,7 +2815,7 @@ phi::DataType AssignArrayOp::GetKernelTypeForVar( std::vector AssignArrayOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta AssignArrayOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -3203,14 +2826,6 @@ std::vector AssignArrayOp::InferMeta( paddle::dialect::DenseTensorArrayType x_type; if (x_.type().isa()) { x_type = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - x_.type().dyn_cast(); - x_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -3301,7 +2916,7 @@ void AssignArray_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AssignArray_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta AssignArray_Op"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -3312,14 +2927,6 @@ std::vector AssignArray_Op::InferMeta( paddle::dialect::DenseTensorArrayType x_type; if (x_.type().isa()) { x_type = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - x_.type().dyn_cast(); - x_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -3403,7 +3010,7 @@ void ExpandOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ExpandOp::InferMeta(argument_inputs, argument_attributes); + ExpandOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3436,7 +3043,7 @@ void ExpandOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ExpandOp::InferMeta(argument_inputs, argument_attributes); + ExpandOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3455,7 +3062,7 @@ void ExpandOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ExpandOp::InferMeta(argument_inputs, argument_attributes); + ExpandOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3463,8 +3070,8 @@ void ExpandOp::Build(pir::Builder &builder, bool ExpandOp::InferSymbolicShape( pir::ShapeConstraintIRAnalysis *shape_analysis) { - const auto x_shape_or_data = shape_analysis->GetShapeOrDataForValue(x()); - const auto expand_shape_shape_or_data = + const auto &x_shape_or_data = shape_analysis->GetShapeOrDataForValue(x()); + const auto &expand_shape_shape_or_data = shape_analysis->GetShapeOrDataForValue(shape()); const std::vector &x_dims = [&] { @@ -3479,12 +3086,23 @@ bool ExpandOp::InferSymbolicShape( const std::vector &expand_shape = [&] { std::vector dims; - if (expand_shape_shape_or_data.data().has_value()) { - dims = expand_shape_shape_or_data.data().value(); + + if (expand_shape_shape_or_data + .isa()) { + const auto &dims_list = + expand_shape_shape_or_data + .dyn_cast(); + for (const auto &shape_data : dims_list) { + const auto &dim_expr = shape_data.data().has_value() + ? shape_data.data().value()[0] + : shape_data.shape()[0]; + dims.emplace_back(dim_expr); + } } else { - dims = expand_shape_shape_or_data.shape(); + dims = expand_shape_shape_or_data.data().has_value() + ? expand_shape_shape_or_data.data().value() + : expand_shape_shape_or_data.shape(); } - if (dims.empty()) { dims = std::vector(x_dims.size(), -1); } @@ -3564,7 +3182,7 @@ void ExpandOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ExpandOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta ExpandOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -3577,15 +3195,6 @@ std::vector ExpandOp::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_input = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout(), - allocated_input.lod(), - allocated_input.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -3633,17 +3242,6 @@ std::vector ExpandOp::InferMeta( } vec_shape = std::vector(shape_size, -2); *is_from_tensor = true; - } else if (shape.type().isa()) { - common::DDim shape_dim = - shape.type() - .dyn_cast() - .dims(); - size_t shape_size = common::product(shape_dim); - if (common::contain_unknown_dim(shape_dim)) { - shape_size = 1; - } - vec_shape = std::vector(shape_size, -2); - *is_from_tensor = true; } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support VectorType or DenseTensorType " @@ -3653,8 +3251,7 @@ std::vector ExpandOp::InferMeta( }; is_from_tensor = false; - phi::IntArray shape = - std::move(phi::IntArray(ParseValueShape(shape_, &is_from_tensor))); + phi::IntArray shape = phi::IntArray(ParseValueShape(shape_, &is_from_tensor)); if (is_from_tensor) shape.SetFromTensor(true); VLOG(4) << "Builder construction dense_x"; @@ -3732,7 +3329,7 @@ void IncrementOp::Build(pir::Builder &builder, argument.AddAttribute("value", attr_value); argument_attributes.insert({"value", attr_value}); std::vector argument_outputs = - IncrementOp::InferMeta(argument_inputs, argument_attributes); + IncrementOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3759,7 +3356,7 @@ void IncrementOp::Build(pir::Builder &builder, argument.AddAttribute("value", attr_value); argument_attributes.insert({"value", attr_value}); std::vector argument_outputs = - IncrementOp::InferMeta(argument_inputs, argument_attributes); + IncrementOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3807,7 +3404,12 @@ void IncrementOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector IncrementOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta IncrementOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -3822,15 +3424,6 @@ std::vector IncrementOp::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_input = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout(), - allocated_input.lod(), - allocated_input.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -3872,6 +3465,14 @@ phi::DataType IncrementOp::GetKernelTypeForVar( return expected_kernel_dtype; } +bool IncrementOp::InferSymbolicShape( + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(x()); + shape_analysis->SetShapeOrDataForValue(out(), operand_shape_or_data); + return true; +} + const char *Increment_Op::attributes_name[1] = {"value"}; OpInfoTuple Increment_Op::GetOpInfo() { @@ -3913,7 +3514,7 @@ void Increment_Op::Build(pir::Builder &builder, argument.AddAttribute("value", attr_value); argument_attributes.insert({"value", attr_value}); std::vector argument_outputs = - Increment_Op::InferMeta(argument_inputs, argument_attributes); + Increment_Op::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3940,7 +3541,7 @@ void Increment_Op::Build(pir::Builder &builder, argument.AddAttribute("value", attr_value); argument_attributes.insert({"value", attr_value}); std::vector argument_outputs = - Increment_Op::InferMeta(argument_inputs, argument_attributes); + Increment_Op::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -3989,7 +3590,12 @@ void Increment_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector Increment_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta Increment_Op"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -4004,15 +3610,6 @@ std::vector Increment_Op::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_input = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout(), - allocated_input.lod(), - allocated_input.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -4054,6 +3651,14 @@ phi::DataType Increment_Op::GetKernelTypeForVar( return expected_kernel_dtype; } +bool Increment_Op::InferSymbolicShape( + pir::ShapeConstraintIRAnalysis *shape_analysis) { + const symbol::ShapeOrDataDimExprs &operand_shape_or_data = + shape_analysis->GetShapeOrDataForValue(x()); + shape_analysis->SetShapeOrDataForValue(out(), operand_shape_or_data); + return true; +} + OpInfoTuple AssignOut_Op::GetOpInfo() { std::vector inputs = { paddle::dialect::OpInputInfo( @@ -4095,7 +3700,7 @@ void AssignOut_Op::Build(pir::Builder &builder, pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - AssignOut_Op::InferMeta(argument_inputs, argument_attributes); + AssignOut_Op::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); constexpr char kStopGradientAttrName[] = "stop_gradient"; auto stop_gradient0 = @@ -4150,7 +3755,7 @@ void AssignOut_Op::InferMeta(phi::InferMetaContext *infer_meta) { std::vector AssignOut_Op::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", input_values.size()); @@ -4161,15 +3766,6 @@ std::vector AssignOut_Op::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_x = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_x.dtype(), - allocated_x.dims(), - allocated_x.data_layout(), - allocated_x.lod(), - allocated_x.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -4225,7 +3821,7 @@ void ShapeBroadcastOp::Build(pir::Builder &builder, VLOG(4) << "Builder construction attributes"; pir::AttributeMap argument_attributes = {}; std::vector argument_outputs = - ShapeBroadcastOp::InferMeta(argument_inputs, argument_attributes); + ShapeBroadcastOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -4238,7 +3834,7 @@ void ShapeBroadcastOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ShapeBroadcastOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { VLOG(4) << "Start infermeta ShapeBroadcastOp"; IR_ENFORCE(input_values.size() == 2, "Num of inputs is expected to be 2 but got %d.", @@ -4250,15 +3846,6 @@ std::vector ShapeBroadcastOp::InferMeta( paddle::dialect::DenseTensorType x; if (x_.type().isa()) { x = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_x = - x_.type().dyn_cast(); - x = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_x.dtype(), - allocated_x.dims(), - allocated_x.data_layout(), - allocated_x.lod(), - allocated_x.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -4268,15 +3855,6 @@ std::vector ShapeBroadcastOp::InferMeta( paddle::dialect::DenseTensorType y; if (y_.type().isa()) { y = y_.type().dyn_cast(); - } else if (y_.type().isa()) { - paddle::dialect::AllocatedDenseTensorType allocated_x = - y_.type().dyn_cast(); - y = paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(), - allocated_x.dtype(), - allocated_x.dims(), - allocated_x.data_layout(), - allocated_x.lod(), - allocated_x.offset()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorType or " @@ -4335,7 +3913,7 @@ symbol::DimExpr GetBroadcastDimExpr(const symbol::DimExpr &lhs, return symbol::Broadcast{ symbol::List{lhs, rhs}}; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code")); } } // namespace @@ -4466,7 +4044,7 @@ void MemcpyD2hMultiIoOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector MemcpyD2hMultiIoOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", input_values.size()); @@ -4476,14 +4054,6 @@ std::vector MemcpyD2hMultiIoOp::InferMeta( paddle::dialect::DenseTensorArrayType x_type; if (x_.type().isa()) { x_type = x_.type().dyn_cast(); - } else if (x_.type().isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - x_.type().dyn_cast(); - x_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -4608,7 +4178,7 @@ void ArrayPopOp::Build(pir::Builder &builder, // NOLINT argument.AddAttribute("index", attr_index); argument_attributes.insert({"index", attr_index}); std::vector argument_outputs = - ArrayPopOp::InferMeta(argument_inputs, argument_attributes); + ArrayPopOp::InferMeta(argument_inputs, &argument_attributes); argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); @@ -4621,7 +4191,12 @@ void ArrayPopOp::InferMeta(phi::InferMetaContext *infer_meta) { std::vector ArrayPopOp::InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes) { + pir::AttributeMap *p_attributes) { + PADDLE_ENFORCE_NOT_NULL( + p_attributes, + common::errors::Fatal( + "AttrtibueMap pointer in InferMeta function is nullptr.")); + auto &attributes = *p_attributes; VLOG(4) << "Start infermeta ArrayPopOp"; IR_ENFORCE(input_values.size() == 1, "Num of inputs is expected to be 1 but got %d.", @@ -4632,15 +4207,6 @@ std::vector ArrayPopOp::InferMeta( paddle::dialect::DenseTensorArrayType input_type; if (input.type().isa()) { input_type = input.type().dyn_cast(); - } else if (input.type() - .isa()) { - paddle::dialect::AllocatedDenseTensorArrayType allocated_input = - input.type().dyn_cast(); - input_type = paddle::dialect::DenseTensorArrayType::get( - pir::IrContext::Instance(), - allocated_input.dtype(), - allocated_input.dims(), - allocated_input.data_layout()); } else { PADDLE_THROW(phi::errors::Unimplemented( "Only support paddle::dialect::DenseTensorArrayType or " @@ -4701,7 +4267,6 @@ phi::DataType ArrayPopOp::GetKernelTypeForVar( IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNArrayOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AssignOut_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index ea836f68a4959..8d13c11d06a59 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -55,7 +55,7 @@ class AddNOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, @@ -87,30 +87,7 @@ class AddN_Op : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); -}; - -class AddNWithKernelOp : public pir::Op { - public: - using Op::Op; - static const char *name() { return "pd_op.add_n_with_kernel"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - static OpInfoTuple GetOpInfo(); - static void Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - pir::Value inputs_); - - void VerifySig(); - pir::Value inputs() { return operand_source(0); } - pir::Value out() { return result(0); } - - static void InferMeta(phi::InferMetaContext *infer_meta); - static std::vector InferMeta( - const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class AddNArrayOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class FusedGemmEpilogueOp @@ -163,7 +140,7 @@ class FusedGemmEpilogueOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class FusedGemmEpilogueGradOp @@ -196,7 +173,7 @@ class FusedGemmEpilogueGradOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class SplitGradOp : public pir::Op { @@ -222,7 +199,7 @@ class SplitGradOp : public pir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class CreateArrayOp @@ -241,7 +218,7 @@ class CreateArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class CreateArrayLikeOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class ArrayLengthOp @@ -283,7 +260,7 @@ class ArrayLengthOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class ArrayReadOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -344,7 +321,7 @@ class ArrayWrite_Op : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -375,7 +352,7 @@ class ArrayToTensorOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -405,7 +382,7 @@ class TensorToArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class SliceArrayOp @@ -439,7 +416,7 @@ class SliceArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class SliceArrayDenseOp @@ -471,7 +448,7 @@ class SliceArrayDenseOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class AssignArrayOp @@ -502,7 +479,7 @@ class AssignArrayOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class AssignArray_Op @@ -530,7 +507,7 @@ class AssignArray_Op static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class ExpandOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -588,6 +565,7 @@ class IncrementOp : public pir::Op { public: @@ -619,19 +597,21 @@ class IncrementOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); }; class Increment_Op : public pir::Op { @@ -664,13 +644,14 @@ class Increment_Op static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, const std::vector> &outputs, const std::vector> &out_grads, const std::vector> &stop_gradients); + bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); }; class AssignOut_Op @@ -705,7 +686,7 @@ class AssignOut_Op static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); static std::vector> Vjp( pir::Operation *op, const std::vector> &inputs_, @@ -748,7 +729,7 @@ class MemcpyD2hMultiIoOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; class IR_API ShapeBroadcastOp @@ -774,7 +755,7 @@ class IR_API ShapeBroadcastOp static void InferMeta(phi::InferMetaContext *infer_meta); static std::vector InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis); }; @@ -809,7 +790,7 @@ class ArrayPopOp : public pir::Op InferMeta( const std::vector &input_values, - const pir::AttributeMap &attributes); + pir::AttributeMap *p_attributes); }; } // namespace dialect @@ -818,7 +799,6 @@ class ArrayPopOp : public pir::Op StringToDataType{ - {"bool", phi::DataType::BOOL}, - {"uint8", phi::DataType::UINT8}, - {"int8", phi::DataType::INT8}, - {"uint16", phi::DataType::UINT16}, - {"int16", phi::DataType::INT16}, - {"uint32", phi::DataType::UINT32}, - {"int32", phi::DataType::INT32}, - {"uint64", phi::DataType::UINT64}, - {"int64", phi::DataType::INT64}, - {"float32", phi::DataType::FLOAT32}, - {"complex64", phi::DataType::COMPLEX64}, - {"complex128", phi::DataType::COMPLEX128}, - {"Undefined", phi::DataType::UNDEFINED}, - {"psting", phi::DataType::PSTRING}, - {"float16", phi::DataType::FLOAT16}, - {"bfloat16", phi::DataType::BFLOAT16}, - {"float64", phi::DataType::FLOAT64}}; std::string datatype_token_val = parser.ConsumeToken().val_; - IR_ENFORCE(StringToDataType.count(datatype_token_val) > 0, - datatype_token_val + " is not defined in DataType." + - parser.GetErrorLocationInfo()); + PADDLE_ENFORCE_EQ(StringToDataTypeMap().count(datatype_token_val) > 0, + true, + common::errors::InvalidArgument( + datatype_token_val + " is not defined in DataType." + + parser.GetErrorLocationInfo())); return DataTypeAttribute::get(parser.ctx, - StringToDataType[datatype_token_val]); + StringToDataTypeMap().at(datatype_token_val)); } // Parse a PlaceAttribute // PlaceAttribute := Place(cpu)|Place(gpu:0)|Place(gpu_pinned) // |Place(xpu:0)|Place(ipu:0)|Place(:0)|undefined PlaceAttribute PlaceAttribute::Parse(pir::IrParser &parser) { // NOLINT - std::unordered_map StringToPlace{ - {"cpu", phi::CPUPlace{}}, - {"gpu", phi::GPUPlace{}}, - {"gpu_pinned", phi::GPUPinnedPlace{}}, - {"xpu", phi::XPUPlace{}}, - {"ipu", phi::IPUPlace{}}, - {":", phi::CustomPlace{}}, - {"undefined", phi::Place{}}}; parser.ConsumeAToken("Place"); parser.ConsumeAToken("("); std::string place_token_val = parser.ConsumeToken().val_; - IR_ENFORCE(StringToPlace.count(place_token_val) > 0, - place_token_val + " is not defined in Place." + - parser.GetErrorLocationInfo()); + PADDLE_ENFORCE_EQ(StringToPlaceMap().count(place_token_val) > 0, + true, + common::errors::InvalidArgument( + place_token_val + " is not defined in Place." + + parser.GetErrorLocationInfo())); if (parser.PeekToken().val_ == ":") { parser.ConsumeAToken(":"); parser.ConsumeToken(); @@ -124,7 +104,8 @@ PlaceAttribute PlaceAttribute::Parse(pir::IrParser &parser) { // NOLINT parser.ConsumeToken(); } parser.ConsumeAToken(")"); - return PlaceAttribute::get(parser.ctx, StringToPlace[place_token_val]); + return PlaceAttribute::get(parser.ctx, + StringToPlaceMap().at(place_token_val)); } // Parse a DataLayoutAttribute @@ -133,28 +114,20 @@ PlaceAttribute PlaceAttribute::Parse(pir::IrParser &parser) { // NOLINT // |NCDHW|PSTRING_UNION|STRIDED DataLayoutAttribute DataLayoutAttribute::Parse( pir::IrParser &parser) { // NOLINT - std::unordered_map StringToDataLayout{ - {"NHWC", phi::DataLayout::kNHWC}, - {"NCHW", phi::DataLayout::kNCHW}, - {"Undefined", phi::DataLayout::kAnyLayout}, - {"ONEDNN", phi::DataLayout::ONEDNN}, - {"SPARSE_COO", phi::DataLayout::SPARSE_COO}, - {"SPARSE_CSR", phi::DataLayout::SPARSE_CSR}, - {"NDHWC", phi::DataLayout::kNDHWC}, - {"NCDHW", phi::DataLayout::kNCDHW}, - {"PSTRING_UNION", phi::DataLayout::PSTRING_UNION}, - {"STRIDED", phi::DataLayout::STRIDED}}; std::string datalayout_token_val = parser.ConsumeToken().val_; - IR_ENFORCE(StringToDataLayout.count(datalayout_token_val) > 0, - datalayout_token_val + " is not defined in DataLayout." + - parser.GetErrorLocationInfo()); + PADDLE_ENFORCE_EQ( + StringToDataLayoutMap().count(datalayout_token_val) > 0, + true, + common::errors::InvalidArgument(datalayout_token_val + + " is not defined in DataLayout." + + parser.GetErrorLocationInfo())); if (datalayout_token_val == "Undefined") { parser.ConsumeAToken("("); parser.ConsumeAToken("AnyLayout"); parser.ConsumeAToken(")"); } - return DataLayoutAttribute::get(parser.ctx, - StringToDataLayout[datalayout_token_val]); + return DataLayoutAttribute::get( + parser.ctx, StringToDataLayoutMap().at(datalayout_token_val)); } } // namespace dialect diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index 6816d64a05467..f60bdd115cf36 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -24,6 +24,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/transforms/param_to_variable.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/pir/include/core/builtin_type_interfaces.h" #include "paddle/pir/include/core/interface_value.h" #include "paddle/pir/include/core/ir_printer.h" @@ -38,17 +39,6 @@ namespace paddle { namespace dialect { -static std::unordered_map kCustomTypeMap = { - {"bool", "pir::BoolAttribute"}, - {"int", "pir::Int32Attribute"}, - {"float", "pir::FloatAttribute"}, - {"int64_t", "pir::Int64Attribute"}, - {"std::string", "pir::StrAttribute"}, - {"std::vector", "pir::ArrayAttribute"}, - {"std::vector", "pir::ArrayAttribute"}, - {"std::vector", "pir::ArrayAttribute"}, - {"std::vector", "pir::ArrayAttribute"}}; - struct CombineOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( @@ -141,6 +131,17 @@ struct ParameterOpInferSymbolicShapeInterfaceModel : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} }; +struct SetParameterOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + return true; + } + + SetParameterOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + struct ShadowOutputOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( @@ -159,6 +160,52 @@ struct ShadowOutputOpInferSymbolicShapeInterfaceModel : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} }; +struct SliceOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + const auto index = + op->attributes().at("index").dyn_cast().data(); + const auto output_value = + (op->operand(0).type().dyn_cast())[index] + .dyn_cast(); + + shape_analysis->SetShapeOrDataForValue( + op->result(0), shape_analysis->GetShapeOrDataForValue(output_value)); + + return true; + } + + SliceOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + +struct SplitOpInferSymbolicShapeInterfaceModel + : public InferSymbolicShapeInterface::Concept { + static inline bool InferSymbolicShape( + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis) { + const auto& shape_data_list = + shape_analysis->GetShapeOrDataForValue(op->operand_source(0)) + .dyn_cast(); + + for (uint32_t rst_idx = 0; rst_idx < op->num_results(); rst_idx++) { + PADDLE_ENFORCE_EQ( + shape_data_list[rst_idx].data().has_value(), + false, + paddle::platform::errors::InvalidArgument( + "Currently InferSymbolicShape of SplitOp only support " + "input without value.")); + shape_analysis->SetShapeOrDataForValue( + op->result(rst_idx), + symbol::ShapeOrDataDimExprs{shape_data_list[rst_idx]}); + } + return true; + } + + SplitOpInferSymbolicShapeInterfaceModel() + : InferSymbolicShapeInterface::Concept(InferSymbolicShape) {} +}; + struct YieldOpInferSymbolicShapeInterfaceModel : public InferSymbolicShapeInterface::Concept { static inline bool InferSymbolicShape( @@ -177,43 +224,49 @@ OperatorDialect::OperatorDialect(pir::IrContext* ctx) ctx->GetOrRegisterDialect<::pir::ControlFlowDialect>(); auto info = ctx->GetRegisteredOpInfo(pir::TuplePushOp::name()); - info.AttachInterface(std::move( - pir::InterfaceValue::Get())); + info.AttachInterface( + pir::InterfaceValue::Get()); info = ctx->GetRegisteredOpInfo(pir::CombineOp::name()); - info.AttachInterface(std::move( + info.AttachInterface( pir::InterfaceValue::Get())); + CombineOpInferSymbolicShapeInterfaceModel>()); info = ctx->GetRegisteredOpInfo(pir::ParameterOp::name()); - info.AttachInterface(std::move( + info.AttachInterface( pir::InterfaceValue::Get())); + ParameterOpInferSymbolicShapeInterfaceModel>()); info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name()); + info.AttachInterface(pir::InterfaceValue::Get< + InferSymbolicShapeInterface, + ShadowOutputOpInferSymbolicShapeInterfaceModel>()); + + info = ctx->GetRegisteredOpInfo(pir::SplitOp::name()); info.AttachInterface( - std::move(pir::InterfaceValue::Get< - InferSymbolicShapeInterface, - ShadowOutputOpInferSymbolicShapeInterfaceModel>())); + pir::InterfaceValue::Get()); info = ctx->GetRegisteredOpInfo(pir::YieldOp::name()); - info.AttachInterface(std::move( + info.AttachInterface( pir::InterfaceValue::Get())); + YieldOpInferSymbolicShapeInterfaceModel>()); + + info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name()); + info.AttachInterface(pir::InterfaceValue::Get< + InferSymbolicShapeInterface, + SetParameterOpInferSymbolicShapeInterfaceModel>()); + + info = ctx->GetRegisteredOpInfo(pir::SliceOp::name()); + info.AttachInterface( + pir::InterfaceValue::Get()); } void PrintTypeImpl(pir::Type type, std::ostream& os) { os << type.dialect().name(); os << '.'; - if (auto tensor_type = type.dyn_cast()) { - os << "tensor<"; - for (auto d : common::vectorize(tensor_type.dims())) { - os << d; - os << "x"; - } - tensor_type.dtype().Print(os); - os << ">"; - } else if (auto selected_rows_type = type.dyn_cast()) { + if (auto selected_rows_type = type.dyn_cast()) { os << "selectedrows<"; for (auto d : common::vectorize(selected_rows_type.dims())) { os << d; @@ -266,8 +319,9 @@ void PrintOperationImpl(pir::Operation* op, } void OperatorDialect::initialize() { - RegisterTypes(); RegisterAttributes dim{}; - Token dim_token = parser.PeekToken(); - while (dim_token.token_type_ == DIGIT) { - dim_token = parser.ConsumeToken(); - dim.push_back(atoi(dim_token.val_.c_str())); - std::string peek_token_val = parser.PeekToken().val_; - if (peek_token_val[0] != 'x') { - break; - } - parser.ConsumeToken(); - parser.lexer->Unget(static_cast(peek_token_val.size() - 1)); - if (parser.PeekToken().token_type_ != DIGIT) { - break; - } - } - phi::DDim ddim = common::make_ddim(dim); - pir::Type dtype = parser.ParseType(); - std::vector> lod; - std::vector lodv; - lodv.push_back(0); - lod.push_back(lodv); - parser.ConsumeAToken(">"); - return DenseTensorType::get( - parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0); -} - pir::Attribute OperatorDialect::ParseAttribute( pir::IrParser& parser) { // NOLINT std::string type_name = parser.ConsumeToken().val_; @@ -473,8 +498,10 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { auto& grad_op_output_names = OpMetaInfoHelper::GetOutputs(*grad_op_meta_ptr); bool is_double_grad_op = - (grad_op_name.find("_grad_grad") != grad_op_name.npos) ? true - : false; + (grad_op_name.find(paddle::framework::kDoubleGradSuffix) != + grad_op_name.npos) + ? true + : false; for (auto& grad_op_output_name : grad_op_output_names) { auto fwd_input_name = paddle::framework::detail::NoGrad( grad_op_output_name, is_double_grad_op); @@ -500,7 +527,7 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { auto attr_name = attr_name_and_type[0]; auto attr_type_str = attr_name_and_type[1]; param_names.push_back(attr_name); - if (kCustomTypeMap.find(attr_type_str) == kCustomTypeMap.end()) { + if (CppTypeToAttrTypeMap().count(attr_type_str) == 0) { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported `%s` type value as custom attribute now. " "Supported data types include `bool`, `int`, `float`, " @@ -510,9 +537,8 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { "the attribute data type and data type string are matched.", attr_type_str)); } - std::string attr_pir_type = kCustomTypeMap[attr_type_str]; - attributes_info.push_back( - paddle::dialect::OpAttributeInfo{attr_name, attr_pir_type, ""}); + std::string attr_pir_type = CppTypeToAttrTypeMap().at(attr_type_str); + attributes_info.emplace_back(attr_name, attr_pir_type, ""); } // translate output info @@ -537,8 +563,8 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { } std::vector> vec_inplace; - for (auto inplace_map : inplace_maps) { - vec_inplace.push_back(inplace_map); + for (const auto& inplace_map : inplace_maps) { + vec_inplace.emplace_back(inplace_map); } // we only need kernel params name in run_time_info @@ -556,7 +582,7 @@ struct CustomOpInfoInterfaceModel : public OpYamlInfoInterface::Concept { struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { static std::vector> CustomOpVjp( pir::Operation* op, - const std::vector>& inputs_, + const std::vector>& inputs, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients) { @@ -593,13 +619,13 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { auto infershape_func = OpMetaInfoHelper::GetInferShapeFn(bwd_op_meta_info); auto inferdtype_func = OpMetaInfoHelper::GetInferDtypeFn(bwd_op_meta_info); PADDLE_ENFORCE_EQ( - inputs_.size(), + inputs.size(), fwd_inputs_name.size(), paddle::platform::errors::InvalidArgument( "Custom op: %s inputs size should be %d, but now is %d.", pir_op_name, fwd_inputs_name.size(), - inputs_.size())); + inputs.size())); PADDLE_ENFORCE_EQ( outputs.size(), fwd_outputs_name.size(), @@ -617,9 +643,11 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { pir_op_name, fwd_outputs_name.size(), out_grads.size())); - bool is_double_grad_op = - (bwd_pir_op_name.find("_grad_grad") != pir_op_name.npos) ? true : false; + (bwd_pir_op_name.find(paddle::framework::kDoubleGradSuffix) != + bwd_pir_op_name.npos) + ? true + : false; pir::IrContext* ctx = pir::IrContext::Instance(); pir::OpInfo pir_info = ctx->GetRegisteredOpInfo(bwd_pir_op_name); pir::OperationArgument argument(pir_info); @@ -671,7 +699,6 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { grad_op_input_name)); } }; - // Construct custom grad op inputs int input_index = 0; int vec_input_index = 0; @@ -680,8 +707,8 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { const auto input_location = GetInputLocation(bwd_input_name); std::vector input_values; if (input_location.first == 0) { - // grad op input is in inputs_ - input_values = inputs_[input_location.second]; + // grad op input is in inputs + input_values = inputs[input_location.second]; } else if (input_location.first == 1) { // grad op input is in outputs input_values = outputs[input_location.second]; @@ -689,32 +716,43 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { // grad op input is in out_grads input_values = out_grads[input_location.second]; } - - if (input_values.size() > 1) { + if (paddle::framework::detail::IsDuplicableVar(bwd_input_name)) { std::vector> tmp_input_shapes; std::vector tmp_input_dtypes; + pir::Value input_value; vec_input_name2id_map[bwd_input_name] = vec_input_index; vec_input_index++; - for (auto& input_value : input_values) { - paddle::dialect::DenseTensorType input_tensor = - input_value.type().dyn_cast(); - tmp_input_shapes.push_back(phi::vectorize(input_tensor.dims())); - tmp_input_dtypes.push_back( - paddle::dialect::TransToPhiDataType(input_tensor.dtype())); + bool is_optional = + (input_values.size() == 1 && input_values[0].impl() == nullptr); + if (!is_optional) { + for (auto& input_value : input_values) { + paddle::dialect::DenseTensorType input_tensor = + input_value.type().dyn_cast(); + tmp_input_shapes.push_back(phi::vectorize(input_tensor.dims())); + tmp_input_dtypes.push_back( + paddle::dialect::TransToPhiDataType(input_tensor.dtype())); + } + input_value = paddle::dialect::builtin_combine(input_values); } vec_input_shapes.push_back(tmp_input_shapes); vec_input_dtypes.push_back(tmp_input_dtypes); - auto input_value = paddle::dialect::builtin_combine(input_values); argument_inputs.push_back(input_value); } else { + std::vector tmp_input_shape; + phi::DataType tmp_input_dtype = DataType::UNDEFINED; input_name2id_map[bwd_input_name] = input_index; input_index++; pir::Value input_value = input_values[0]; // NOLINT - paddle::dialect::DenseTensorType input_tensor = - input_value.type().dyn_cast(); - input_shapes.push_back(phi::vectorize(input_tensor.dims())); - input_dtypes.push_back( - paddle::dialect::TransToPhiDataType(input_tensor.dtype())); + if (input_value.impl() != nullptr) { + paddle::dialect::DenseTensorType input_tensor = + input_value.type().dyn_cast(); + tmp_input_shape = phi::vectorize(input_tensor.dims()); + tmp_input_dtype = + paddle::dialect::TransToPhiDataType(input_tensor.dtype()); + } + input_shapes.push_back(tmp_input_shape); + input_dtypes.push_back(tmp_input_dtype); + argument_inputs.push_back(input_value); } } @@ -729,7 +767,6 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { custom_attrs.push_back(paddle::dialect::TransAttrToAny(fwd_op_attr)); argument.AddAttribute(fwd_attr_name, fwd_op_attr); } - // Run Compile InferMeta std::vector> output_shapes = paddle::framework::RunInferShape(infershape_func, @@ -752,18 +789,23 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { std::unordered_map output_name2value_num; for (size_t i = 0; i < bwd_outputs_name.size(); ++i) { const auto& bwd_output_name = bwd_outputs_name.at(i); + const auto& bwd_input = + paddle::framework::detail::NoGrad(bwd_output_name, is_double_grad_op); + if (paddle::framework::detail::IsDuplicableVar(bwd_output_name)) { - const auto& bwd_input = paddle::framework::detail::NoGrad( - bwd_output_name, is_double_grad_op); auto index = vec_input_name2id_map[bwd_input]; - auto& input_shapes = vec_input_shapes[index]; - output_name2value_num[bwd_output_name] = input_shapes.size(); - all_values_num += input_shapes.size(); + auto& vec_input_shape = vec_input_shapes[index]; + output_name2value_num[bwd_output_name] = vec_input_shape.size(); } else { - output_name2value_num[bwd_output_name] = 1; - all_values_num++; + auto index = input_name2id_map[bwd_input]; + // input_shapes[index] is dim of tensor, if the dim doesn't have + // element, it must be a optional tensor that is None in custom operator + output_name2value_num[bwd_output_name] = + input_shapes[index].size() == 0 ? 0 : 1; } + all_values_num += output_name2value_num[bwd_output_name]; } + PADDLE_ENFORCE_EQ( output_shapes.size(), all_values_num, @@ -785,13 +827,18 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { "Tensors' dtype", all_values_num, output_dtypes.size())); - // Construct custom grad op outputs size_t value_index = 0; for (size_t i = 0; i < bwd_outputs_name.size(); ++i) { const auto& bwd_output_name = bwd_outputs_name.at(i); + auto value_num = output_name2value_num[bwd_output_name]; + if (value_num == 0) { + // Optional value condition + pir::Type out_type; + argument_outputs.push_back(out_type); + continue; + } if (paddle::framework::detail::IsDuplicableVar(bwd_output_name)) { - auto value_num = output_name2value_num[bwd_output_name]; std::vector out_types; for (size_t j = 0; j < value_num; ++j) { auto ddims = phi::make_ddim(output_shapes[value_index]); @@ -827,6 +874,7 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { } } argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); + // Build Operation std::vector op_results; pir::Operation* bwd_op = @@ -839,6 +887,42 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { for (size_t i = 0; i < stop_gradients.size(); ++i) { res[i].resize(stop_gradients[i].size()); } + + auto GetInputGradientIndex = [&](const std::string& bwd_output_name, + bool is_double_grad_op) -> size_t { + /* + This function is used to get the index of input that need calculate + gradient in forward op. For example: forward inputs : TensorA, TensorB, + TensorC, TensorD backward outputs: TensorC@Grad, TensorA@Grad So, we + only need to calculate gradient of TensorA and TensorC and store them in + res; In this example, the res size is 2, and the first element of res + should store TensorA@Grad, and the second element of res should store + TensorC@Grad. + + So, This function will return 1 if we pass TensorC@Grad and return 0 if + we pass TensorA@Grad. + */ + size_t gradient_vec_index = 0; + const auto& fwd_input = + paddle::framework::detail::NoGrad(bwd_output_name, is_double_grad_op); + auto fwd_inputs_name_iter = + std::find(fwd_inputs_name.begin(), fwd_inputs_name.end(), fwd_input); + size_t input_index = + std::distance(fwd_inputs_name.begin(), fwd_inputs_name_iter); + for (size_t i = 0; i < input_index; ++i) { + for (size_t j = 0; j < bwd_outputs_name.size(); j++) { + const auto& fwd_input_name_tmp = paddle::framework::detail::NoGrad( + bwd_outputs_name[j], is_double_grad_op); + if (fwd_input_name_tmp == fwd_inputs_name[i]) { + // find forward input that need calculate gradient + gradient_vec_index++; + break; + } + } + } + return gradient_vec_index; + }; + // Build result and apply stop gradients for (size_t i = 0; i < bwd_outputs_name.size(); ++i) { const auto& bwd_output_name = bwd_outputs_name.at(i); @@ -855,16 +939,20 @@ struct CustomOpVjpInterfaceModel : public VjpInterface::Concept { "forward input that need calculate gradients.", pir_op_name, bwd_output_name)); - int index = - std::distance(fwd_inputs_name.begin(), fwd_inputs_name_iter); - auto split_op = - ApiBuilder::Instance().GetBuilder()->Build( - bwd_op->result(i)); - res[index] = split_op.outputs(); + int index = GetInputGradientIndex(bwd_output_name, is_double_grad_op); + if (bwd_op->result(i).type().dyn_cast()) { + auto split_op = + ApiBuilder::Instance().GetBuilder()->Build( + bwd_op->result(i)); + res[index] = split_op.outputs(); + } else { + // optional output condition + pir::Value empty_value; + res[index][0] = empty_value; + } } else { if (fwd_inputs_name_iter != fwd_inputs_name.end()) { - int index = - std::distance(fwd_inputs_name.begin(), fwd_inputs_name_iter); + int index = GetInputGradientIndex(bwd_output_name, is_double_grad_op); res[index][0] = bwd_op->result(i); } else { // Situation that has only one input and only one output. If not meet diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h index ae7dc883f8911..deda7b3ddcdd0 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.h @@ -29,7 +29,6 @@ class TEST_API OperatorDialect : public pir::Dialect { static const char* name() { return "pd_op"; } - pir::Type ParseType(pir::IrParser& parser) override; // NOLINT pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT void PrintType(pir::Type type, std::ostream& os) const override; diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc index 5b7323264c626..8ea9f0a7ce02f 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc @@ -68,15 +68,7 @@ void OneDNNOperatorDialect::initialize() { void OneDNNOperatorDialect::PrintType(pir::Type type, std::ostream &os) const { os << type.dialect().name(); os << '.'; - if (auto tensor_type = type.dyn_cast()) { - os << "tensor<"; - for (auto d : common::vectorize(tensor_type.dims())) { - os << d; - os << "x"; - } - tensor_type.dtype().Print(os); - os << ">"; - } else if (auto selected_rows_type = type.dyn_cast()) { + if (auto selected_rows_type = type.dyn_cast()) { os << "selectedrows<"; for (auto d : common::vectorize(selected_rows_type.dims())) { os << d; @@ -117,35 +109,6 @@ void OneDNNOperatorDialect::PrintAttribute(pir::Attribute attr, } } -pir::Type OneDNNOperatorDialect::ParseType(pir::IrParser &parser) { // NOLINT - parser.ConsumeAToken("pd_op.tensor"); - parser.ConsumeAToken("<"); - std::vector dim{}; - Token dim_token = parser.PeekToken(); - while (dim_token.token_type_ == DIGIT) { - dim_token = parser.ConsumeToken(); - dim.push_back(atoi(dim_token.val_.c_str())); - std::string peek_token_val = parser.PeekToken().val_; - if (peek_token_val[0] != 'x') { - break; - } - parser.ConsumeToken(); - parser.lexer->Unget(static_cast(peek_token_val.size() - 1)); - if (parser.PeekToken().token_type_ != DIGIT) { - break; - } - } - phi::DDim ddim = common::make_ddim(dim); - pir::Type dtype = parser.ParseType(); - std::vector> lod; - std::vector lodv; - lodv.push_back(0); - lod.push_back(lodv); - parser.ConsumeAToken(">"); - return DenseTensorType::get( - parser.ctx, dtype, ddim, phi::DataLayout::UNDEFINED, lod, 0); -} - pir::Attribute OneDNNOperatorDialect::ParseAttribute( pir::IrParser &parser) { // NOLINT std::string type_name = parser.ConsumeToken().val_; diff --git a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h index 405c9346e2fa8..6ef33672c9c96 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h @@ -25,7 +25,6 @@ class OneDNNOperatorDialect : public pir::Dialect { static const char* name() { return "onednn_op"; } - pir::Type ParseType(pir::IrParser& parser) override; // NOLINT pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT void PrintType(pir::Type type, std::ostream& os) const override; diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.cc b/paddle/fluid/pir/dialect/operator/ir/op_type.cc index 2765352759969..2edb4a29cdc0e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.cc @@ -28,6 +28,26 @@ const phi::LoD& SelectedRowsType::lod() const { return storage()->lod_; } const size_t& SelectedRowsType::offset() const { return storage()->offset_; } +bool SelectedRowsType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) return true; + if (auto wrap_type = type.dyn_cast()) { + return classof(wrap_type.prim_type()); + } + } + return false; +} + +SelectedRowsType SelectedRowsType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) return SelectedRowsType(type.storage()); + if (auto wrap_type = type.dyn_cast()) { + return dyn_cast_impl(wrap_type.prim_type()); + } + } + return nullptr; +} + const pir::Type& DenseTensorArrayType::dtype() const { return storage()->dtype_; } @@ -37,8 +57,112 @@ const phi::DataLayout& DenseTensorArrayType::data_layout() const { return storage()->layout_; } +bool DenseTensorArrayType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) return true; + if (auto wrap_type = type.dyn_cast()) { + return classof(wrap_type.prim_type()); + } + } + return false; +} + +DenseTensorArrayType DenseTensorArrayType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) + return DenseTensorArrayType(type.storage()); + if (auto wrap_type = type.dyn_cast()) { + return dyn_cast_impl(wrap_type.prim_type()); + } + } + return nullptr; +} + +pir::Type SparseCooTensorType::dtype() const { return storage()->dtype_; } + +const common::DDim& SparseCooTensorType::dims() const { + return storage()->dims_; +} + +const common::DDim& SparseCooTensorType::non_zero_dims() const { + return storage()->non_zero_dims_; +} + +common::DataLayout SparseCooTensorType::data_layout() const { + return storage()->layout_; +} + +pir::DenseTensorType SparseCooTensorType::non_zero_indices() const { + return storage()->non_zero_indices_; +} + +pir::DenseTensorType SparseCooTensorType::non_zero_elements() const { + return storage()->non_zero_elements_; +} + +bool SparseCooTensorType::coalesced() const { return storage()->coalesced_; } + +bool SparseCooTensorType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) { + return true; + } + } + return false; +} + +SparseCooTensorType SparseCooTensorType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) { + return SparseCooTensorType(type.storage()); + } + } + return nullptr; +} + +pir::Type SparseCsrTensorType::dtype() const { return storage()->dtype_; } + +const common::DDim& SparseCsrTensorType::dims() const { + return storage()->dims_; +} + +common::DataLayout SparseCsrTensorType::data_layout() const { + return storage()->layout_; +} + +pir::DenseTensorType SparseCsrTensorType::non_zero_crows() const { + return storage()->non_zero_crows_; +} + +pir::DenseTensorType SparseCsrTensorType::non_zero_cols() const { + return storage()->non_zero_cols_; +} + +pir::DenseTensorType SparseCsrTensorType::non_zero_elements() const { + return storage()->non_zero_elements_; +} + +bool SparseCsrTensorType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) { + return true; + } + } + return false; +} + +SparseCsrTensorType SparseCsrTensorType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) { + return SparseCsrTensorType(type.storage()); + } + } + return nullptr; +} } // namespace dialect } // namespace paddle IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorArrayType) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCsrTensorType) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_type.h b/paddle/fluid/pir/dialect/operator/ir/op_type.h index b06940d5b34d7..f2c078b016dd7 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_type.h +++ b/paddle/fluid/pir/dialect/operator/ir/op_type.h @@ -42,6 +42,14 @@ class TEST_API SelectedRowsType const phi::LoD &lod() const; const size_t &offset() const; + + /// + /// \brief Implementation of 'classof' that compares the type id of + /// the provided value with the concrete type id. + /// + static bool classof(Type type); + + static SelectedRowsType dyn_cast_impl(Type type); }; class DenseTensorArrayType @@ -56,6 +64,93 @@ class DenseTensorArrayType const phi::DDim &dims() const; const phi::DataLayout &data_layout() const; + + /// + /// \brief Implementation of 'classof' that compares the type id of + /// the provided value with the concrete type id. + /// + static bool classof(Type type); + + static DenseTensorArrayType dyn_cast_impl(Type type); +}; + +class IR_API SparseCooTensorType + : public pir::Type:: + TypeBase { + public: + using Base::Base; + + pir::Type dtype() const; + const common::DDim &dims() const; + const common::DDim &non_zero_dims() const; + common::DataLayout data_layout() const; + pir::DenseTensorType non_zero_indices() const; + pir::DenseTensorType non_zero_elements() const; + bool coalesced() const; + + /// + /// \brief Implementation of 'classof' that compares the type id of + /// the provided value with the concrete type id. + /// + static bool classof(pir::Type type); + + static SparseCooTensorType dyn_cast_impl(pir::Type type); + + static SparseCooTensorType get(pir::IrContext *ctx, + pir::Type dtype, + const common::DDim &dims, + const common::DDim &non_zero_dims, + common::DataLayout layout, + pir::DenseTensorType non_zero_indices, + pir::DenseTensorType non_zero_elements, + bool coalesced = false) { + return Base::get(ctx, + dtype, + dims, + non_zero_dims, + layout, + non_zero_indices, + non_zero_elements, + coalesced); + } +}; + +class IR_API SparseCsrTensorType + : public pir::Type:: + TypeBase { + public: + using Base::Base; + + pir::Type dtype() const; + const common::DDim &dims() const; + common::DataLayout data_layout() const; + pir::DenseTensorType non_zero_crows() const; + pir::DenseTensorType non_zero_cols() const; + pir::DenseTensorType non_zero_elements() const; + + /// + /// \brief Implementation of 'classof' that compares the type id of + /// the provided value with the concrete type id. + /// + static bool classof(pir::Type type); + + static SparseCsrTensorType dyn_cast_impl(pir::Type type); + + static SparseCsrTensorType get(pir::IrContext *ctx, + pir::Type dtype, + const common::DDim &dims, + common::DataLayout layout, + pir::DenseTensorType non_zero_crows, + pir::DenseTensorType non_zero_cols, + pir::DenseTensorType non_zero_elements) { + return Base::get(ctx, + dtype, + dims, + layout, + non_zero_crows, + non_zero_cols, + non_zero_elements); + } }; } // namespace dialect @@ -63,3 +158,5 @@ class DenseTensorArrayType IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SelectedRowsType) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DenseTensorArrayType) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCooTensorType) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SparseCsrTensorType) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 5c163637450c3..4da4f54c3ac90 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -28,23 +28,8 @@ support_trans_dtype : x, y interfaces : paddle::dialect::InferSymbolicShapeInterface +# this add_n is only for ops_api_gen.py and onednn - op : add_n - args : (Tensor[] inputs) - output : Tensor - invoke : add_n_impl(inputs) - backward : add_n_grad - -- op : add_n_ - args : (Tensor[] inputs) - output : Tensor(out) - infer_meta: - func: AddNInferMeta - param: [inputs] - kernel: - func: add_n - param: [inputs] - -- op : add_n_with_kernel args : (Tensor[] inputs) output : Tensor(out) infer_meta: @@ -62,6 +47,17 @@ kernel : func : all +- op : all_reduce + args : (Tensor x, int ring_id = 0, int reduce_type = 0) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param: [x] + kernel : + func : all_reduce + param: [x, reduce_type] + inplace : (x -> out) + - op : amax args : (Tensor x, int64_t[] axis={}, bool keepdim=false) output : Tensor(out) @@ -122,6 +118,7 @@ param : [shape, dtype, values] backend: place> data_type : dtype + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : assign_value_ args : (Tensor output, int[] shape, DataType dtype, Scalar[] values, Place place = {}) @@ -135,6 +132,22 @@ param : [shape, dtype, values] data_type : dtype backend : place > output + interfaces : paddle::dialect::InferSymbolicShapeInterface + +- op : barrier + args : (Tensor x, int ring_id=0) + output : Tensor(out) + kernel : + func : barrier + +- op : batch_fc + args : (Tensor input, Tensor w, Tensor bias) + output : Tensor(out) + infer_meta: + func : BatchFCInferMeta + kernel : + func : batch_fc + data_type: input - op : batch_norm args : (Tensor x, Tensor mean, Tensor variance, Tensor scale, Tensor bias, bool is_test, float momentum, float epsilon, str data_format, bool use_global_stats, bool trainable_statistics) @@ -157,6 +170,16 @@ kernel : func : c_allgather +- op : c_allreduce_avg + args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) + output : Tensor(out) + infer_meta : + func : AllReduceInferMeta + param : [x] + kernel : + func : c_allreduce_avg + inplace : (x -> out) + - op : c_allreduce_max args : (Tensor x, int ring_id, bool use_calc_stream, bool use_model_parallel) output : Tensor(out) @@ -237,6 +260,26 @@ func : c_identity inplace : (x -> out) +- op : c_reduce_avg + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_avg + inplace : (x -> out) + +- op : c_reduce_max + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_max + inplace : (x -> out) + - op : c_reduce_min args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) output : Tensor(out) @@ -247,6 +290,16 @@ func : c_reduce_min inplace : (x -> out) +- op : c_reduce_prod + args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) + output : Tensor(out) + infer_meta : + func : DistReduceInferMeta + param : [x] + kernel : + func : c_reduce_prod + inplace : (x -> out) + - op : c_reduce_sum args : (Tensor x, int ring_id, int root_id, bool use_calc_stream) output : Tensor(out) @@ -267,6 +320,24 @@ func : reduce_scatter param: [x, nranks] +- op : c_scatter + args : (Tensor x, int ring_id = 0, int root = 0, int nranks = 0, bool use_calc_stream = false) + output : Tensor(out) + infer_meta : + func : CScatterInferMeta + param : [x, nranks] + kernel : + func : c_scatter + +- op : c_split + args : (Tensor x, int rank = 0, int nranks = 1, int ring_id = 0, bool use_calc_stream = false, bool use_model_parallel = true) + output : Tensor(out) + infer_meta : + func : CSplitInferMeta + param : [x, nranks] + kernel : + func : c_split + - op : c_sync_calc_stream args : (Tensor x) output : Tensor(out) @@ -310,6 +381,16 @@ func : channel_shuffle backward : channel_shuffle_grad +- op : coalesce_tensor_ + args : (Tensor[] input, DataType dtype, bool copy_data = false, bool set_constant = false, bool persist_output = false, float constant = 0.0, bool use_align = true, int align_size = -1, int size_of_dtype = -1, int64_t[] concated_shapes = {}, int64_t[] concated_ranks = {}) + output : Tensor[](output){input.size()}, Tensor(fused_output) + infer_meta : + func : CoalesceTensorInferMeta + kernel : + func : coalesce_tensor + data_type : dtype + inplace: (input -> output) + - op : conv2d_transpose args : (Tensor x, Tensor filter, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") output : Tensor(out) @@ -320,6 +401,16 @@ data_type : x backward : conv2d_transpose_grad +- op : conv2d_transpose_bias + args : (Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") + output : Tensor(out) + infer_meta : + func : Conv2dTransposeInferMeta + param: [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format] + kernel : + func : conv2d_transpose_bias + data_type : x + - op : copy_to args : (Tensor x, Place place, bool blocking) output : Tensor(out) @@ -400,6 +491,16 @@ data_type : fpn_rois optional : rois_num, multi_level_rois_num +- op : distributed_fused_lamb_init + args : (Tensor[] param, Tensor[] grad, float beta1, float beta2, int[] apply_weight_decay, int alignment, int rank, int nranks) + output : Tensor(fp32_fused_param), Tensor(fp32_fused_grad), Tensor(fp16_fused_param), Tensor(fp16_fused_grad), Tensor(moment1), Tensor(moment2), Tensor(beta1_pow), Tensor(beta2_pow), Tensor(fused_param_offsets), Tensor(fp32_shard_fused_param_offsets), Tensor(fp16_shard_fused_param_offsets), Tensor(param_info), Tensor(param_order), Tensor[](param_out){param.size()}, Tensor[](master_param_out){param.size()}, Tensor[](grad_out){grad.size()}, Tensor(global_scale), Tensor(step) + infer_meta : + func : DistributedFusedLambInitInferMeta + kernel : + func : distributed_fused_lamb_init + optional : fp32_fused_param, fp32_fused_grad, fp16_fused_param, fp16_fused_grad + inplace: (param -> param_out), (grad -> grad_out) + - op : distributed_lookup_table args : (Tensor[] ids, Tensor w, int table_id = 0, bool is_distributed = false, str lookup_table_version = "lookup_table", int64_t padding_idx = -1, DataType dtype = DataType::FLOAT32, bool is_test = false) output : Tensor[](outputs){ids.size()} @@ -409,6 +510,15 @@ func : distributed_lookup_table data_type : dtype +- op : distributed_push_sparse + args : (Tensor[] ids, Tensor[] shows, Tensor[] clicks, int table_id = 0, int size = 8, bool is_distributed = false, str push_sparse_version = "push_sparse", int64_t padding_idx = -1, DataType dtype=DataType::FLOAT32, bool is_test = false, bool use_cvm_op = false) + output : Tensor[](output){ids.size()} + infer_meta : + func : DistributedPushSparseInferMeta + kernel : + func: distributed_push_sparse + data_type : dtype + - op : divide args : (Tensor x, Tensor y) output : Tensor(out) @@ -628,6 +738,7 @@ infer_meta : func : CreateLikeInferMeta param : [x, dtype] + spmd_rule : FullLikeInferSpmd kernel : func : full_like param : [x, value, dtype] @@ -655,7 +766,7 @@ kernel : func : fused_adam data_type : params - optional : skip_update, master_params + optional : skip_update, master_params, master_params_out inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) - op : fused_batch_norm_act @@ -682,6 +793,16 @@ view : (mean -> mean_out), (variance -> variance_out) backward : fused_bn_add_activation_grad +- op : fused_multi_transformer + args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1) + optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs + output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out) + infer_meta : + func : FusedMultiTransformerInferMeta + kernel : + func : fused_multi_transformer + data_type : x + - op : fused_softmax_mask args : (Tensor x, Tensor mask) output : Tensor(out) @@ -701,6 +822,14 @@ func : fused_softmax_mask_upper_triangle backward: fused_softmax_mask_upper_triangle_grad +- op : fused_token_prune + args : (Tensor attn, Tensor x, Tensor mask, Tensor new_mask, bool keep_first_token = true, bool keep_order = false) + output : Tensor(slimmed_x), Tensor(cls_inds) + infer_meta : + func : FusedTokenPruneInferMeta + kernel: + func : fused_token_prune + - op : gaussian args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={}) output: Tensor(out) @@ -722,6 +851,15 @@ kernel: func: get_tensor_from_selected_rows {selected_rows -> dense} +- op : global_scatter + args : (Tensor x, Tensor local_count, Tensor global_count, int ring_id=0, bool use_calc_stream=false) + output : Tensor(out) + infer_meta : + func : GlobalScatterInferMeta + kernel : + func : global_scatter + data_type : x + - op : greater_equal args : (Tensor x, Tensor y) output : Tensor(out) @@ -801,6 +939,15 @@ inplace: (x -> out) interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : limit_by_capacity + args : (Tensor expert_count, Tensor capacity, int n_worker) + output : Tensor(out) + infer_meta : + func : LimitByCapacityInferMeta + kernel : + func : limit_by_capacity + data_type : expert_count + - op : linspace args : (Tensor start, Tensor stop, Tensor number, DataType dtype, Place place) output : Tensor(out) @@ -1008,6 +1155,16 @@ backward : multiply_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : nop + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : nop + inplace: (x -> out) + interfaces : paddle::dialect::ParseKernelKeyInterface + - op : norm args : (Tensor x, int axis, float epsilon, bool is_test) output : Tensor(out), Tensor(norm) @@ -1058,6 +1215,44 @@ backward : pad_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : partial_allgather + args : (Tensor x, int nranks, int rank, int ring_id = 0, bool use_calc_stream = false) + output : Tensor(out) + infer_meta : + func: PartialAllgatherInferMeta + kernel : + func : partial_allgather + inplace : (x -> out) + +- op : partial_concat + args : (Tensor[] x, int start_index = 0, int length = -1) + output : Tensor(out) + infer_meta : + func : PartialConcatInferMeta + kernel : + func : partial_concat + data_type : x + backward : partial_concat_grad + +- op : partial_recv + args : (int ring_id = 0, int peer = 0, DataType dtype=DataType::FLOAT32, int[] out_shape= {}, bool use_calc_stream = false, int num = 1, int id = 0) + output : Tensor(out) + infer_meta : + func: PartialRecvInferMeta + kernel : + func : partial_recv + data_type : dtype + +- op : partial_sum + args : (Tensor[] x, int start_index = 0, int length = -1) + output : Tensor(out) + infer_meta : + func : PartialSumInferMeta + kernel : + func : partial_sum + data_type : x + backward : partial_sum_grad + - op : pool2d args : (Tensor x, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) output : Tensor(out) @@ -1089,6 +1284,7 @@ kernel : func : print_kernel param: [in, first_n, message, summarize, print_tensor_name, print_tensor_type, print_tensor_shape, print_tensor_layout, print_tensor_lod, print_phase, is_forward] + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : prod args : (Tensor x, IntArray dims, bool keep_dim, bool reduce_all) @@ -1100,6 +1296,25 @@ backward : prod_grad interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : prune_gate_by_capacity + args : (Tensor gate_idx, Tensor expert_count, int64_t n_expert, int64_t n_worker) + output : Tensor(new_gate_idx) + infer_meta : + func : PruneGateByCapacityInferMeta + kernel : + func : prune_gate_by_capacity + data_type : gate_idx + +- op : push_dense + args : (Tensor[] ids, int table_id = -1, float scale_data_norm = -1.0f, str[] input_names = {}) + output : + infer_meta : + func : PushDenseInferMeta + param : [ids, table_id, scale_data_norm, input_names] + kernel : + func : push_dense + data_type : DataType::FLOAT32 + - op : push_sparse_v2 args : (Tensor[] ids, Tensor[] w, Tensor[] out_grad_in, int embeddingdim = 11, int tableid = 0, str accessorclass = "", str ctrlabelname = "", int paddingid = 0, bool scalesparsegrad = true, str[] inputnames = {}, bool is_distributed = true) output : Tensor[](out_grad_out){out_grad_in.size()} @@ -1137,6 +1352,15 @@ backend : place interfaces : paddle::dialect::InferSymbolicShapeInterface +- op : random_routing + args : (Tensor prob, Tensor topk_value, Tensor topk_idx) + output : Tensor(out) + infer_meta : + func : RandomRoutingInferMeta + kernel : + func : random_routing + data_type : dtype + - op : randperm args : (int n, DataType dtype, Place place={}) output : Tensor(out) @@ -1149,6 +1373,17 @@ data_type : dtype backend : place +- op : rank_attention + args : (Tensor x, Tensor rank_offset, Tensor rank_param, int max_rank = 3, int max_size = 0) + output : Tensor(input_help), Tensor(out), Tensor(ins_rank) + infer_meta : + func : RankAttentionInferMeta + kernel : + func : rank_attention + data_type : x + backward : rank_attention_grad + optional : ins_rank, input_help + - op : read_file args : (str filename = "", DataType dtype=DataType::UINT8, Place place=CPUPlace()) output : Tensor(out) @@ -1312,6 +1547,16 @@ func: shadow_feed param: [x] +- op : shadow_feed_tensors + args : (Tensor[] x) + output : Tensor[](out){x.size()} + infer_meta: + func: UnchangedVectorInferMeta + param: [x] + kernel: + func: shadow_feed_tensors + param: [x] + - op : share_data args : (Tensor x) output : Tensor(out) @@ -1362,6 +1607,7 @@ func : softmax inplace : (x -> out) backward : softmax_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : split args : (Tensor x, IntArray sections, Scalar(int) axis) @@ -1517,6 +1763,7 @@ func : triu inplace: (x -> out) backward : triu_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : triu_indices args : (int row, int col, int offset, DataType dtype, Place place={}) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml index 7b3068a8ab6c9..9ab68a7e52eb6 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml @@ -81,6 +81,17 @@ func : assign inplace : (out_grad -> x_grad) +- backward_op : batch_fc_grad + forward : batch_fc (Tensor input, Tensor w, Tensor bias) -> Tensor(out) + args : (Tensor input, Tensor w, Tensor bias, Tensor out_grad) + output : Tensor(input_grad), Tensor(w_grad), Tensor(bias_grad) + infer_meta : + func : BatchFCGradInferMeta + kernel : + func : batch_fc_grad + data_type : out_grad + no_need_buffer : bias + - backward_op : batch_norm_double_grad forward : batch_norm_grad (Tensor x, Tensor scale, Tensor bias, Tensor out_mean, Tensor out_variance, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor grad_out, float momentum, float epsilon, str data_format, bool is_test, bool use_global_stats, bool trainable_statistics) -> Tensor(grad_x), Tensor(grad_scale), Tensor(grad_bias) args : (Tensor x, Tensor scale, Tensor out_mean, Tensor out_variance, Tensor saved_mean, Tensor saved_variance, Tensor grad_out, Tensor grad_x_grad, Tensor grad_scale_grad, Tensor grad_bias_grad, float momentum, float epsilon, str data_format, bool is_test, bool use_global_stats, bool trainable_statistics) @@ -190,15 +201,15 @@ - backward_op : divide_double_grad forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) - args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) + args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad) infer_meta : func : GeneralTernaryGradInferMeta - param : [y, grad_x, grad_x] + param : [y, out, out] kernel : func : divide_double_grad data_type : out - optional : grad_x_grad, grad_y_grad + optional : grad_x, grad_x_grad, grad_y_grad inplace : (grad_x_grad -> grad_out_grad) - backward_op : divide_grad @@ -580,6 +591,26 @@ composite : pad_grad(x, out_grad, paddings, pad_value, x_grad) backward : pad_double_grad +- backward_op : partial_concat_grad + forward : partial_concat (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad, int start_index, int length) + output : Tensor[](x_grad){x.size()} + infer_meta : + func : PartialConcatGradInferMeta + param : [x] + kernel : + func : partial_concat_grad + +- backward_op : partial_sum_grad + forward : partial_sum (Tensor[] x, int start_index = 0, int length = -1) -> Tensor(out) + args : (Tensor[] x, Tensor out_grad, int start_index, int length) + output : Tensor[](x_grad){x.size()} + infer_meta : + func : PartialSumGradInferMeta + param : [x] + kernel : + func : partial_sum_grad + - backward_op : pool2d_double_grad forward : pool2d_grad(Tensor x, Tensor out, Tensor grad_out, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) -> Tensor(grad_x) args : (Tensor x, Tensor grad_x_grad, IntArray kernel_size, int[] strides, int[] paddings, bool ceil_mode, bool exclusive, str data_format, str pooling_type, bool global_pooling, bool adaptive, str padding_algorithm) @@ -626,6 +657,16 @@ func : prod_grad composite: prod_grad(x, out, out_grad, dims, keep_dim, reduce_all, x_grad) +- backward_op : rank_attention_grad + forward : rank_attention (Tensor x, Tensor rank_offset, Tensor rank_param, int max_rank = 3, int max_size = 0) -> Tensor(input_help), Tensor(out), Tensor(ins_rank) + args : (Tensor x, Tensor rank_offset, Tensor rank_param, Tensor input_help, Tensor ins_rank, Tensor out_grad, int max_rank = 3, int max_size = 0) + output : Tensor(rank_param_grad) + infer_meta : + func : RankAttentionGradInferMeta + kernel : + func : rank_attention_grad + data_type : out_grad + - backward_op : repeat_interleave_grad forward : repeat_interleave(Tensor x, int repeats, int axis) -> Tensor(out) args : (Tensor x, Tensor out_grad, int repeats, int axis) diff --git a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml index 5af2b7e13d0d8..f13b066d335be 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml @@ -15,7 +15,8 @@ - op : abs_grad -# - op : add_n +- op : add_n + extra_args : str mkldnn_data_type="float32" - op : batch_norm extra_args : bool fuse_with_relu=false @@ -51,6 +52,14 @@ extra_args : bool is_test=false data_format_tensors : input, out_grad +- op : conv2d_transpose + extra_args : bool is_test=false + data_format_tensors : x + +- op : conv2d_transpose_bias + extra_args : bool is_test=false, bool force_fp32_output = false, str mkldnn_data_type = "float32", bool fuse_relu = false, str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f + data_format_tensors : x + - op : conv3d extra_args : bool is_test=false data_format_tensors : input @@ -61,9 +70,11 @@ - op : depthwise_conv2d extra_args : bool is_test=false, bool fuse_relu_before_depthwise_conv=false, bool use_quantizer=false, str mkldnn_data_type="float32", bool fuse_relu=false, str fuse_activation="", float fuse_alpha=0.0, float fuse_beta=0.0, bool use_addto=false, bool fuse_residual_connection=false, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}, bool force_fp32_output=false + data_format_tensors : input - op : depthwise_conv2d_grad extra_args : bool is_test=false, bool fuse_relu_before_depthwise_conv=false, bool use_quantizer=false, str mkldnn_data_type="float32", bool fuse_relu=false, str fuse_activation="", float fuse_alpha=0.0, float fuse_beta=0.0, bool use_addto=false, bool fuse_residual_connection=false, float scale_in=1.0, float scale_out=1.0, float scale_in_eltwise=1.0, float[] scale_weights={1.0f}, bool force_fp32_output=false + data_format_tensors : input, out_grad - op : divide @@ -110,16 +121,19 @@ - op : fused_elementwise_sub -# - op : fused_matmul +- op : fused_matmul -# - op : fused_softplus +- op : fused_softplus -# - op : fused_transpose +- op : fused_transpose + extra_args : str data_format="AnyLayout", str mkldnn_data_type="float32" + data_format_tensors : x - op : fusion_gru extra_args : str mkldnn_data_type="float32", float scale_data=1.0, float shift_data=0.0, float[] scale_weights={1.0f} -# - op : fusion_lstm +- op : fusion_lstm + extra_args : str mkldnn_data_type="float32" - op : gaussian @@ -187,6 +201,7 @@ - op : multiply_grad - op : nearest_interp + data_format_tensors : x - op : pad @@ -234,9 +249,7 @@ - op : scale -- op : sgd - -# - op : sgd_dense_param_sparse_grad +- op : sgd_ - op : shape extra_args : str mkldnn_data_type="float32" @@ -247,9 +260,11 @@ - op : sigmoid_grad -# - op : slice +- op : slice + extra_args : str mkldnn_data_type="float32" -# - op : slice_grad +- op : slice_grad + extra_args : str mkldnn_data_type="float32" - op : softmax extra_args : str data_format="AnyLayout", str mkldnn_data_type="float32", bool is_test=false @@ -261,9 +276,10 @@ - op : softplus -# - op : split +- op : split + extra_args : str mkldnn_data_type="float32" -# - op : split_with_num +- op : split_with_num - op : sqrt @@ -275,7 +291,7 @@ - op : squeeze_grad extra_args : str mkldnn_data_type="float32" -# - op : stack +- op : stack - op : subtract @@ -297,6 +313,10 @@ - op : tanh_grad -# - op : transpose +- op : transpose + extra_args : str data_format="AnyLayout", str mkldnn_data_type="float32" + data_format_tensors : x -# - op : transpose_grad +- op : transpose_grad + extra_args : str data_format="AnyLayout", str mkldnn_data_type="float32" + data_format_tensors : out_grad diff --git a/paddle/fluid/pir/dialect/operator/ir/type_storage.h b/paddle/fluid/pir/dialect/operator/ir/type_storage.h index 375bef9799d6c..95b68a3370714 100644 --- a/paddle/fluid/pir/dialect/operator/ir/type_storage.h +++ b/paddle/fluid/pir/dialect/operator/ir/type_storage.h @@ -17,6 +17,7 @@ #include #include "paddle/phi/core/tensor_meta.h" +#include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/builtin_type_storage.h" #include "paddle/pir/include/core/type.h" #include "paddle/pir/include/core/type_base.h" @@ -166,5 +167,239 @@ struct DenseTensorArrayTypeStorage : public pir::TypeStorage { phi::DataLayout layout_; }; +struct SparseCooTensorTypeStorage : public pir::TypeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = std::tuple; + SparseCooTensorTypeStorage(pir::Type dtype, + common::DDim dims, + common::DDim non_zero_dims, + common::DataLayout layout, + pir::DenseTensorType non_zero_indices, + pir::DenseTensorType non_zero_elements, + bool coalesced = false) + : dtype_(dtype), + dims_(dims), + non_zero_dims_(non_zero_dims), + layout_(layout), + non_zero_indices_(non_zero_indices), + non_zero_elements_(non_zero_elements), + coalesced_(coalesced) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static SparseCooTensorTypeStorage* Construct(const ParamKey& key) { + return new SparseCooTensorTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key), + std::get<5>(key), + std::get<6>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + std::size_t hash_value = 0; + // hash dtype + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<1>(key))); + // hash non_zero_dims + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<2>(key))); + // hash layout + hash_value = pir::detail::hash_combine( + hash_value, + std::hash::type>()( + static_cast::type>( + std::get<3>(key)))); + // hash DenseTensorType + auto tuple1 = std::make_tuple(std::get<4>(key).dtype(), + std::get<4>(key).dims(), + std::get<4>(key).data_layout(), + std::get<4>(key).lod(), + std::get<4>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple1)); + // hash DenseTensorType + auto tuple2 = std::make_tuple(std::get<5>(key).dtype(), + std::get<5>(key).dims(), + std::get<5>(key).data_layout(), + std::get<5>(key).lod(), + std::get<5>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple2)); + // hash coalesced + hash_value = pir::detail::hash_combine(hash_value, + std::hash()(std::get<6>(key))); + + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return ParamKey(dtype_, + dims_, + non_zero_dims_, + layout_, + non_zero_indices_, + non_zero_elements_, + coalesced_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, + dims_, + non_zero_dims_, + layout_, + non_zero_indices_, + non_zero_elements_, + coalesced_); + } + + /// + /// \brief SparseCooTensorTypeStorage include six parameters: dims, dtype, + /// layout, non_zero_indices_, non_zero_elements_,coalesced_. + /// + + pir::Type dtype_; + common::DDim dims_; + common::DDim non_zero_dims_; + common::DataLayout layout_{DataLayout::NCHW}; + pir::DenseTensorType non_zero_indices_; + pir::DenseTensorType non_zero_elements_; + bool coalesced_ = false; +}; + +struct SparseCsrTensorTypeStorage : public pir::TypeStorage { + /// + /// \brief Declare ParamKey according to parameter type. + /// + using ParamKey = std::tuple; + SparseCsrTensorTypeStorage(pir::Type dtype, + common::DDim dims, + common::DataLayout layout, + pir::DenseTensorType non_zero_crows, + pir::DenseTensorType non_zero_cols, + pir::DenseTensorType non_zero_elements) + : dtype_(dtype), + dims_(dims), + layout_(layout), + non_zero_crows_(non_zero_crows), + non_zero_cols_(non_zero_cols), + non_zero_elements_(non_zero_elements) {} + + /// + /// \brief Each derived TypeStorage must define a Construct method, which + /// StorageManager uses to construct a derived TypeStorage. + /// + static SparseCsrTensorTypeStorage* Construct(const ParamKey& key) { + return new SparseCsrTensorTypeStorage(std::get<0>(key), + std::get<1>(key), + std::get<2>(key), + std::get<3>(key), + std::get<4>(key), + std::get<5>(key)); + } + + /// + /// \brief Each derived TypeStorage must provide a HashValue method. + /// + static std::size_t HashValue(const ParamKey& key) { + std::size_t hash_value = 0; + // hash dtype + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<0>(key))); + // hash dims + hash_value = pir::detail::hash_combine( + hash_value, std::hash()(std::get<1>(key))); + // hash layout + hash_value = pir::detail::hash_combine( + hash_value, + std::hash::type>()( + static_cast::type>( + std::get<2>(key)))); + // hash DenseTensorType + auto tuple1 = std::make_tuple(std::get<3>(key).dtype(), + std::get<3>(key).dims(), + std::get<3>(key).data_layout(), + std::get<3>(key).lod(), + std::get<3>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple1)); + // hash DenseTensorType + auto tuple2 = std::make_tuple(std::get<4>(key).dtype(), + std::get<4>(key).dims(), + std::get<4>(key).data_layout(), + std::get<4>(key).lod(), + std::get<4>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple2)); + // hash DenseTensorType + auto tuple3 = std::make_tuple(std::get<5>(key).dtype(), + std::get<5>(key).dims(), + std::get<5>(key).data_layout(), + std::get<5>(key).lod(), + std::get<5>(key).offset()); + hash_value = pir::detail::hash_combine( + hash_value, DenseTensorTypeStorage::HashValue(tuple3)); + return hash_value; + } + + /// + /// \brief Each derived TypeStorage needs to overload operator==. + /// + bool operator==(const ParamKey& key) const { + return ParamKey(dtype_, + dims_, + layout_, + non_zero_crows_, + non_zero_cols_, + non_zero_elements_) == key; + } + + ParamKey GetAsKey() const { + return ParamKey(dtype_, + dims_, + layout_, + non_zero_crows_, + non_zero_cols_, + non_zero_elements_); + } + + /// + /// \brief SparseCsrTensorTypeStorage include six parameters: dims, dtype, + /// layout, non_zero_crows_,non_zero_cols_,non_zero_elements_. + /// + + pir::Type dtype_; + common::DDim dims_; + common::DataLayout layout_; + pir::DenseTensorType non_zero_crows_; + pir::DenseTensorType non_zero_cols_; + pir::DenseTensorType non_zero_elements_; +}; + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc index 7f84eac85bdb8..aeecd67bcf920 100644 --- a/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc +++ b/paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.cc @@ -153,8 +153,8 @@ std::unordered_map OpYamlInfoParser::GetInplaceIdMap() bool OpYamlInfoParser::HasView(const std::string& out_name) const { auto& view_info = std::get<3>(op_info_tuple_).view; - for (size_t i = 0; i < view_info.size(); i++) { - if (out_name == view_info[i].first) { + for (const auto& i : view_info) { + if (out_name == i.first) { return true; } } @@ -164,9 +164,9 @@ bool OpYamlInfoParser::HasView(const std::string& out_name) const { const std::string& OpYamlInfoParser::ViewName( const std::string& out_name) const { auto& view_info = std::get<3>(op_info_tuple_).view; - for (size_t i = 0; i < view_info.size(); i++) { - if (out_name == view_info[i].first) { - return view_info[i].second; + for (const auto& i : view_info) { + if (out_name == i.first) { + return i.second; } } PADDLE_THROW(phi::errors::PreconditionNotMet( @@ -232,7 +232,7 @@ int OpYamlInfoParser::GetTensorParamIndexByArgsName( kernel_fn_tensor_params_.end(), args_name); if (iter != kernel_fn_tensor_params_.end()) { - return std::distance(kernel_fn_tensor_params_.begin(), iter); + return std::distance(kernel_fn_tensor_params_.begin(), iter); // NOLINT } else { return -1; } diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.cc b/paddle/fluid/pir/dialect/operator/utils/utils.cc index 9b450977814b6..f9b6658e4c716 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.cc +++ b/paddle/fluid/pir/dialect/operator/utils/utils.cc @@ -37,19 +37,25 @@ namespace dialect { const std::unordered_set LegacyOpList = { LoadCombineOp::name(), + BatchFcOp::name(), + BatchFcGradOp::name(), CConcatOp::name(), CBroadcast_Op::name(), CSyncCalcStream_Op::name(), CSyncCommStream_Op::name(), + DistributedPushSparseOp::name(), FtrlOp::name(), FusedElemwiseAddActivationOp::name(), FusedElemwiseAddActivationGradOp::name(), + FusedTokenPruneOp::name(), DpsgdOp::name(), SendV2Op::name(), RecvV2Op::name(), CAllreduceProd_Op::name(), CAllreduceSumOp::name(), CAllreduceSum_Op::name(), + CAllreduceAvgOp::name(), + CAllreduceAvg_Op::name(), CReduceSumOp::name(), CReduceSum_Op::name(), CAllreduceMax_Op::name(), @@ -57,19 +63,27 @@ const std::unordered_set LegacyOpList = { CAllgatherOp::name(), CSoftmaxWithCrossEntropyOp::name(), CSoftmaxWithCrossEntropyGradOp::name(), + CSplitOp::name(), + PushDenseOp::name(), SeedOp::name(), ShareDataOp::name(), SparseMomentumOp::name(), GetTensorFromSelectedRowsOp::name(), TdmSamplerOp::name(), + RankAttentionOp::name(), + RankAttentionGradOp::name(), RowConvOp::name(), RowConvGradOp::name(), SoftReluOp::name(), SoftReluGradOp::name(), MatchMatrixTensorOp::name(), MatchMatrixTensorGradOp::name(), + PartialConcatOp::name(), + PartialConcatGradOp::name(), NceOp::name(), NceGradOp::name(), + PartialSumOp::name(), + PartialSumGradOp::name(), LrnOp::name(), LrnGradOp::name(), MovingAverageAbsMaxScaleOp::name(), @@ -84,10 +98,17 @@ const std::unordered_set LegacyOpList = { paddle::onednn::dialect::QuantizeOp::name(), paddle::onednn::dialect::RequantizeOp::name(), paddle::onednn::dialect::MultiGruOp::name(), + paddle::onednn::dialect::FusionLstmOp::name(), #endif + CReduceAvgOp::name(), + CReduceAvg_Op::name(), + CReduceMaxOp::name(), CReduceMinOp::name(), + CReduceProdOp::name(), + CScatterOp::name(), PushSparseV2Op::name(), - PartialSendOp::name()}; + PartialSendOp::name(), + PartialRecvOp::name()}; enum class AttrType { UNDEFINED = 0, @@ -139,123 +160,124 @@ static inline AttrType GetAttributeType(const pir::Attribute& attr) { } } -static std::unordered_map< - AttrType, - std::function> - kAttrCastMap = { - {AttrType::BOOL, - [](const pir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::FLOAT, - [](const pir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::DOUBLE, - [](const pir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::INT32, - [](const pir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::INT64, - [](const pir::Attribute& attr) { - return VariantType{attr.dyn_cast().data()}; - }}, - {AttrType::INT_ARRAY, - [](const pir::Attribute& attr) { - return VariantType{ - attr.dyn_cast() - .data() - .GetData()}; - }}, - {AttrType::STRING, - [](const pir::Attribute& attr) { - return VariantType{attr.dyn_cast().AsString()}; - }}, - {AttrType::DATA_TYPE, - [](const pir::Attribute& attr) { - return VariantType{ - attr.dyn_cast().data()}; - }}, - {AttrType::PLACE, - [](const pir::Attribute& attr) { - return VariantType{ - attr.dyn_cast().data()}; - }}, - {AttrType::ARRAY, - [](const pir::Attribute& attr) { - auto attr_vec = attr.dyn_cast().AsVector(); - if (attr_vec.empty()) { - return VariantType{std::vector()}; - } - AttrType element_type = GetAttributeType(attr_vec[0]); - - if (element_type == AttrType::BOOL) { - std::vector vec_bools; - vec_bools.reserve(attr_vec.size()); - for (auto vec_element : attr_vec) { - vec_bools.push_back( - vec_element.dyn_cast().data()); +template +static std::function GetAttrCast( + AttrType attr_type) { + std::unordered_map> + kAttrCastMap = { + {AttrType::BOOL, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().data()}; + }}, + {AttrType::FLOAT, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().data()}; + }}, + {AttrType::DOUBLE, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().data()}; + }}, + {AttrType::INT32, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().data()}; + }}, + {AttrType::INT64, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().data()}; + }}, + {AttrType::INT_ARRAY, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast() + .data() + .GetData()}; + }}, + {AttrType::STRING, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().AsString()}; + }}, + {AttrType::DATA_TYPE, + [](const pir::Attribute& attr) { + return T{ + attr.dyn_cast().data()}; + }}, + {AttrType::PLACE, + [](const pir::Attribute& attr) { + return T{attr.dyn_cast().data()}; + }}, + {AttrType::ARRAY, + [](const pir::Attribute& attr) { + auto attr_vec = attr.dyn_cast().AsVector(); + if (attr_vec.empty()) { + return T{std::vector()}; } - return VariantType{vec_bools}; - } else if (element_type == AttrType::INT32) { - std::vector vec_int32; - vec_int32.reserve(attr_vec.size()); - for (auto vec_element : attr_vec) { - vec_int32.push_back( - vec_element.dyn_cast().data()); + AttrType element_type = GetAttributeType(attr_vec[0]); + + if (element_type == AttrType::BOOL) { + std::vector vec_bools; + vec_bools.reserve(attr_vec.size()); + for (auto vec_element : attr_vec) { + vec_bools.push_back( + vec_element.dyn_cast().data()); + } + return T{vec_bools}; + } else if (element_type == AttrType::INT32) { + std::vector vec_int32; + vec_int32.reserve(attr_vec.size()); + for (auto vec_element : attr_vec) { + vec_int32.push_back( + vec_element.dyn_cast().data()); + } + return T{vec_int32}; + } else if (element_type == AttrType::INT64) { + std::vector vec_int64; + vec_int64.reserve(attr_vec.size()); + for (auto vec_element : attr_vec) { + vec_int64.push_back( + vec_element.dyn_cast().data()); + } + return T{vec_int64}; + } else if (element_type == AttrType::FLOAT) { + std::vector vec_float; + vec_float.reserve(attr_vec.size()); + for (auto vec_element : attr_vec) { + vec_float.push_back( + vec_element.dyn_cast().data()); + } + return T{vec_float}; + } else if (element_type == AttrType::DOUBLE) { + std::vector vec_double; + vec_double.reserve(attr_vec.size()); + for (auto vec_element : attr_vec) { + vec_double.push_back( + vec_element.dyn_cast().data()); + } + return T{vec_double}; + } else if (element_type == AttrType::STRING) { + std::vector vec_string; + vec_string.reserve(attr_vec.size()); + for (auto vec_element : attr_vec) { + vec_string.push_back( + vec_element.dyn_cast().AsString()); + } + return T{vec_string}; + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported ir Attribute type when casting it into " + "vector.")); } - return VariantType{vec_int32}; - } else if (element_type == AttrType::INT64) { - std::vector vec_int64; - vec_int64.reserve(attr_vec.size()); - for (auto vec_element : attr_vec) { - vec_int64.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_int64}; - } else if (element_type == AttrType::FLOAT) { - std::vector vec_float; - vec_float.reserve(attr_vec.size()); - for (auto vec_element : attr_vec) { - vec_float.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_float}; - } else if (element_type == AttrType::DOUBLE) { - std::vector vec_double; - vec_double.reserve(attr_vec.size()); - for (auto vec_element : attr_vec) { - vec_double.push_back( - vec_element.dyn_cast().data()); - } - return VariantType{vec_double}; - } else if (element_type == AttrType::STRING) { - std::vector vec_string; - vec_string.reserve(attr_vec.size()); - for (auto vec_element : attr_vec) { - vec_string.push_back( - vec_element.dyn_cast().AsString()); - } - return VariantType{vec_string}; - } else { - PADDLE_THROW(phi::errors::Unimplemented( - "Unsupported ir Attribute type when casting it into " - "vector.")); - } - }}, -}; + }}, + }; + return kAttrCastMap[attr_type]; +} VariantType GetAttributeData(const pir::Attribute& attr) { AttrType attr_type = GetAttributeType(attr); - return kAttrCastMap[attr_type](attr); + return GetAttrCast(attr_type)(attr); } paddle::any TransAttrToAny(const pir::Attribute& attr) { AttrType attr_type = GetAttributeType(attr); - return kAttrCastMap[attr_type](attr); + return GetAttrCast(attr_type)(attr); } bool IsLegacyOp(const std::string& name) { return LegacyOpList.count(name); } @@ -302,7 +324,9 @@ std::set GetRegisterDataType(const std::string& op_name) { data_type.insert(phi::DataTypeToString(info_pair.first.dtype())); } } - + if (data_type.empty()) { + VLOG(6) << "No data type is registered for " << op_name; + } return data_type; } @@ -323,16 +347,6 @@ phi::DataType GetValueDataType(const pir::Type& type) { } else { return phi::DataType::UNDEFINED; } - } else if (type.isa()) { - return dialect::TransToPhiDataType( - type.dyn_cast().dtype()); - } else if (type.isa()) { - return dialect::TransToPhiDataType( - type.dyn_cast().dtype()); - } else if (type.isa()) { - return dialect::TransToPhiDataType( - type.dyn_cast() - .dtype()); } else { PADDLE_THROW( phi::errors::InvalidType("Currently, we can only get dtype for " @@ -344,43 +358,7 @@ phi::DataType GetValueDataType(const pir::Value& value) { if (value.impl() == nullptr) { return phi::DataType::UNDEFINED; } - if (value.type().isa()) { - return dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); - } else if (value.type().isa()) { - return dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); - } else if (value.type().isa()) { - return dialect::TransToPhiDataType( - value.type().dyn_cast().dtype()); - } else if (value.type().isa()) { - auto vec_value = value.type().dyn_cast(); - if (vec_value.size() > 0) { - return GetValueDataType(vec_value[0]); - } else { - return phi::DataType::UNDEFINED; - } - } else if (value.type().isa()) { - return dialect::TransToPhiDataType( - value.type() - .dyn_cast() - .dtype()); - } else if (value.type().isa()) { - return dialect::TransToPhiDataType( - value.type() - .dyn_cast() - .dtype()); - } else if (value.type() - .isa()) { - return dialect::TransToPhiDataType( - value.type() - .dyn_cast() - .dtype()); - } else { - PADDLE_THROW( - phi::errors::InvalidType("Currently, we can only get dtype for " - "DenseTensorType and SelectedRowsType.")); - } + return GetValueDataType(value.type()); } void DoValueCheck(const pir::Value& value, @@ -512,17 +490,6 @@ std::vector ParseValueShape(const pir::Value& shape, } vec_shape = std::vector(shape_size, -1); *is_from_tensor = true; - } else if (shape.type().isa()) { - common::DDim shape_dim = - shape.type() - .dyn_cast() - .dims(); - size_t shape_size = common::product(shape_dim); - if (common::contain_unknown_dim(shape_dim)) { - shape_size = 1; - } - vec_shape = std::vector(shape_size, -1); - *is_from_tensor = true; } else { PADDLE_THROW( phi::errors::Unimplemented("Only support VectorType or DenseTensorType " @@ -531,5 +498,69 @@ std::vector ParseValueShape(const pir::Value& shape, return vec_shape; } +const std::unordered_map& CppTypeToAttrTypeMap() { + static const std::unordered_map attr_type_map = { + {"bool", "pir::BoolAttribute"}, + {"int", "pir::Int32Attribute"}, + {"float", "pir::FloatAttribute"}, + {"int64_t", "pir::Int64Attribute"}, + {"std::string", "pir::StrAttribute"}, + {"std::vector", "pir::ArrayAttribute"}, + {"std::vector", "pir::ArrayAttribute"}, + {"std::vector", "pir::ArrayAttribute"}, + {"std::vector", "pir::ArrayAttribute"}}; + return attr_type_map; +} + +const std::unordered_map& StringToDataTypeMap() { + static std::unordered_map data_type_map{ + {"bool", phi::DataType::BOOL}, + {"uint8", phi::DataType::UINT8}, + {"int8", phi::DataType::INT8}, + {"uint16", phi::DataType::UINT16}, + {"int16", phi::DataType::INT16}, + {"uint32", phi::DataType::UINT32}, + {"int32", phi::DataType::INT32}, + {"uint64", phi::DataType::UINT64}, + {"int64", phi::DataType::INT64}, + {"float32", phi::DataType::FLOAT32}, + {"complex64", phi::DataType::COMPLEX64}, + {"complex128", phi::DataType::COMPLEX128}, + {"Undefined", phi::DataType::UNDEFINED}, + {"psting", phi::DataType::PSTRING}, + {"float16", phi::DataType::FLOAT16}, + {"bfloat16", phi::DataType::BFLOAT16}, + {"float64", phi::DataType::FLOAT64}}; + return data_type_map; +} + +const std::unordered_map& StringToPlaceMap() { + static std::unordered_map place_map{ + {"cpu", phi::CPUPlace{}}, + {"gpu", phi::GPUPlace{}}, + {"gpu_pinned", phi::GPUPinnedPlace{}}, + {"xpu", phi::XPUPlace{}}, + {"ipu", phi::IPUPlace{}}, + {":", phi::CustomPlace{}}, + {"undefined", phi::Place{}}}; + return place_map; +} + +const std::unordered_map& +StringToDataLayoutMap() { + static std::unordered_map data_layout_map{ + {"NHWC", phi::DataLayout::kNHWC}, + {"NCHW", phi::DataLayout::kNCHW}, + {"Undefined", phi::DataLayout::kAnyLayout}, + {"ONEDNN", phi::DataLayout::ONEDNN}, + {"SPARSE_COO", phi::DataLayout::SPARSE_COO}, + {"SPARSE_CSR", phi::DataLayout::SPARSE_CSR}, + {"NDHWC", phi::DataLayout::kNDHWC}, + {"NCDHW", phi::DataLayout::kNCDHW}, + {"PSTRING_UNION", phi::DataLayout::PSTRING_UNION}, + {"STRIDED", phi::DataLayout::STRIDED}}; + return data_layout_map; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/utils/utils.h b/paddle/fluid/pir/dialect/operator/utils/utils.h index fd8ec68401b08..9402458477319 100644 --- a/paddle/fluid/pir/dialect/operator/utils/utils.h +++ b/paddle/fluid/pir/dialect/operator/utils/utils.h @@ -167,5 +167,13 @@ phi::DataType GetValueDataType(const pir::Value& value); std::vector ParseValueShape(const pir::Value& shape_, bool* is_from_tensor); +const std::unordered_map& CppTypeToAttrTypeMap(); + +const std::unordered_map& StringToDataTypeMap(); + +const std::unordered_map& StringToPlaceMap(); + +const std::unordered_map& StringToDataLayoutMap(); + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt index 512e3927004e4..b23774a431795 100644 --- a/paddle/fluid/pir/drr/CMakeLists.txt +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -54,7 +54,7 @@ add_custom_command( set(DRR_SRCS ${DRR_SRCS} ${pd_op_creator_file}) -if(WITH_CINN AND NOT CINN_ONLY) +if(WITH_CINN) set(cinn_op_yaml_file ${PADDLE_BINARY_DIR}/paddle/cinn/hlir/dialect/generated/ops.parsed.yaml) @@ -128,4 +128,4 @@ endif() cc_library( drr SRCS ${DRR_SRCS} - DEPS op_dialect_vjp ${CINN_DEPS} pir) + DEPS op_dialect_vjp ${CINN_DEPS} pir pir_general_functions) diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 1c5de89780c6f..d9b435160c41d 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -9,9 +9,9 @@ DRR can reduce the development cost of PASS, allowing developers to focus on pro Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows: ~~~ c++ // 1. Inherit class from DrPatternBase -class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundantCastPattern"; } // 2. Overload operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index e621e7112ac30..c01b21febeda3 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -9,9 +9,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P 以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下: ~~~ c++ // 1. 继承 DrrPatternBase 类 -class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundantCastPattern"; } // 2. 重载 operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { diff --git a/paddle/fluid/pir/drr/include/drr_pattern_context.h b/paddle/fluid/pir/drr/include/drr_pattern_context.h index af70dee24b8d4..b7755f659e85d 100644 --- a/paddle/fluid/pir/drr/include/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_context.h @@ -101,12 +101,12 @@ class Constraint { ConstraintFunction IsContextMatchConstraint_; }; -class DrrPatternContext { +class TEST_API DrrPatternContext { public: DrrPatternContext(); ~DrrPatternContext() = default; - TEST_API drr::SourcePattern SourcePattern(); + drr::SourcePattern SourcePattern(); std::shared_ptr source_pattern_graph() const { return source_pattern_graph_; @@ -122,20 +122,19 @@ class DrrPatternContext { friend class drr::SourcePattern; friend class drr::ResultPattern; - TEST_API const Op& SourceOpPattern( + const Op& SourceOpPattern( const std::string& op_type, const std::unordered_map& attributes = {}); - TEST_API const drr::Tensor& SourceTensorPattern(const std::string& name); + drr::Tensor& SourceTensorPattern(const std::string& name); - TEST_API const Op& ResultOpPattern( + const Op& ResultOpPattern( const std::string& op_type, const std::unordered_map& attributes = {}); - TEST_API drr::Tensor& ResultTensorPattern(const std::string& name); + drr::Tensor& ResultTensorPattern(const std::string& name); // void RequireEqual(const Attribute& first, const Attribute& second); void RequireEqual(const TensorShape& first, const TensorShape& second); - TEST_API void RequireEqual(const TensorDataType& first, - const TensorDataType& second); + void RequireEqual(const TensorDataType& first, const TensorDataType& second); void RequireNativeCall(const ConstraintFunction& custom_fn); std::shared_ptr source_pattern_graph_; @@ -147,17 +146,15 @@ class DrrPatternContext { class Op { public: - const std::string& name() const { return op_type_name_; } - - TEST_API void operator()(const Tensor& arg, const Tensor* out) const; + TEST_API const std::string& name() const { return op_type_name_; } TEST_API Tensor& operator()() const; - + TEST_API void operator()(const Tensor& arg, const Tensor* out) const; TEST_API Tensor& operator()(const Tensor& arg) const; TEST_API Tensor& operator()(const Tensor& arg0, const Tensor& arg1) const; - Tensor& operator()(const Tensor& arg0, - const Tensor& arg1, - const Tensor& arg2) const; + TEST_API Tensor& operator()(const Tensor& arg0, + const Tensor& arg1, + const Tensor& arg2) const; TEST_API void operator()(const std::vector& args, const std::vector& outputs) const; // const Tensor& operator()(const Tensor& arg0, const Tensor& arg1, const @@ -169,9 +166,6 @@ class Op { static const char* prefix; private: - friend class DrrPatternContext; - friend class OpCall; - Op(const std::string& op_type_name, const std::unordered_map& attributes, PatternGraph* pattern_graph) @@ -183,29 +177,37 @@ class Op { return attributes_; } - thread_local static int64_t count; + friend class DrrPatternContext; + friend class OpCall; std::string op_type_name_; std::unordered_map attributes_; PatternGraph* pattern_graph_{nullptr}; + + thread_local static int64_t count; }; -class Tensor { +class TEST_API Tensor { public: - static const char INPUT_NONE_TENSOR_NAME[]; - static const char OUTPUT_NONE_TENSOR_NAME[]; + static const char RESULT_INPUT_NONE_TENSOR_NAME[]; + static const char RESULT_OUTPUT_NONE_TENSOR_NAME[]; + static const char SOURCE_INPUT_NONE_TENSOR_NAME[]; + static const char SOURCE_OUTPUT_NONE_TENSOR_NAME[]; TensorShape shape() const { return TensorShape(name()); } TensorDataType dtype() const { return TensorDataType(name()); } bool is_none() const { - return name_ == INPUT_NONE_TENSOR_NAME || name_ == OUTPUT_NONE_TENSOR_NAME; + return name_ == RESULT_INPUT_NONE_TENSOR_NAME || + name_ == RESULT_OUTPUT_NONE_TENSOR_NAME || + name_ == SOURCE_INPUT_NONE_TENSOR_NAME || + name_ == SOURCE_OUTPUT_NONE_TENSOR_NAME; } - TEST_API void Assign(const Tensor& other); + void Assign(const Tensor& other); - TEST_API void operator=(const Tensor& other) const; // NOLINT + void operator=(const Tensor& other) const; // NOLINT const std::string& name() const { return name_; } @@ -215,24 +217,26 @@ class Tensor { void set_producer(OpCall* producer) { producer_ = producer; } - const std::vector& consumers() const { return consumers_; } + const std::unordered_set& consumers() const { + return consumers_; + } - void AddConsumer(const OpCall* consumer) { consumers_.push_back(consumer); } + void AddConsumer(const OpCall* consumer) { consumers_.insert(consumer); } private: - friend class DrrPatternContext; - friend class Op; - Tensor(const std::string& name, PatternGraph* pattern_graph) : name_(name), pattern_graph_(pattern_graph) {} + friend class DrrPatternContext; + friend class Op; + std::string name_; OpCall* producer_{nullptr}; - std::vector consumers_; + std::unordered_set consumers_; PatternGraph* pattern_graph_{nullptr}; }; -class OpCall { +class TEST_API OpCall { public: OpCall(const Op* op, const std::vector& inputs, @@ -259,17 +263,13 @@ class OpCall { std::unordered_map attributes_; }; -class ResultPattern { +class TEST_API ResultPattern { public: const drr::Op& Op( const std::string& op_type, - const std::unordered_map& attributes = {}) { - return ctx_->ResultOpPattern(op_type, attributes); - } + const std::unordered_map& attributes = {}); - drr::Tensor& Tensor(const std::string& name) { - return ctx_->ResultTensorPattern(name); - } + drr::Tensor& Tensor(const std::string& name); // Represent the input tensor which is none. // Example: @@ -278,9 +278,7 @@ class ResultPattern { // When scale is none, we can write a instance_norm op in drr as follow: // res.Op("instance_norm")(res.Tensor("x"), res.InputNoneTensor(), // res.Tensor("bias")); - drr::Tensor& InputNoneTensor() { - return ctx_->ResultTensorPattern(Tensor::INPUT_NONE_TENSOR_NAME); - } + drr::Tensor& InputNoneTensor(); // Represent the output tensor which is none. // Example: @@ -288,59 +286,31 @@ class ResultPattern { // it may be none). We can write a reshape op in drr as follow: // res.Op("reshape")({res.Tensor("x")}, {res.Tensor("out"), // res.OutputNoneTensor()}); - drr::Tensor& OutputNoneTensor() { - return ctx_->ResultTensorPattern(Tensor::OUTPUT_NONE_TENSOR_NAME); - } + drr::Tensor& OutputNoneTensor(); - Attribute StrAttr(const std::string& value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> std::string { return value; }); - } + Attribute StrAttr(const std::string& value) const; - Attribute BoolAttr(bool value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> bool { return value; }); - } + Attribute BoolAttr(bool value) const; - Attribute Int32Attr(int32_t value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> int32_t { return value; }); - } + Attribute Int32Attr(int32_t value) const; - Attribute Int64Attr(int64_t value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> int64_t { return value; }); - } + Attribute Int64Attr(int64_t value) const; - Attribute Float32Attr(float value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> float { return value; }); - } + Attribute Float32Attr(float value) const; - Attribute VectorInt64Attr(const std::vector& value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> std::vector { - return value; - }); - } + Attribute VectorInt64Attr(const std::vector& value) const; - Attribute VectorInt32Attr(const std::vector& value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> std::vector { - return value; - }); - } + Attribute VectorInt32Attr(const std::vector& value) const; - Attribute VectorFloatAttr(const std::vector& value) const { - return ComputeAttr( - [=](const MatchContext& match_ctx) -> std::vector { - return value; - }); - } + Attribute VectorFloatAttr(const std::vector& value) const; - Attribute ComputeAttr(const AttrComputeFunc& attr_compute_func) const { - return ComputeAttribute(attr_compute_func); - } + Attribute DataTypeAttr(const std::string& value) const; + + Attribute PlaceAttr(const std::string& value) const; + + Attribute DataLayoutAttr(const std::string& value) const; + + Attribute ComputeAttr(const AttrComputeFunc& attr_compute_func) const; private: friend class SourcePattern; @@ -350,34 +320,29 @@ class ResultPattern { DrrPatternContext* ctx_{nullptr}; }; -class SourcePattern { +class TEST_API SourcePattern { public: - drr::ResultPattern ResultPattern() const { return drr::ResultPattern(ctx_); } + drr::ResultPattern ResultPattern() const; const drr::Op& Op( const std::string& op_type, - const std::unordered_map& attributes = {}) { - return ctx_->SourceOpPattern(op_type, attributes); - } + const std::unordered_map& attributes = {}); - const drr::Tensor& Tensor(const std::string& name) { - return ctx_->SourceTensorPattern(name); - } + const drr::Tensor& Tensor(const std::string& name); - Attribute Attr(const std::string& attr_name) const { - return NormalAttribute(attr_name); - } + Attribute Attr(const std::string& attr_name) const; - void RequireEqual(const TensorShape& first, const TensorShape& second) { - ctx_->RequireEqual(first, second); - } - void RequireEqual(const TensorDataType& first, const TensorDataType& second) { - ctx_->RequireEqual(first, second); - } + void RequireEqual(const TensorShape& first, const TensorShape& second); - void RequireNativeCall(const ConstraintFunction& custom_fn) { - ctx_->RequireNativeCall(custom_fn); - } + void RequireEqual(const TensorDataType& first, const TensorDataType& second); + + void RequireNativeCall(const ConstraintFunction& custom_fn); + + // Same as a ResultPattern::InputNoneTensor + drr::Tensor& InputNoneTensor(); + + // Same as a ResultPattern::OutputNoneTensor + drr::Tensor& OutputNoneTensor(); private: friend class DrrPatternContext; diff --git a/paddle/fluid/pir/drr/src/attr_type_uilts.h b/paddle/fluid/pir/drr/src/attr_type_uilts.h index 02f5a4defc155..a6b08b8054195 100644 --- a/paddle/fluid/pir/drr/src/attr_type_uilts.h +++ b/paddle/fluid/pir/drr/src/attr_type_uilts.h @@ -37,18 +37,20 @@ PD_SPECIALIZE_CppTypeToIrAttribute(int32_t, pir::Int32Attribute); PD_SPECIALIZE_CppTypeToIrAttribute(int64_t, pir::Int64Attribute); PD_SPECIALIZE_CppTypeToIrAttribute(float, pir::FloatAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::string, pir::StrAttribute); -PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, - paddle::dialect::DataTypeAttribute); -PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, paddle::dialect::IntArrayAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataType, + paddle::dialect::DataTypeAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(phi::DataLayout, + paddle::dialect::DataLayoutAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray, paddle::dialect::IntArrayAttribute); template -struct IrAttrbuteCreator { +struct IrAttributeCreator { typename CppTypeToIrAttribute::type operator()(T obj) const { return CppTypeToIrAttribute::type::template get( pir::IrContext::Instance(), obj); @@ -56,7 +58,7 @@ struct IrAttrbuteCreator { }; template <> -struct IrAttrbuteCreator> { +struct IrAttributeCreator> { pir::ArrayAttribute operator()(std::vector obj) const { std::vector attr_vec; attr_vec.reserve(obj.size()); @@ -69,7 +71,7 @@ struct IrAttrbuteCreator> { }; template <> -struct IrAttrbuteCreator> { +struct IrAttributeCreator> { pir::ArrayAttribute operator()(std::vector obj) const { std::vector attr_vec; attr_vec.reserve(obj.size()); diff --git a/paddle/fluid/pir/drr/src/ir_operation_factory.cc b/paddle/fluid/pir/drr/src/ir_operation_factory.cc index f792ccbdaff92..e625db38d1b8f 100644 --- a/paddle/fluid/pir/drr/src/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/src/ir_operation_factory.cc @@ -14,15 +14,20 @@ #include +#include "paddle/common/layout.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/fluid/pir/drr/src/attr_type_uilts.h" #include "paddle/fluid/pir/drr/src/ir_operation_factory.h" #include "paddle/phi/core/enforce.h" +#include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/core/operation.h" #include "paddle/pir/include/core/value.h" +#ifdef PADDLE_WITH_DNNL +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#endif namespace paddle { namespace drr { @@ -51,53 +56,264 @@ void OperationFactory::RegisterManualOpCreator() { return rewriter.Build(inputs); }); RegisterOperationCreator( - "pd_op.scale", + "builtin.slice", [](const std::vector& inputs, const pir::AttributeMap& attrs, pir::PatternRewriter& rewriter) { - return rewriter.Build( + return rewriter.Build( inputs[0], - inputs[1], - attrs.at("bias").dyn_cast().data(), - attrs.at("bias_after_scale").dyn_cast().data()); + attrs.at("index").dyn_cast().data()); + }); + RegisterOperationCreator( + "pd_op.scale", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + if (inputs.size() == 2) { + return rewriter.Build( + inputs[0], + inputs[1], + attrs.at("bias").dyn_cast().data(), + attrs.at("bias_after_scale") + .dyn_cast() + .data()); + } + return rewriter.Build(inputs[0], attrs); + }); + RegisterOperationCreator( + "pd_op.slice", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + if (inputs.size() == 3) { + PADDLE_ENFORCE_NE(attrs.find("axes"), + attrs.end(), + phi::errors::InvalidArgument( + "'axes' Attribute is expected for SliceOp. ")); + std::vector axes; + for (size_t i = 0; + i < attrs.at("axes").dyn_cast().size(); + i++) { + axes.push_back(attrs.at("axes") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + + PADDLE_ENFORCE_NE( + attrs.find("infer_flags"), + attrs.end(), + phi::errors::InvalidArgument( + "'infer_flags' Attribute is expected for SliceOp. ")); + std::vector infer_flags; + for (size_t i = 0; + i < + attrs.at("infer_flags").dyn_cast().size(); + i++) { + infer_flags.push_back(attrs.at("infer_flags") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + + PADDLE_ENFORCE_NE( + attrs.find("decrease_axis"), + attrs.end(), + phi::errors::InvalidArgument( + "'decrease_axis' Attribute is expected for SliceOp. ")); + std::vector decrease_axis; + for (size_t i = 0; + i < + attrs.at("decrease_axis").dyn_cast().size(); + i++) { + decrease_axis.push_back(attrs.at("decrease_axis") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + return rewriter.Build(inputs[0], + inputs[1], + inputs[2], + axes, + infer_flags, + decrease_axis); + } + return rewriter.Build(inputs[0], attrs); + }); +#ifdef PADDLE_WITH_DNNL + RegisterOperationCreator( + "onednn_op.conv2d_transpose_bias", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + if (inputs.size() == 4) { + PADDLE_ENFORCE_EQ( + attrs.find("strides") != attrs.end(), + true, + phi::errors::InvalidArgument("'strides' Attribute is expected " + "for Conv2dTransposeBiasOp. ")); + std::vector strides; + for (size_t i = 0; + i < attrs.at("strides").dyn_cast().size(); + i++) { + strides.push_back(attrs.at("strides") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + + PADDLE_ENFORCE_EQ( + attrs.find("paddings") != attrs.end(), + true, + phi::errors::InvalidArgument("'paddings' Attribute is expected " + "for Conv2dTransposeBiasOp. ")); + std::vector paddings; + for (size_t i = 0; + i < attrs.at("paddings").dyn_cast().size(); + i++) { + paddings.push_back(attrs.at("paddings") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + + PADDLE_ENFORCE_EQ(attrs.find("output_padding") != attrs.end(), + true, + phi::errors::InvalidArgument( + "'output_padding' Attribute is expected for " + "Conv2dTransposeBiasOp. ")); + std::vector output_padding; + for (size_t i = 0; i < attrs.at("output_padding") + .dyn_cast() + .size(); + i++) { + output_padding.push_back(attrs.at("output_padding") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + + PADDLE_ENFORCE_EQ(attrs.find("padding_algorithm") != attrs.end(), + true, + phi::errors::InvalidArgument( + "'padding_algorithm' Attribute is expected for " + "Conv2dTransposeBiasOp. ")); + std::string padding_algorithm = attrs.at("padding_algorithm") + .dyn_cast() + .AsString(); + + PADDLE_ENFORCE_EQ( + attrs.find("groups") != attrs.end(), + true, + phi::errors::InvalidArgument("'groups' Attribute is expected for " + "Conv2dTransposeBiasOp. ")); + int groups = + attrs.at("groups").dyn_cast().data(); + + PADDLE_ENFORCE_EQ( + attrs.find("dilations") != attrs.end(), + true, + phi::errors::InvalidArgument("'dilations' Attribute is expected " + "for Conv2dTransposeBiasOp. ")); + std::vector dilations; + for (size_t i = 0; + i < attrs.at("dilations").dyn_cast().size(); + i++) { + dilations.push_back(attrs.at("dilations") + .dyn_cast() + .at(i) + .dyn_cast() + .data()); + } + + PADDLE_ENFORCE_EQ(attrs.find("data_format") != attrs.end(), + true, + phi::errors::InvalidArgument( + "'data_format' Attribute is expected for " + "Conv2dTransposeBiasOp. ")); + std::string data_format = + attrs.at("data_format").dyn_cast().AsString(); + + PADDLE_ENFORCE_EQ( + attrs.find("is_test") != attrs.end(), + true, + phi::errors::InvalidArgument("'is_test' Attribute is expected " + "for Conv2dTransposeBiasOp. ")); + bool is_test = + attrs.at("is_test").dyn_cast().data(); + + return rewriter.Build( + inputs[0], + inputs[1], + inputs[2], + inputs[3], + strides, + paddings, + output_padding, + padding_algorithm, + groups, + dilations, + data_format, + is_test); + } + + return rewriter.Build( + inputs[0], inputs[1], inputs[2], attrs); }); +#endif } pir::Attribute CreateIrAttribute(const std::any& obj) { - if (obj.type() == typeid(bool)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(int32_t)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(int64_t)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(float)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(std::string)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(const char*)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(phi::DataType)) { - return IrAttrbuteCreator()( - std::any_cast(obj)); - } else if (obj.type() == typeid(phi::Place)) { - return IrAttrbuteCreator()(std::any_cast(obj)); - } else if (obj.type() == typeid(std::vector)) { - return IrAttrbuteCreator>()( - std::any_cast>(obj)); - } else if (obj.type() == typeid(std::vector)) { - return IrAttrbuteCreator>()( - std::any_cast>(obj)); - } else if (obj.type() == typeid(std::vector)) { - return IrAttrbuteCreator>()( - std::any_cast>(obj)); - } else if (obj.type() == typeid(phi::IntArray)) { - return IrAttrbuteCreator()( - std::any_cast(obj)); - } else { - PADDLE_THROW( - phi::errors::Unimplemented("Type error. CreateIrAttribute for type(%s) " - "is unimplemented CreateInCurrently.", - obj.type().name())); + try { + if (obj.type() == typeid(bool)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int32_t)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(int64_t)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(float)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(std::string)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(const char*)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(phi::DataType)) { + return IrAttributeCreator()( + std::any_cast(obj)); + } else if (obj.type() == typeid(phi::Place)) { + return IrAttributeCreator()(std::any_cast(obj)); + } else if (obj.type() == typeid(phi::DataLayout)) { + return IrAttributeCreator()( + std::any_cast(obj)); + } else if (obj.type() == typeid(std::vector)) { // NOLINT + return IrAttributeCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttributeCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttributeCreator>()( + std::any_cast>(obj)); + } else if (obj.type() == typeid(phi::IntArray)) { + return IrAttributeCreator()( + std::any_cast(obj)); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Type error. CreateIrAttribute for type(%s) " + "is unimplemented CreateInCurrently.", + obj.type().name())); + } + } catch (const std::bad_any_cast& e) { + PADDLE_THROW(phi::errors::Fatal( + "%s: CreateIrAttribute for type(%s) not successfully.", + e.what(), + obj.type().name())); } } diff --git a/paddle/fluid/pir/drr/src/ir_operation_factory.h b/paddle/fluid/pir/drr/src/ir_operation_factory.h index f0c78663de193..23095bf9a73e0 100644 --- a/paddle/fluid/pir/drr/src/ir_operation_factory.h +++ b/paddle/fluid/pir/drr/src/ir_operation_factory.h @@ -37,7 +37,7 @@ class OperationFactory { void RegisterOperationCreator(const std::string& op_name, const operation_create_fn& create_fn) { - op_creator_map.emplace(op_name, create_fn); + op_creator_map[op_name] = create_fn; } pir::Operation* CreateOperation( diff --git a/paddle/fluid/pir/drr/src/pattern_context.cc b/paddle/fluid/pir/drr/src/pattern_context.cc index f73115e96b44c..7bdee5d5dcafe 100644 --- a/paddle/fluid/pir/drr/src/pattern_context.cc +++ b/paddle/fluid/pir/drr/src/pattern_context.cc @@ -14,10 +14,14 @@ #include +#include "paddle/common/enforce.h" +#include "paddle/common/errors.h" +#include "paddle/common/layout.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/fluid/pir/drr/src/pattern_graph.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/phi/core/enforce.h" +#include "paddle/fluid/pir/utils/general_functions.h" +#include "paddle/phi/common/data_type.h" namespace paddle { namespace drr { @@ -39,8 +43,7 @@ const Op& DrrPatternContext::SourceOpPattern( return *owned_ops_.back(); } -const drr::Tensor& DrrPatternContext::SourceTensorPattern( - const std::string& name) { +drr::Tensor& DrrPatternContext::SourceTensorPattern(const std::string& name) { return source_pattern_graph_->AddTensor(std::shared_ptr( new drr::Tensor(name, source_pattern_graph_.get()))); } @@ -142,8 +145,14 @@ Tensor& Op::operator()() const { thread_local int64_t Op::count = 0; const char* Op::prefix = "@drr_temp@_"; -const char Tensor::INPUT_NONE_TENSOR_NAME[] = "__@input_none_tensor@__"; -const char Tensor::OUTPUT_NONE_TENSOR_NAME[] = "__@output_none_tensor@__"; +const char Tensor::SOURCE_INPUT_NONE_TENSOR_NAME[] = + "__@source_input_none_tensor@__"; +const char Tensor::SOURCE_OUTPUT_NONE_TENSOR_NAME[] = + "__@source_output_none_tensor@__"; +const char Tensor::RESULT_INPUT_NONE_TENSOR_NAME[] = + "__@result_input_none_tensor@__"; +const char Tensor::RESULT_OUTPUT_NONE_TENSOR_NAME[] = + "__@result_output_none_tensor@__"; void Tensor::Assign(const Tensor& other) { dynamic_cast(pattern_graph_)->AssignTensor(*this, other); @@ -154,14 +163,154 @@ void Tensor::operator=(const Tensor& other) const { // NOLINT PADDLE_ENFORCE_EQ( this->pattern_graph_, other.pattern_graph_, - phi::errors::InvalidArgument("Matching failed." - "Two Tensors must be in the same pattern " - "graph to make the '=' judgment.")); + common::errors::InvalidArgument("Matching failed." + "Two Tensors must be in the same pattern " + "graph to make the '=' judgment.")); if (other.name_.find(Op::prefix) == 0 && name_.find(Op::prefix) == std::string::npos) { other.pattern_graph_->UpdateTmpTensor(other.name_, this->name_); } } +const drr::Op& ResultPattern::Op( + const std::string& op_type, + const std::unordered_map& attributes) { + return ctx_->ResultOpPattern(op_type, attributes); +} + +drr::Tensor& ResultPattern::Tensor(const std::string& name) { + return ctx_->ResultTensorPattern(name); +} + +drr::Tensor& ResultPattern::InputNoneTensor() { + return ctx_->ResultTensorPattern(Tensor::RESULT_INPUT_NONE_TENSOR_NAME); +} + +drr::Tensor& ResultPattern::OutputNoneTensor() { + return ctx_->ResultTensorPattern(Tensor::RESULT_OUTPUT_NONE_TENSOR_NAME); +} + +Attribute ResultPattern::StrAttr(const std::string& value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> std::string { return value; }); +} + +Attribute ResultPattern::BoolAttr(bool value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> bool { return value; }); +} + +Attribute ResultPattern::Int32Attr(int32_t value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> int32_t { return value; }); +} + +Attribute ResultPattern::Int64Attr(int64_t value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> int64_t { return value; }); +} + +Attribute ResultPattern::Float32Attr(float value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> float { return value; }); +} + +Attribute ResultPattern::VectorInt64Attr( + const std::vector& value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> std::vector { + return value; + }); +} + +Attribute ResultPattern::VectorInt32Attr( + const std::vector& value) const { + return ComputeAttr( + [=](const MatchContext& match_ctx) -> std::vector { + return value; + }); +} + +Attribute ResultPattern::VectorFloatAttr( + const std::vector& value) const { + return ComputeAttr([=](const MatchContext& match_ctx) -> std::vector { + return value; + }); +} + +Attribute ResultPattern::DataTypeAttr(const std::string& value) const { + return ComputeAttr([=](const MatchContext& match_ctx) -> phi::DataType { + PADDLE_ENFORCE_EQ(dialect::StringToDataTypeMap().count(value) > 0, + true, + common::errors::InvalidArgument( + "The DataTypeAttr %s is not supported.", value)); + return dialect::StringToDataTypeMap().at(value); + }); +} + +Attribute ResultPattern::PlaceAttr(const std::string& value) const { + return ComputeAttr([=](const MatchContext& match_ctx) -> phi::Place { + PADDLE_ENFORCE_EQ(dialect::StringToPlaceMap().count(value) > 0, + true, + common::errors::InvalidArgument( + "The PlaceAttr %s is not supported.", value)); + return dialect::StringToPlaceMap().at(value); + }); +} + +Attribute ResultPattern::DataLayoutAttr(const std::string& value) const { + return ComputeAttr([=](const MatchContext& match_ctx) -> phi::DataLayout { + PADDLE_ENFORCE_EQ(dialect::StringToDataLayoutMap().count(value) > 0, + true, + common::errors::InvalidArgument( + "The DataLayoutAttr %s is not supported.", value)); + return dialect::StringToDataLayoutMap().at(value); + }); +} + +Attribute ResultPattern::ComputeAttr( + const AttrComputeFunc& attr_compute_func) const { + return ComputeAttribute(attr_compute_func); +} + +drr::ResultPattern SourcePattern::ResultPattern() const { + return drr::ResultPattern(ctx_); +} + +const drr::Op& SourcePattern::Op( + const std::string& op_type, + const std::unordered_map& attributes) { + return ctx_->SourceOpPattern(op_type, attributes); +} + +const drr::Tensor& SourcePattern::Tensor(const std::string& name) { + return ctx_->SourceTensorPattern(name); +} + +Attribute SourcePattern::Attr(const std::string& attr_name) const { + return NormalAttribute(attr_name); +} + +void SourcePattern::RequireEqual(const TensorShape& first, + const TensorShape& second) { + ctx_->RequireEqual(first, second); +} +void SourcePattern::RequireEqual(const TensorDataType& first, + const TensorDataType& second) { + ctx_->RequireEqual(first, second); +} + +void SourcePattern::RequireNativeCall(const ConstraintFunction& custom_fn) { + ctx_->RequireNativeCall(custom_fn); +} + +drr::Tensor& SourcePattern::InputNoneTensor() { + return ctx_->SourceTensorPattern(Tensor::SOURCE_INPUT_NONE_TENSOR_NAME); +} + +drr::Tensor& SourcePattern::OutputNoneTensor() { + return ctx_->SourceTensorPattern(Tensor::SOURCE_OUTPUT_NONE_TENSOR_NAME); +} + } // namespace drr } // namespace paddle diff --git a/paddle/fluid/pir/drr/src/pattern_graph.cc b/paddle/fluid/pir/drr/src/pattern_graph.cc index a8c72a064d0b8..a6b0e0a04067a 100644 --- a/paddle/fluid/pir/drr/src/pattern_graph.cc +++ b/paddle/fluid/pir/drr/src/pattern_graph.cc @@ -16,6 +16,7 @@ #include +#include "paddle/common/errors.h" #include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/phi/core/enforce.h" @@ -98,20 +99,6 @@ void PatternGraph::UpdateTmpTensor(const std::string &tmp_tensor_name, size_t PatternGraph::CountOfOpCalls() const { return owned_op_call_.size(); } -OpCall *SourcePatternGraph::AnchorNode() const { - for (const auto &output_tensor : output_tensors_) { - OpCall *output_op_candidate = - id2owned_tensor_.at(output_tensor)->producer(); - if (std::all_of(output_op_candidate->outputs().begin(), - output_op_candidate->outputs().end(), - [this](const Tensor *output) -> bool { - return this->output_tensors().count(output->name()); - })) - return output_op_candidate; - } - IR_THROW("Unable to find a valid anchor"); -} - std::unordered_set SourcePatternGraph::OutputNodes() const { std::unordered_set output_op_set; for (const auto &output_tensor : output_tensors_) { @@ -124,6 +111,10 @@ std::unordered_set SourcePatternGraph::OutputNodes() const { })) output_op_set.insert(output_op_candidate); } + if (output_op_set.empty()) { + PADDLE_THROW(common::errors::InvalidArgument( + "Unable to find a valid anchor in drr's source result pattern!")); + } return output_op_set; } @@ -147,8 +138,8 @@ void GraphTopo::WalkGraphNodesTopoOrder( const std::unordered_set &inputs_tensor = graph_->input_tensors(); const std::unordered_map> - &id2owned_tensor = graph_->id2owend_tensor(); - const std::vector> &owend_opcall = + &id2owned_tensor = graph_->id2owned_tensor(); + const std::vector> &owned_opcall = graph_->owned_op_call(); std::queue opcall_queue; @@ -156,7 +147,7 @@ void GraphTopo::WalkGraphNodesTopoOrder( opcall_dependent; // init opcall_dependent - for (const std::shared_ptr &opcall_sptr : owend_opcall) { + for (const std::shared_ptr &opcall_sptr : owned_opcall) { if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty opcall_queue.push(opcall_sptr.get()); } else { @@ -174,11 +165,11 @@ void GraphTopo::WalkGraphNodesTopoOrder( "The input tensor [%s] must exists " "in pattern graph to be obtained.", tensor_name)); - for (const auto &tensor_comsumer : + for (const auto &tensor_consumer : id2owned_tensor.at(tensor_name).get()->consumers()) { - opcall_dependent[tensor_comsumer].erase(tensor_name); - if (opcall_dependent[tensor_comsumer].empty()) { - opcall_queue.push(tensor_comsumer); + opcall_dependent[tensor_consumer].erase(tensor_name); + if (opcall_dependent[tensor_consumer].empty()) { + opcall_queue.push(tensor_consumer); } } } @@ -190,10 +181,10 @@ void GraphTopo::WalkGraphNodesTopoOrder( // update opcall_dependent for (const auto &output_tensor : opcall->outputs()) { - for (const auto &tensor_comsumer : output_tensor->consumers()) { - opcall_dependent[tensor_comsumer].erase(output_tensor->name()); - if (opcall_dependent[tensor_comsumer].empty()) { - opcall_queue.push(tensor_comsumer); + for (const auto &tensor_consumer : output_tensor->consumers()) { + opcall_dependent[tensor_consumer].erase(output_tensor->name()); + if (opcall_dependent[tensor_consumer].empty()) { + opcall_queue.push(tensor_consumer); } } } @@ -202,7 +193,7 @@ void GraphTopo::WalkGraphNodesTopoOrder( std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { os << "\nAll Tensors:\n"; - for (const auto &kv : pattern_graph.id2owend_tensor()) { + for (const auto &kv : pattern_graph.id2owned_tensor()) { os << " " << kv.first; } os << "\n\n"; diff --git a/paddle/fluid/pir/drr/src/pattern_graph.h b/paddle/fluid/pir/drr/src/pattern_graph.h index e5cd74b2fa217..fb9af1a781d25 100644 --- a/paddle/fluid/pir/drr/src/pattern_graph.h +++ b/paddle/fluid/pir/drr/src/pattern_graph.h @@ -57,7 +57,7 @@ class PatternGraph { } const std::unordered_map>& - id2owend_tensor() const { + id2owned_tensor() const { return id2owned_tensor_; } @@ -72,8 +72,6 @@ std::ostream& operator<<(std::ostream& os, const PatternGraph& pattern_graph); class SourcePatternGraph : public PatternGraph { public: - OpCall* AnchorNode() const; - std::unordered_set OutputNodes() const; private: diff --git a/paddle/fluid/pir/drr/src/rewrite_pattern.cc b/paddle/fluid/pir/drr/src/rewrite_pattern.cc index 68a7b14f81a3e..a5ea7ad074c9f 100644 --- a/paddle/fluid/pir/drr/src/rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/src/rewrite_pattern.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" @@ -33,7 +34,7 @@ DrrRewritePattern::DrrRewritePattern( pir::PatternBenefit benefit, const std::shared_ptr& drr_pattern_owner) : pir::RewritePattern( - drr_context.source_pattern_graph()->AnchorNode()->name(), + (*drr_context.source_pattern_graph()->OutputNodes().begin())->name(), benefit, context, {}), @@ -58,7 +59,7 @@ bool DrrRewritePattern::MatchAndRewrite( if (PatternGraphMatch(op, src_match_ctx.get())) { VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program."; PatternGraphRewrite(*src_match_ctx, rewriter); - VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewrited in program."; + VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewritten in program."; return true; } return false; @@ -67,7 +68,7 @@ bool DrrRewritePattern::MatchAndRewrite( bool DrrRewritePattern::PatternGraphMatch( pir::Operation* op, MatchContextImpl* source_pattern_match_ctx) const { VLOG(6) << "PatternGraphMatch Start: op(" << op->name() << ")"; - const OpCall* anchor = source_pattern_graph_->AnchorNode(); + const OpCall* anchor = *source_pattern_graph_->OutputNodes().begin(); std::unordered_map> bind_map = FindCandidateIrOutputOp(op, anchor, *(source_pattern_graph_.get())); @@ -257,95 +258,143 @@ bool DrrRewritePattern::MatchFromOutputToInput( std::unordered_set ir_visited; std::queue drr_q; std::queue ir_q; - bool matched = true; - size_t step = 0; - for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { - VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" - << it->second << ") in source_pattern_graph "; - drr_q.push(it->first); - drr_visited.insert(it->first); - ir_q.push(it->second); - ir_visited.insert(it->second); - } - while (!drr_q.empty()) { - if (!matched) break; - auto* drr_node = drr_q.front(); - auto* ir_node = ir_q.front(); - drr_q.pop(); - ir_q.pop(); + // Initialize DRR matched queue. + const auto& InitDrrQueue = [&]() -> void { + for (auto it = output_op_map.begin(); it != output_op_map.end(); ++it) { + VLOG(6) << "match (" << it->first->name() << " @" << it->first << " : @" + << it->second << ") in source_pattern_graph "; + drr_q.push(it->first); + drr_visited.insert(it->first); + ir_q.push(it->second); + ir_visited.insert(it->second); + } + }; + // Check whether DrrNode and Operation have the same Operands and Results + // information. + const auto& IsSameOperandsAndResults = + [](const OpCall* drr_node, const pir::Operation* ir_node) -> bool { if (drr_node->name() != ir_node->name()) { - matched = false; VLOG(8) << "Match failed: drr_node(" << drr_node->name() << ") != pir_node(" << ir_node->name() << ")."; - break; + return false; } const auto& drr_input_tensors = drr_node->inputs(); auto ir_input_value_size = ir_node->num_operands(); if (drr_input_tensors.size() != ir_input_value_size) { - matched = false; VLOG(8) << drr_node->name() << " Match failed: drr input tensors(" << drr_input_tensors.size() << ") != pir input tensors(" << ir_input_value_size << ")."; - break; + return false; } if (drr_node->outputs().size() != ir_node->num_results()) { - matched = false; VLOG(8) << drr_node->name() << " Match failed: drr output tensors(" << drr_node->outputs().size() << ") != pir output tensors(" << ir_node->num_results() << ")."; + return false; + } + return true; + }; + // Check whether source_pattern_match_ctx has visited Operation's Operands. + const auto& HasVisitedOperands = [&](const Tensor* drr_input_tensor, + pir::Value ir_value) -> bool { + const auto& tensor_name = drr_input_tensor->name(); + if (ir_value.isa()) { + VLOG(8) << "Match Attention! Found BlockArgument as input of " + << tensor_name; + } + return source_pattern_match_ctx->tensor_map().count(tensor_name) != 0 && + ir_value != source_pattern_match_ctx->tensor_map().at(tensor_name); + }; + // Update drr_q et.al information. Return false if faild. + const auto& TryUpdateDrrQueue = [&](const OpCall* drr_producer_op, + pir::Operation* ir_producer_op) -> bool { + // still return true if both visited. + if (drr_visited.count(drr_producer_op) && + ir_visited.count(ir_producer_op)) { + return true; + } + // insert map if both not visited. + if (!drr_visited.count(drr_producer_op) && + !ir_visited.count(ir_producer_op)) { + drr_q.push(drr_producer_op); + ir_q.push(ir_producer_op); + drr_visited.insert(drr_producer_op); + ir_visited.insert(ir_producer_op); + return true; + } + return false; + }; + // Check whether Drr Tensor and IR Value is None. + const auto& IsNoneTensorAndValue = [](const Tensor* drr_input_tensor, + pir::Value ir_value) { + return drr_input_tensor->is_none() && ir_value == nullptr; + }; + // Step 1: Initialize DRR matched queue. + bool matched = true; + size_t step = 0; + InitDrrQueue(); + + while (!drr_q.empty()) { + if (!matched) break; + auto* drr_node = drr_q.front(); + auto* ir_node = ir_q.front(); + drr_q.pop(); + ir_q.pop(); + if (!IsSameOperandsAndResults(drr_node, ir_node)) { + matched = false; break; } + // Step 1: Bind Operation of current op to match_ctx. source_pattern_match_ctx->BindIrOperation(drr_node, ir_node); - // binding input_tensor of current_op + + // Step 2: Bind input_tensor of current op to match_ctx. + const auto& drr_input_tensors = drr_node->inputs(); + auto ir_input_values = ir_node->operands_source(); for (size_t i = 0; i < drr_input_tensors.size(); ++i) { - if (source_pattern_match_ctx->tensor_map().count( - drr_input_tensors[i]->name()) != 0 && - ir_node->operand(i).source() != - source_pattern_match_ctx->tensor_map().at( - drr_input_tensors[i]->name())) { + if (drr_input_tensors[i]->is_none()) { + if (IsNoneTensorAndValue(drr_input_tensors[i], ir_input_values[i])) { + continue; + } else { + VLOG(8) << drr_node->name() << "Match failed:drr_input[" << i + << "] != pir_intput[" << i << "] , drr_input_tensor[" << i + << "] is None."; + matched = false; + break; + } + } + if (HasVisitedOperands(drr_input_tensors[i], ir_input_values[i])) { matched = false; VLOG(8) << " tensor_map key[" << drr_input_tensors[i]->name() << "] already exists,but value is different!"; break; - } else { - source_pattern_match_ctx->BindIrValue(drr_input_tensors[i]->name(), - ir_node->operand(i).source()); } - - if (ir_node->operand_source(i).isa()) { - VLOG(8) << "Match Attention! Found BlockArgument as input of " - << drr_node->name(); - } - + source_pattern_match_ctx->BindIrValue(drr_input_tensors[i]->name(), + ir_input_values[i]); + // Skip it while drr_producer_op is nullptr for trigger pattern boundary. auto* drr_producer_op = drr_input_tensors[i]->producer(); if (drr_producer_op == nullptr) { continue; } - + // Check whether tensor and value have the same use_count. if (drr_input_tensors[i]->consumers().size() != - ir_node->operand(i).source().use_count()) { + ir_input_values[i].use_count()) { matched = false; VLOG(8) << drr_node->name() << " Match failed: consumers of drr intput[" << i << "] { " << drr_node->outputs().size() << " } != consumers of pir intput[" << i << "] { " - << ir_node->operand(i).source().use_count() << " }."; + << ir_input_values[i].use_count() << " }."; break; } - auto* ir_producer_op = ir_node->operand_source(i).defining_op(); - // bfs producer_op of current_op - if (drr_visited.count(drr_producer_op) && - ir_visited.count(ir_producer_op)) { - continue; + auto* ir_producer_op = ir_input_values[i].defining_op(); + // Tigger early stop while operand is BlockArgument with + // producer_op==nullptr. + if (drr_producer_op && ir_producer_op == nullptr) { + matched = false; + break; } - - if (!drr_visited.count(drr_producer_op) && - !ir_visited.count(ir_producer_op)) { - drr_q.push(drr_producer_op); - ir_q.push(ir_producer_op); - drr_visited.insert(drr_producer_op); - ir_visited.insert(ir_producer_op); - } else { + // bfs producer_op of current_op + if (!TryUpdateDrrQueue(drr_producer_op, ir_producer_op)) { matched = false; VLOG(8) << "Match failed: status of visiting for" << drr_node->name() << " is different."; @@ -414,13 +463,13 @@ MatchContextImpl DrrRewritePattern::CreateOperations( // add input tensors info for res_match_ctx for (const auto& in_tensor : result_pattern_graph.input_tensors()) { PADDLE_ENFORCE_NE( - result_pattern_graph.id2owend_tensor().count(in_tensor), + result_pattern_graph.id2owned_tensor().count(in_tensor), 0, phi::errors::NotFound("Not found the input tensor." "Drr input tensor [%s] must exist in the result " "pattern graph to be obtained.", in_tensor)); - if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { + if (!result_pattern_graph.id2owned_tensor().at(in_tensor)->is_none()) { res_match_ctx.BindIrValue(in_tensor, src_match_ctx.GetIrValue(in_tensor)); } } @@ -436,7 +485,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( GraphTopo graph_topo_visit(&result_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { // set insert point - size_t max_input_op_index = 0; + size_t max_input_op_index = 0UL; pir::Operation* max_index_op = nullptr; for (const Tensor* input : op_call.inputs()) { if (input->is_none()) { @@ -446,7 +495,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( if (ir_val) { pir::Operation* ir_input_op = ir_val.defining_op(); if (op_2_temp_program_index.count(ir_input_op) == 0) { - max_input_op_index = 0UL; + // do nothing } else if (max_input_op_index < op_2_temp_program_index.at(ir_input_op)) { max_input_op_index = op_2_temp_program_index.at(ir_input_op); @@ -471,10 +520,10 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } if (max_input_op_index == 0UL) { VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; - pir::Operation* source_patter_first_op = src_match_ctx.IrOperation( + pir::Operation* source_pattern_first_op = src_match_ctx.IrOperation( source_pattern_graph.owned_op_call()[0].get()); - max_input_op_index = op_2_temp_program_index[source_patter_first_op]; - rewriter.set_insertion_point(source_patter_first_op); + max_input_op_index = op_2_temp_program_index[source_pattern_first_op]; + rewriter.set_insertion_point(source_pattern_first_op); } else { rewriter.SetInsertionPointAfter(max_index_op); } @@ -508,7 +557,7 @@ void DrrRewritePattern::ReplaceOutputTensor( const MatchContextImpl& res_match_ctx, pir::PatternRewriter& rewriter) const { // NOLINT for (const auto& output_name : result_pattern_graph_->output_tensors()) { - if (source_pattern_graph_->id2owend_tensor().count(output_name)) { + if (source_pattern_graph_->id2owned_tensor().count(output_name)) { const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); rewriter.ReplaceAllUsesWith(src_ir_tensor, res_ir_tensor); diff --git a/paddle/fluid/pir/transforms/CMakeLists.txt b/paddle/fluid/pir/transforms/CMakeLists.txt index bc2c3050fc2a5..627fcb78d8563 100644 --- a/paddle/fluid/pir/transforms/CMakeLists.txt +++ b/paddle/fluid/pir/transforms/CMakeLists.txt @@ -11,8 +11,19 @@ if(NOT WITH_MKLDNN) list(REMOVE_ITEM transforms_srcs ${onednn_srcs}) endif() -set(transforms_deps drr op_dialect op_dialect_vjp standalone_executor pir - device_event_base) +if(NOT WITH_XPU) + file(GLOB_RECURSE xpu_srcs "xpu/*.cc") + list(REMOVE_ITEM transforms_srcs ${xpu_srcs}) +endif() + +set(transforms_deps + drr + op_dialect + op_dialect_vjp + standalone_executor + pir + pir_general_functions + device_event_base) if(WITH_CINN) set(transforms_deps ${transforms_deps} cinn_op_dialect cinnapi) diff --git a/paddle/fluid/pir/transforms/build_cinn_pass.cc b/paddle/fluid/pir/transforms/build_cinn_pass.cc index 48c872c23b527..4daa4be6445b2 100644 --- a/paddle/fluid/pir/transforms/build_cinn_pass.cc +++ b/paddle/fluid/pir/transforms/build_cinn_pass.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/pir/transforms/build_cinn_pass.h" +#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/transforms/sub_graph_detector.h" #include "paddle/pir/include/core/builtin_op.h" @@ -24,28 +25,85 @@ namespace { using GroupOpsVec = std::vector; using CompatibleInfo = cinn::hlir::framework::pir::CompatibleInfo; +void VerifyOperationOrder(const pir::Block& block); + class BuildCinnPass : public pir::Pass { public: BuildCinnPass() : pir::Pass("build_cinn_pass", /*opt_level=*/1) {} void Run(pir::Operation* op) override { - auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "build_cinn_pass should run on module op."); - auto& block = module_op.block(); + for (uint32_t i = 0; i < op->num_regions(); ++i) { + for (auto& block : op->region(i)) { + ProcessBlock(&block); + VerifyOperationOrder(block); + } + } + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->num_regions() > 0 && !op->isa() && + !op->isa(); + } + private: + void ProcessBlock(pir::Block* block) { std::vector groups = - ::pir::SubgraphDetector(&block, CompatibleInfo::IsSupportCinn)(); + ::pir::SubgraphDetector(block, CompatibleInfo::IsSupportForCinn)(); AddStatistics(groups.size()); for (auto& group_ops : groups) { + if (group_ops.size() == 1 && group_ops[0]->name() == "pd_op.full") { + continue; + } VLOG(4) << "current group_ops.size(): " << group_ops.size(); - ::pir::ReplaceWithGroupOp(&block, group_ops); + ::pir::ReplaceWithGroupOp(block, group_ops); } } +}; - bool CanApplyOn(pir::Operation* op) const override { - return op->isa() && op->num_regions() > 0; +void VerifyOperationOrder(const pir::Block& block) { + auto order_info = + [&]() -> std::unordered_map { + std::unordered_map map; + // initialize the position index with block size by default. + const int64_t block_size = block.size(); + for (auto& op : block) map[&op] = block_size; + return map; + }(); + const auto& CheckOpOrder = [&](const pir::Operation* op) -> void { + const pir::Operation* current_op = op; + for (auto& value : op->operands_source()) { + if (!value || !value.defining_op()) continue; + pir::Operation* defining_op = value.defining_op(); + if (order_info.count(defining_op) == 0) continue; + if (op->GetParentOp() && + op->GetParentOp()->isa()) { + current_op = op->GetParentOp(); + } + CHECK(order_info.at(defining_op) < order_info.at(current_op)) + << "The order of operations is not correct!" + << " Received defining_op(" << defining_op->id() << " " + << order_info.at(defining_op) << ") is behind current_op(" + << current_op->id() << " " << order_info.at(current_op) << ")"; + } + }; + const auto& CheckGroupOpOrder = [&](pir::Operation* op) -> void { + auto group_op = op->dyn_cast(); + for (auto& inner_op : *group_op.block()) { + CheckOpOrder(&inner_op); + } + }; + + int64_t index = 0; + for (auto& op : block) { + order_info[&op] = index++; + if (op.isa()) { + CheckGroupOpOrder(&op); + } else { + CheckOpOrder(&op); + } } -}; +} + } // namespace namespace pir { diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc b/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc similarity index 96% rename from paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc rename to paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc index 4f5c4c0e4cd6b..4f076c3e8b247 100644 --- a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.cc +++ b/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/auto_mixed_precision_pass.h" +#include "paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h" + #include #include #include @@ -31,7 +32,7 @@ #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/phi/common/backend.h" #include "paddle/phi/common/bfloat16.h" @@ -60,17 +61,23 @@ class AutoMixedPrecisionPass : public pir::Pass { precision_mode_(phi::DataType::FLOAT16) {} bool Initialize(pir::IrContext* context) override { - IR_ENFORCE(Has(pir::kPlaceAttr), - "Pass initialize failed." - "When using AutoMixedPrecisionPass, place attribute is required!" - "Use Set method to set the place attribute."); - IR_ENFORCE(Has("__mixed_precision_mode__"), - "Pass initialize failed." - "When using AutoMixedPrecisionPass, precison_mode attribute is " - "required!" - "Use Set method to set the scope attribute."); - - place_ = Get(pir::kPlaceAttr); + PADDLE_ENFORCE_EQ( + Has(pir::Pass::kPlaceAttr), + true, + phi::errors::InvalidArgument( + "Pass initialize failed." + "When using AutoMixedPrecisionPass, place attribute is required!" + "Use Set method to set the place attribute.")); + PADDLE_ENFORCE_EQ( + Has("__mixed_precision_mode__"), + true, + phi::errors::InvalidArgument( + "Pass initialize failed." + "When using AutoMixedPrecisionPass, precision_mode attribute is " + "required!" + "Use Set method to set the scope attribute.")); + + place_ = Get(pir::Pass::kPlaceAttr); precision_mode_ = Get("__mixed_precision_mode__"); context_ = context; enable_low_precision_io_ = false; @@ -224,13 +231,13 @@ class AutoMixedPrecisionPass : public pir::Pass { precision_updated = true; } if (!OpRunLowPrecision(op)) continue; - // if the producer's output is in float VectorType, then the precsion + // if the producer's output is in float VectorType, then the precision // between two op should be the same for (size_t idx = 0; idx < op->num_operands(); ++idx) { if (!op->operand_source(idx)) continue; auto operand = op->operand(idx); if (operand.type() && operand.type().isa()) { - // check if there are all float in the vectortype + // check if there are all float in the vector type auto vec_type = operand.type().dyn_cast(); if (IsVectorTypeFloat(vec_type)) { auto input_operation = GetDefiningOpForInput(op, idx); diff --git a/paddle/fluid/pir/transforms/auto_mixed_precision_pass.h b/paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/auto_mixed_precision_pass.h rename to paddle/fluid/pir/transforms/general/auto_mixed_precision_pass.h diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/general/constant_folding_pass.cc similarity index 95% rename from paddle/fluid/pir/transforms/constant_folding_pass.cc rename to paddle/fluid/pir/transforms/general/constant_folding_pass.cc index d7834f9195bfd..bf1bc26850c56 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/general/constant_folding_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/constant_folding_pass.h" +#include "paddle/fluid/pir/transforms/general/constant_folding_pass.h" #include #include @@ -27,7 +27,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/common/errors.h" #include "paddle/phi/common/place.h" @@ -461,24 +461,27 @@ class ConstantFoldingPatternForTrain : public ConstantFoldingPattern { class ConstantFoldingPass : public pir::Pass { public: - ConstantFoldingPass() - : pir::Pass("constant_folding_pass", 1), - place_(phi::CPUPlace{}), - scope_(nullptr) {} + ConstantFoldingPass() : pir::Pass("constant_folding_pass", 1) {} private: bool Initialize(pir::IrContext* context) override { - IR_ENFORCE(Has(pir::kPlaceAttr), - "Pass initialize failed." - "When using ConstantFoldingPass, place attribute is required!" - "Use Set method to set the place attribute."); - IR_ENFORCE(Has(pir::kParamScopeAttr), - "Pass initialize failed." - "When using ConstantFoldingPass, scope attribute is required!" - "Use Set method to set the scope attribute."); - - place_ = Get(pir::kPlaceAttr); - scope_ = &Get(pir::kParamScopeAttr); + PADDLE_ENFORCE_EQ( + Has(pir::Pass::kPlaceAttr), + true, + phi::errors::InvalidArgument( + "Pass initialize failed." + "When using ConstantFoldingPass, place attribute is required!" + "Use Set method to set the place attribute.")); + PADDLE_ENFORCE_EQ( + Has(pir::Pass::kParamScopeAttr), + true, + phi::errors::InvalidArgument( + "Pass initialize failed." + "When using ConstantFoldingPass, scope attribute is required!" + "Use Set method to set the scope attribute.")); + + place_ = Get(pir::Pass::kPlaceAttr); + scope_ = &Get(pir::Pass::kParamScopeAttr); PADDLE_ENFORCE_NOT_NULL( scope_, phi::errors::InvalidArgument("scope can not be nullptr")); @@ -523,7 +526,7 @@ class ConstantFoldingPass : public pir::Pass { private: size_t suffix_{0}; - phi::Place place_; + phi::Place place_{phi::CPUPlace{}}; paddle::framework::Scope* scope_{nullptr}; paddle::framework::interpreter::ExecutionConfig exe_config_{}; std::vector deleted_vars_; diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.h b/paddle/fluid/pir/transforms/general/constant_folding_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/constant_folding_pass.h rename to paddle/fluid/pir/transforms/general/constant_folding_pass.h diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc b/paddle/fluid/pir/transforms/general/dead_code_elimination_pass.cc similarity index 88% rename from paddle/fluid/pir/transforms/dead_code_elimination_pass.cc rename to paddle/fluid/pir/transforms/general/dead_code_elimination_pass.cc index 442aec918e08f..5ec283eea6810 100644 --- a/paddle/fluid/pir/transforms/dead_code_elimination_pass.cc +++ b/paddle/fluid/pir/transforms/general/dead_code_elimination_pass.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" +#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h" +#include #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" @@ -31,7 +32,12 @@ class DeadCodeEliminationPass : public pir::Pass { void Run(pir::Operation* op) override { VLOG(6) << "apply dead_code_elimination_pass"; int64_t num_erasers{0}; - EraseOp(*op->GetParentProgram()->block(), &num_erasers); + bool updated{true}; + while (updated) { + int64_t pre_num_erasers = num_erasers; + EraseOp(*op->GetParentProgram()->block(), &num_erasers); + updated = pre_num_erasers != num_erasers; + } AddStatistics(num_erasers); } diff --git a/paddle/fluid/pir/transforms/dead_code_elimination_pass.h b/paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/dead_code_elimination_pass.h rename to paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/general/identity_op_clean_pass.cc similarity index 93% rename from paddle/fluid/pir/transforms/identity_op_clean_pass.cc rename to paddle/fluid/pir/transforms/general/identity_op_clean_pass.cc index cf27800512b0b..fe2369e71a551 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/general/identity_op_clean_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" +#include "paddle/fluid/pir/transforms/general/identity_op_clean_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" @@ -53,9 +53,9 @@ class RemoveUselessScalePattern : public paddle::drr::DrrPatternBase { } }; -class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantScalePattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentScalePattern"; } + std::string name() const override { return "RemoveRedundantScalePattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -83,7 +83,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &bais_attr = res.ComputeAttr( + const auto &bias_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> float { float res_bias_1 = 0.f; float res_bias_2 = 0.f; @@ -115,7 +115,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { {"place", pat.Attr("place_1")}}); const auto &scale_op_res = res.Op("pd_op.scale", - {{"bias", bais_attr}, {"bias_after_scale", res.BoolAttr(true)}}); + {{"bias", bias_attr}, {"bias_after_scale", res.BoolAttr(true)}}); scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } @@ -154,9 +154,9 @@ class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { } }; -class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundantCastPattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -245,10 +245,10 @@ class ReplaceDropoutWithScalePattern : public paddle::drr::DrrPatternBase { } }; -class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantTransposePattern : public paddle::drr::DrrPatternBase { public: std::string name() const override { - return "RemoveRedundentTransposePattern"; + return "RemoveRedundantTransposePattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { @@ -271,10 +271,10 @@ class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { } return new_perm; }); - const auto &tranpose_continuous = + const auto &transpose_continuous = res.Op("pd_op.transpose", {{"perm", new_perm_attr}}); - res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); + res.Tensor("ret") = transpose_continuous(res.Tensor("arg_transpose")); } }; @@ -286,13 +286,13 @@ class IdentityOpCleanPass : public pir::PatternRewritePass { pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); ps.Add(paddle::drr::Create(context)); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); return ps; } }; diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.h b/paddle/fluid/pir/transforms/general/identity_op_clean_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/identity_op_clean_pass.h rename to paddle/fluid/pir/transforms/general/identity_op_clean_pass.h diff --git a/paddle/fluid/pir/transforms/inplace_pass.cc b/paddle/fluid/pir/transforms/general/inplace_pass.cc similarity index 95% rename from paddle/fluid/pir/transforms/inplace_pass.cc rename to paddle/fluid/pir/transforms/general/inplace_pass.cc index b5574685bd113..6c1044957a958 100644 --- a/paddle/fluid/pir/transforms/inplace_pass.cc +++ b/paddle/fluid/pir/transforms/general/inplace_pass.cc @@ -28,8 +28,8 @@ #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/fluid/pir/transforms/inplace_pass.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/transforms/general/inplace_pass.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/core/operation.h" #include "paddle/pir/include/pass/pass.h" @@ -184,8 +184,8 @@ bool IsNoNeedBuffer(pir::Operation* op, pir::Value value) { info_interface->get_op_info_(op_name), paddle::dialect::IsLegacyOp(op_name)); auto& no_need_buffer_ids = info_parser.NoNeedBufferIds(); - for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { - if (value == op->operand_source(no_need_buffer_ids[id])) { + for (auto no_need_buffer_id : no_need_buffer_ids) { + if (value == op->operand_source(no_need_buffer_id)) { return true; } } @@ -203,8 +203,11 @@ std::unordered_set GetSkipDeletionValues(const pir::Block& block) { 0) { continue; } - IR_ENFORCE(op.attributes().count("op_name") > 0, - "kernel_dialect op should own an 'op_name' attribute."); + PADDLE_ENFORCE_GT( + op.attributes().count("op_name"), + 0UL, + phi::errors::InvalidArgument( + "kernel_dialect op should own an 'op_name' attribute.")); auto upper_op_name = op.attributes().at("op_name").dyn_cast().AsString(); @@ -213,6 +216,7 @@ std::unordered_set GetSkipDeletionValues(const pir::Block& block) { skip_dels.insert(op.result(0)); continue; } + // TODO(chenxi67) add logic for shadow_feed_tensors op if (upper_op_name == "pd_op.fetch" || upper_op_name == "builtin.shadow_output") { skip_dels.insert(op.operand_source(0)); @@ -233,8 +237,11 @@ void GetEagerDelValueOfOp( std::string upper_op_name = op.name(); if (op.dialect()->name().compare(paddle::dialect::KernelDialect::name()) == 0) { - IR_ENFORCE(op.attributes().count("op_name") > 0, - "kernel_dialect op should own an 'op_name' attribute."); + PADDLE_ENFORCE_GT( + op.attributes().count("op_name"), + 0UL, + phi::errors::InvalidArgument( + "kernel_dialect op should own an 'op_name' attribute.")); upper_op_name = op.attributes() .at("op_name") .dyn_cast() @@ -478,9 +485,11 @@ class InplacePass : public pir::Pass { .AsString(); pir::Block::Iterator insert_pos = std::find(block.begin(), block.end(), *kv.first); - IR_ENFORCE(insert_pos != block.end(), - "Operator %s not found in block.", - kv.first->name()); + PADDLE_ENFORCE_NE( + insert_pos, + block.end(), + phi::errors::InvalidArgument("Operator %s not found in block.", + kv.first->name())); kv.first->set_attribute( "op_name", diff --git a/paddle/fluid/pir/transforms/inplace_pass.h b/paddle/fluid/pir/transforms/general/inplace_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/inplace_pass.h rename to paddle/fluid/pir/transforms/general/inplace_pass.h diff --git a/paddle/fluid/pir/transforms/map_op_to_another_pass.cc b/paddle/fluid/pir/transforms/general/map_op_to_another_pass.cc similarity index 97% rename from paddle/fluid/pir/transforms/map_op_to_another_pass.cc rename to paddle/fluid/pir/transforms/general/map_op_to_another_pass.cc index 54e274a28f007..86facef865413 100644 --- a/paddle/fluid/pir/transforms/map_op_to_another_pass.cc +++ b/paddle/fluid/pir/transforms/general/map_op_to_another_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/map_op_to_another_pass.h" +#include "paddle/fluid/pir/transforms/general/map_op_to_another_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" diff --git a/paddle/fluid/pir/transforms/map_op_to_another_pass.h b/paddle/fluid/pir/transforms/general/map_op_to_another_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/map_op_to_another_pass.h rename to paddle/fluid/pir/transforms/general/map_op_to_another_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/general/matmul_scale_fuse_pass.cc similarity index 92% rename from paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc rename to paddle/fluid/pir/transforms/general/matmul_scale_fuse_pass.cc index d167d7293fec2..ee0e1bf397b55 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/general/matmul_scale_fuse_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" +#include "paddle/fluid/pir/transforms/general/matmul_scale_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -33,7 +33,7 @@ class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { {{"transpose_x", pat.Attr("transpose_x")}, {"transpose_y", pat.Attr("transpose_y")}}); - matmul_op({&pat.Tensor("x"), &pat.Tensor("y")}, + matmul_op({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("matmul_out")}); const auto &full_op = pat.Op(paddle::dialect::FullOp::name(), {{"shape", pat.Attr("shape")}, @@ -48,6 +48,9 @@ class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { {&pat.Tensor("scale_out")}); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + if (!pir::ValueIsPersistable(match_ctx.Tensor("w"))) { + return false; + } return std::abs(match_ctx.Attr("bias")) <= 1e-6; }); @@ -65,7 +68,7 @@ class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { res.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", pat.Attr("transpose_x")}, {"transpose_y", pat.Attr("transpose_y")}}); - scale_op_res({&res.Tensor("y"), &full_op_res()}, + scale_op_res({&res.Tensor("w"), &full_op_res()}, {&res.Tensor("scale_res_out")}); matmul_op_res({&res.Tensor("x"), &res.Tensor("scale_res_out")}, {&res.Tensor("scale_out")}); diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h b/paddle/fluid/pir/transforms/general/matmul_scale_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h rename to paddle/fluid/pir/transforms/general/matmul_scale_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.cc b/paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.cc new file mode 100644 index 0000000000000..4f5dd31024a9d --- /dev/null +++ b/paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { + +class MatmulOutTransposeFusePattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { return "MatmulOutTransposeFusePattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &transpose_op = pat.Op(paddle::dialect::TransposeOp::name(), + {{"perm", pat.Attr("perm")}}); + + pat.Tensor("matmul_op_out") = matmul_op(pat.Tensor("x"), pat.Tensor("y")); + pat.Tensor("transpose_op_out") = transpose_op(pat.Tensor("matmul_op_out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto y_shape = pir::GetShapeFromValue(match_ctx.Tensor("y")); + if (x_shape.size() < 2 || y_shape.size() < 2) return false; + const auto &perm = match_ctx.Attr>("perm"); + const int perm_size = perm.size(); + for (int i = 0; i < perm_size - 2; ++i) { + if (perm[i] != i) return false; + } + if ((perm[perm_size - 1] != perm_size - 2) && + (perm[perm_size - 2] != perm_size - 1)) + return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + // transpose x y + const auto &transpose_x = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { + bool transpose_status_x = !match_ctx.Attr("transpose_x"); + return transpose_status_x; + }); + const auto &transpose_y = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { + bool transpose_status_y = !match_ctx.Attr("transpose_y"); + return transpose_status_y; + }); + const auto &fused_matmul_transpose_op = + res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", transpose_y}, {"transpose_y", transpose_x}}); + res.Tensor("transpose_op_out") = + fused_matmul_transpose_op(res.Tensor("y"), res.Tensor("x")); + } +}; + +class MatmulXTransposeFusePattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { return "MatmulXTransposeFusePattern"; } + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &transpose_op = pat.Op(paddle::dialect::TransposeOp::name(), + {{"perm", pat.Attr("perm")}}); + + pat.Tensor("x_transpose_out") = transpose_op(pat.Tensor("x")); + pat.Tensor("matmul_op_out") = + matmul_op(pat.Tensor("x_transpose_out"), pat.Tensor("y")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto y_shape = pir::GetShapeFromValue(match_ctx.Tensor("y")); + if (x_shape.size() < 2 || y_shape.size() < 2) return false; + const auto &perm = match_ctx.Attr>("perm"); + const int perm_size = perm.size(); + for (int i = 0; i < perm_size - 2; ++i) { + if (perm[i] != i) return false; + } + if ((perm[perm_size - 1] != perm_size - 2) && + (perm[perm_size - 2] != perm_size - 1)) + return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + // transpose x y + const auto &transpose_x = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { + bool transpose_status_x = !match_ctx.Attr("transpose_x"); + return transpose_status_x; + }); + const auto &transpose_y = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { + bool transpose_status_y = match_ctx.Attr("transpose_y"); + return transpose_status_y; + }); + const auto &fused_matmul_transpose_op = + res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", transpose_x}, {"transpose_y", transpose_y}}); + res.Tensor("matmul_op_out") = + fused_matmul_transpose_op(res.Tensor("x"), res.Tensor("y")); + } +}; + +class MatmulYTransposeFusePattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { return "MatmulYTransposeFusePattern"; } + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &transpose_op = pat.Op(paddle::dialect::TransposeOp::name(), + {{"perm", pat.Attr("perm")}}); + + pat.Tensor("y_transpose_out") = transpose_op(pat.Tensor("y")); + + pat.Tensor("matmul_op_out") = + matmul_op(pat.Tensor("x"), pat.Tensor("y_transpose_out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto x_shape = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto y_shape = pir::GetShapeFromValue(match_ctx.Tensor("y")); + if (x_shape.size() < 2 || y_shape.size() < 2) return false; + const auto &perm = match_ctx.Attr>("perm"); + const int perm_size = perm.size(); + for (int i = 0; i < perm_size - 2; ++i) { + if (perm[i] != i) return false; + } + if ((perm[perm_size - 1] != perm_size - 2) && + (perm[perm_size - 2] != perm_size - 1)) + return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + // transpose x y + const auto &transpose_x = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { + bool transpose_status_x = match_ctx.Attr("transpose_x"); + return transpose_status_x; + }); + const auto &transpose_y = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { + bool transpose_status_y = !match_ctx.Attr("transpose_y"); + return transpose_status_y; + }); + const auto &fused_matmul_transpose_op = + res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", transpose_x}, {"transpose_y", transpose_y}}); + res.Tensor("matmul_op_out") = + fused_matmul_transpose_op(res.Tensor("x"), res.Tensor("y")); + } +}; + +class MatmulTransposeFusePass : public pir::PatternRewritePass { + public: + MatmulTransposeFusePass() + : pir::PatternRewritePass("matmul_transpose_fuse_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); + // Add three pattern here + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateMatmulTransposeFusePass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(matmul_transpose_fuse_pass, MatmulTransposeFusePass); diff --git a/paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.h b/paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.h new file mode 100644 index 0000000000000..8f4ba43ebf3d4 --- /dev/null +++ b/paddle/fluid/pir/transforms/general/matmul_transpose_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateMatmulTransposeFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc b/paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.cc similarity index 68% rename from paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc rename to paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.cc index 10d6e66634179..01e1621eb96a6 100644 --- a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.cc +++ b/paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/params_sync_among_devices_pass.h" +#include "paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/platform/place.h" #include "paddle/common/errors.h" @@ -37,25 +37,23 @@ class ParamsSyncAmongDevicesPass : public pir::Pass { : pir::Pass("params_sync_among_devices_pass", 0) {} bool Initialize(pir::IrContext* context) override { - IR_ENFORCE(Has(pir::kPlaceAttr), - "Pass initialize failed." - "When using ConstantFoldingPass, place attribute is required!" - "Use Set method to set the place attribute."); - IR_ENFORCE(Has(pir::kParamScopeAttr), - "Pass initialize failed." - "When using ConstantFoldingPass, scope attribute is required!" - "Use Set method to set the scope attribute."); + PADDLE_ENFORCE_EQ( + Has(pir::Pass::kPlaceAttr), + true, + phi::errors::InvalidArgument( + "Pass initialize failed." + "When using ConstantFoldingPass, place attribute is required!" + "Use Set method to set the place attribute.")); + PADDLE_ENFORCE_EQ( + Has(pir::Pass::kParamScopeAttr), + true, + phi::errors::InvalidArgument( + "Pass initialize failed." + "When using ConstantFoldingPass, scope attribute is required!" + "Use Set method to set the scope attribute.")); - place_ = Get(pir::kPlaceAttr); - scope_ = &Get(pir::kParamScopeAttr); - - PADDLE_ENFORCE_NOT_NULL( - scope_, phi::errors::InvalidArgument("scope can not be nullptr")); - PADDLE_ENFORCE( - paddle::platform::is_gpu_place(place_) || - paddle::platform::is_cpu_place(place_), - phi::errors::PreconditionNotMet( - "params_sync_among_devices_pass should run on cpu or gpu.")); + place_ = Get(pir::Pass::kPlaceAttr); + scope_ = &Get(pir::Pass::kParamScopeAttr); return true; } @@ -100,11 +98,30 @@ class ParamsSyncAmongDevicesPass : public pir::Pass { } bool CanApplyOn(pir::Operation* op) const override { + PADDLE_ENFORCE_NOT_NULL( + scope_, phi::errors::InvalidArgument("scope can not be nullptr")); +#ifdef PADDLE_WITH_XPU + PADDLE_ENFORCE(paddle::platform::is_xpu_place(place_) || + paddle::platform::is_cpu_place(place_), + phi::errors::PreconditionNotMet( + "The Place attr in params_sync_among_devices_pass " + "should be cpu or xpu.")); +#endif +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE(paddle::platform::is_gpu_place(place_) || + paddle::platform::is_cpu_place(place_), + phi::errors::PreconditionNotMet( + "The Place attr in params_sync_among_devices_pass " + "should be cpu or gpu.")); +#endif + if (paddle::platform::is_cpu_place(place_)) { + return false; + } return op->isa<::pir::ModuleOp>() && op->num_regions() > 0; } private: - phi::Place place_; + phi::Place place_{phi::CPUPlace{}}; paddle::framework::Scope* scope_{nullptr}; }; diff --git a/paddle/fluid/pir/transforms/params_sync_among_devices_pass.h b/paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/params_sync_among_devices_pass.h rename to paddle/fluid/pir/transforms/general/params_sync_among_devices_pass.h diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc similarity index 96% rename from paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc rename to paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc index b3b1d14b49412..9bb8e539c2def 100644 --- a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" +#include "paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_op.h" diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h b/paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h rename to paddle/fluid/pir/transforms/general/replace_fetch_with_shadow_output_pass.h diff --git a/paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.cc new file mode 100644 index 0000000000000..619b9eeb3ec17 --- /dev/null +++ b/paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.cc @@ -0,0 +1,299 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.h" + +#include + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + +#include "paddle/fluid/pir/utils/general_functions.h" +#include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/core/value.h" +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { + +class RmsNormFusePattern : public paddle::drr::DrrPatternBase { + private: + const bool is_half_weight_; + + public: + explicit RmsNormFusePattern(bool is_half_weight) + : is_half_weight_(is_half_weight) {} + + std::string name() const override { return "RmsNormFusePattern"; } + + uint32_t benefit() const override { return 3; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &pow = pat.Op(paddle::dialect::PowOp::name()); + const auto &mean = + pat.Op(paddle::dialect::MeanOp::name(), {{"axis", pat.Attr("axis")}}); + const auto &full = pat.Op(paddle::dialect::FullOp::name()); + const auto &scale = + pat.Op(paddle::dialect::ScaleOp::name(), {{"bias", pat.Attr("bias")}}); + const auto &rsqrt = pat.Op(paddle::dialect::RsqrtOp::name()); + const auto &multiply1 = pat.Op(paddle::dialect::MultiplyOp::name()); + const auto &multiply2 = pat.Op(paddle::dialect::MultiplyOp::name()); + if (is_half_weight_) { + const auto &cast1 = pat.Op(paddle::dialect::CastOp::name(), + {{"dtype", pat.Attr("cast_type_1")}}); + pat.Tensor("cast_1_out") = cast1(pat.Tensor("x")); + pat.Tensor("pow_out") = pow(pat.Tensor("cast_1_out")); + pat.Tensor("mean_out") = mean(pat.Tensor("pow_out")); + pat.Tensor("scale_out") = scale(pat.Tensor("mean_out"), full()); + pat.Tensor("rsqrt_out") = rsqrt(pat.Tensor("scale_out")); + pat.Tensor("multiply_out1") = + multiply1(pat.Tensor("rsqrt_out"), pat.Tensor("cast_1_out")); + const auto &cast2 = pat.Op(paddle::dialect::CastOp::name(), + {{"dtype", pat.Attr("cast_type_2")}}); + pat.Tensor("cast_2_out") = cast2(pat.Tensor("multiply_out1")); + pat.Tensor("multiply_out2") = + multiply2(pat.Tensor("cast_2_out"), pat.Tensor("w")); + } else { + pat.Tensor("pow_out") = pow(pat.Tensor("x")); + pat.Tensor("mean_out") = mean(pat.Tensor("pow_out")); + pat.Tensor("scale_out") = scale(pat.Tensor("mean_out"), full()); + pat.Tensor("rsqrt_out") = rsqrt(pat.Tensor("scale_out")); + pat.Tensor("multiply_out1") = + multiply1(pat.Tensor("rsqrt_out"), pat.Tensor("x")); + pat.Tensor("multiply_out2") = + multiply2(pat.Tensor("multiply_out1"), pat.Tensor("w")); + } + pat.RequireNativeCall([this](const paddle::drr::MatchContext &match_ctx) { + auto axis = match_ctx.Attr>("axis"); + if (axis.size() > 1) { + return false; + } + if (this->is_half_weight_) { + auto w_type = pir::GetDataTypeFromValue(match_ctx.Tensor("w")); + if (!(w_type.isa() || + w_type.isa())) { + return false; + } + + auto cast_type_1 = match_ctx.Attr("cast_type_1"); + auto cast_type_2 = match_ctx.Attr("cast_type_2"); + if (cast_type_1 != phi::DataType::FLOAT32) { + return false; + } + if (w_type.isa() && + cast_type_2 != phi::DataType::FLOAT16) { + return false; + } + if (w_type.isa() && + cast_type_2 != phi::DataType::BFLOAT16) { + return false; + } + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + const auto &begin_norm_axis = + res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int { + const auto &axis = match_ctx.Attr>("axis"); + auto pow_out_shape = + pir::GetShapeFromValue(match_ctx.Tensor("pow_out")); + return axis[0] == -1 ? static_cast(pow_out_shape.size()) - 1 + : axis[0]; + }); + + const auto &rms_norm = res.Op(paddle::dialect::RmsNormOp::name(), + {{ + {"epsilon", pat.Attr("bias")}, + {"begin_norm_axis", begin_norm_axis}, + {"quant_scale", res.Float32Attr(-1.0)}, + {"quant_round_type", res.Int32Attr(0)}, + {"quant_max_bound", res.Float32Attr(0.0)}, + {"quant_min_bound", res.Float32Attr(0.0)}, + }}); + + rms_norm( + { + &res.Tensor("x"), + &res.InputNoneTensor(), + &res.InputNoneTensor(), + &res.Tensor("w"), + &res.InputNoneTensor(), + }, + {&res.Tensor("multiply_out2"), + &res.Tensor("residual_out"), + &res.Tensor("inv_var")}); + } +}; + +class AddRmsNormFusePattern : public paddle::drr::DrrPatternBase { + private: + const bool extra_add_; + + public: + explicit AddRmsNormFusePattern(bool extra_add) : extra_add_(extra_add) {} + + uint32_t benefit() const override { return extra_add_ ? 2 : 1; } + + std::string name() const override { return "AddRmsNormFusePattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + const auto &pat_rms_norm = + pat.Op(paddle::dialect::RmsNormOp::name(), + { + {"epsilon", pat.Attr("epsilon")}, + {"begin_norm_axis", pat.Attr("begin_norm_axis")}, + {"quant_scale", pat.Attr("quant_scale")}, + {"quant_round_type", pat.Attr("quant_round_type")}, + {"quant_max_bound", pat.Attr("quant_max_bound")}, + {"quant_min_bound", pat.Attr("quant_min_bound")}, + }); + pat.Tensor("add_out") = add(pat.Tensor("x"), pat.Tensor("residual")); + pat_rms_norm({&pat.Tensor("add_out"), + &pat.Tensor("bias"), + &pat.InputNoneTensor(), + &pat.Tensor("w"), + &pat.InputNoneTensor()}, + {&pat.Tensor("rms_norm_out"), + &pat.Tensor("residual_out_0"), + &pat.Tensor("inv_var_0")}); + // TODO(bukejiyu) :DRR support matching placeholder op, + // the following needs to be deleted + if (extra_add_) { + const auto &add1 = pat.Op(paddle::dialect::AddOp::name()); + pat.Tensor("add_out1") = + add1(pat.Tensor("add_out"), pat.Tensor("any_tensor")); + } + paddle::drr::ResultPattern res = pat.ResultPattern(); + const auto &res_rms_norm = + res.Op(paddle::dialect::RmsNormOp::name(), + { + {"epsilon", pat.Attr("epsilon")}, + {"begin_norm_axis", pat.Attr("begin_norm_axis")}, + {"quant_scale", pat.Attr("quant_scale")}, + {"quant_round_type", pat.Attr("quant_round_type")}, + {"quant_max_bound", pat.Attr("quant_max_bound")}, + {"quant_min_bound", pat.Attr("quant_min_bound")}, + }); + + res_rms_norm( + { + &res.Tensor("x"), + &res.Tensor("bias"), + &res.Tensor("residual"), + &res.Tensor("w"), + &res.InputNoneTensor(), + }, + {&res.Tensor("rms_norm_out"), + &res.Tensor("add_out"), + &res.Tensor("inv_var")}); + } +}; + +class AddLayerNormFusePattern : public paddle::drr::DrrPatternBase { + private: + const bool extra_add_; + + public: + explicit AddLayerNormFusePattern(bool extra_add) : extra_add_(extra_add) {} + + uint32_t benefit() const override { return extra_add_ ? 2 : 1; } + std::string name() const override { return "AddLayerNormFusePattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + const auto &layer_norm = + pat.Op(paddle::dialect::LayerNormOp::name(), + {{"epsilon", pat.Attr("epsilon")}, + {"begin_norm_axis", pat.Attr("begin_norm_axis")}}); + pat.Tensor("add_out") = add(pat.Tensor("x"), pat.Tensor("residual")); + layer_norm({&pat.Tensor("add_out"), &pat.Tensor("w"), &pat.Tensor("bias")}, + {&pat.Tensor("layer_norm_out"), + &pat.Tensor("mean_out_0"), + &pat.Tensor("variance_out_0")}); + // TODO(bukejiyu) :DRR support matching placeholder op, + // the following needs to be deleted + if (extra_add_) { + const auto &add1 = pat.Op(paddle::dialect::AddOp::name()); + pat.Tensor("add_out1") = + add1(pat.Tensor("add_out"), pat.Tensor("any_tensor")); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + const auto &fuse_layer_norm = + res.Op(paddle::dialect::FusedBiasResidualLayernormOp::name(), + {{"epsilon", pat.Attr("epsilon")}, + {"residual_alpha", res.Float32Attr(1.0)}, + {"begin_norm_axis", pat.Attr("begin_norm_axis")}, + {"quant_scale", res.Float32Attr(-1.0)}, + {"quant_round_type", res.Int32Attr(0)}, + {"quant_max_bound", res.Float32Attr(0.0)}, + {"quant_min_bound", res.Float32Attr(0.0)}}); + + fuse_layer_norm( + { + &res.Tensor("x"), + &res.Tensor("bias"), + &res.Tensor("residual"), + &res.Tensor("w"), + &res.InputNoneTensor(), + }, + {&res.Tensor("layer_norm_out"), + &res.Tensor("add_out"), + &res.Tensor("mean_out"), + &res.Tensor("variance_out")}); + } +}; + +class AddNormFusePass : public pir::PatternRewritePass { + public: + AddNormFusePass() : pir::PatternRewritePass("add_norm_fuse_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + // x-pow-mean-scale->rsqrt- + // mul-- + // x----------------------- + // mul --->rms_norm + // w----------------------------- + bool is_half_weight = true; + bool extra_add = true; + ps.Add(paddle::drr::Create(context, !is_half_weight)); + ps.Add(paddle::drr::Create(context, is_half_weight)); + // x-------- + // add-rms_norm ---> rms_norm + // residual- + ps.Add(paddle::drr::Create(context, !extra_add)); + ps.Add(paddle::drr::Create(context, extra_add)); + // x-------- + // add-layer_norm ----> fused_bias_residual_layernorm + // residual- + ps.Add(paddle::drr::Create(context, !extra_add)); + ps.Add(paddle::drr::Create(context, extra_add)); + return ps; + } +}; +} // namespace + +namespace pir { +std::unique_ptr CreateAddNormFusePass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(add_norm_fuse_pass, AddNormFusePass); diff --git a/paddle/fluid/string/printf.h b/paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.h similarity index 64% rename from paddle/fluid/string/printf.h rename to paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.h index 40cc5450f4159..e57f32775a9bc 100644 --- a/paddle/fluid/string/printf.h +++ b/paddle/fluid/pir/transforms/gpu/add_norm_fuse_pass.h @@ -1,10 +1,10 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // -// http://www.apache.org/licenses/LICENSE-2.0 +// http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, @@ -13,4 +13,14 @@ // limitations under the License. #pragma once -#include "paddle/utils/string/printf.h" + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateAddNormFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/conv2d_add_act_fuse_pass.cc similarity index 93% rename from paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/conv2d_add_act_fuse_pass.cc index 9e950dc2d11b9..b842e529a63f0 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/conv2d_add_act_fuse_pass.cc @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/conv2d_add_act_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -43,9 +43,13 @@ class Conv2dAddActFusePattern if (!conv2d_out.HasOneUse()) return false; pir::Value add_input = op.x(); - IR_ENFORCE(add_input == conv2d_out); + PADDLE_ENFORCE_EQ( + add_input, + conv2d_out, + phi::errors::PreconditionNotMet("The type of add input should be the " + "same as the type of conv2d's out.")); - if (!pir::ValueIsPersitable(op.y())) return false; + if (!pir::ValueIsPersistable(op.y())) return false; pir::Value add_out = op.out(); if (!add_out.HasOneUse()) return false; @@ -119,7 +123,7 @@ class Conv2dAdd2ActFusePattern ->dyn_cast(); if (!add1_op) return false; - if (!pir::ValueIsPersitable(add1_op.y())) return false; + if (!pir::ValueIsPersistable(add1_op.y())) return false; pir::Value add1_out = add1_op.out(); if (!add1_out.HasOneUse()) return false; @@ -207,7 +211,7 @@ class Conv2dAddActFusePass : public pir::PatternRewritePass { 1, std::vector{ paddle::dialect::FusedConv2dAddActOp::name()}); - auto conv2d_doublue_add_act_fuse_pattern = + auto conv2d_double_add_act_fuse_pattern = std::make_unique( context, 1, @@ -215,7 +219,7 @@ class Conv2dAddActFusePass : public pir::PatternRewritePass { paddle::dialect::FusedConv2dAddActOp::name()}); // conv2d+add+add+act->fused_conv2d_add_act - ps.Add(std::move(conv2d_doublue_add_act_fuse_pattern)); + ps.Add(std::move(conv2d_double_add_act_fuse_pattern)); // conv2d+add+act->fused_conv2d_add_act ps.Add(std::move(conv2d_add_act_fuse_pattern)); return ps; diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/conv2d_add_act_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/conv2d_add_act_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/conv2d_add_fuse_pass.cc similarity index 95% rename from paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/conv2d_add_fuse_pass.cc index 9c1cec5b9b645..dfd2b0ed588e2 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/conv2d_add_fuse_pass.cc @@ -12,14 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/conv2d_add_fuse_pass.h" #include #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/core/value.h" #include "paddle/pir/include/pass/pass.h" @@ -47,7 +47,7 @@ class Conv2dAddFusePattern : public paddle::drr::DrrPatternBase { pat.Tensor("add_out") = add(pat.Tensor("conv2d_out"), pat.Tensor("bias")); pat.RequireNativeCall( [](const paddle::drr::MatchContext &match_ctx) -> bool { - if (!pir::ValueIsPersitable(match_ctx.Tensor("bias"))) { + if (!pir::ValueIsPersistable(match_ctx.Tensor("bias"))) { return false; } @@ -107,7 +107,6 @@ class Conv2dAddFusePass : public pir::PatternRewritePass { } // namespace namespace pir { - std::unique_ptr CreateConv2dAddFusePass() { return std::make_unique(); } diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/conv2d_add_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/conv2d_add_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/conv2d_bn_fuse_pass.cc similarity index 94% rename from paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/conv2d_bn_fuse_pass.cc index d72e9167b118c..231aaaba7ce05 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/conv2d_bn_fuse_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/conv2d_bn_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -57,6 +57,13 @@ class Conv2dBnFusePattern return false; } if (!conv2d_op.out().HasOneUse()) return false; + // (bukejiyu): The bn + // outputs(mean_out\variance_out\saved_mean\saved_variance) + // cannot be used in conv bn fusion + if (!op.mean_out().use_empty()) return false; + if (!op.variance_out().use_empty()) return false; + if (!op.saved_mean().use_empty()) return false; + if (!op.saved_variance().use_empty()) return false; pir::Value conv2d_filter = conv2d_op.filter(); pir::Value bn_mean = op.mean(); diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/conv2d_bn_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/conv2d_bn_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/embedding_eltwise_layernorm_fuse_pass.cc similarity index 98% rename from paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/embedding_eltwise_layernorm_fuse_pass.cc index c8a61af1aef27..58409b2fbcb15 100644 --- a/paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/embedding_eltwise_layernorm_fuse_pass.cc @@ -12,10 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/embedding_eltwise_layernorm_fuse_pass.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/pass/pass.h" diff --git a/paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/embedding_eltwise_layernorm_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/embedding_eltwise_layernorm_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/fc_elementwise_layernorm_fuse_pass.cc similarity index 96% rename from paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/fc_elementwise_layernorm_fuse_pass.cc index 826d40854fa7c..d3e4ed862e741 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fc_elementwise_layernorm_fuse_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fc_elementwise_layernorm_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/fc_elementwise_layernorm_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/fc_elementwise_layernorm_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/fc_fuse_pass.cc similarity index 97% rename from paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/fc_fuse_pass.cc index b62402c096091..187c4e34f5962 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fc_fuse_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fc_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/fc_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/fc_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_dot_product_attention_pass.cc similarity index 99% rename from paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc rename to paddle/fluid/pir/transforms/gpu/fused_dot_product_attention_pass.cc index dce6483742d38..69882f537a9bb 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_dot_product_attention_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fused_dot_product_attention_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h b/paddle/fluid/pir/transforms/gpu/fused_dot_product_attention_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h rename to paddle/fluid/pir/transforms/gpu/fused_dot_product_attention_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_dropout_add_pass.cc similarity index 98% rename from paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc rename to paddle/fluid/pir/transforms/gpu/fused_dropout_add_pass.cc index a235a8b4ecf67..ccc66d848ecbe 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_dropout_add_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fused_dropout_add_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h b/paddle/fluid/pir/transforms/gpu/fused_dropout_add_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h rename to paddle/fluid/pir/transforms/gpu/fused_dropout_add_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.cc similarity index 99% rename from paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc rename to paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.cc index 6eeb899d67710..0d76f9e569d7f 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h b/paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h rename to paddle/fluid/pir/transforms/gpu/fused_gemm_epilogue_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_linear_param_grad_add_pass.cc similarity index 96% rename from paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc rename to paddle/fluid/pir/transforms/gpu/fused_linear_param_grad_add_pass.cc index 120b882a67194..8bb56c51ea3a5 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_linear_param_grad_add_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fused_linear_param_grad_add_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -67,7 +67,7 @@ class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); @@ -78,7 +78,7 @@ class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { {"transpose_y", res.BoolAttr(true)}}); const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(true)}}}); matmul({&res.Tensor("fwd_add_out_grad"), &res.Tensor("weight")}, @@ -122,7 +122,7 @@ class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); @@ -133,7 +133,7 @@ class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { {"transpose_y", res.BoolAttr(true)}}); const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(false)}}}); matmul({&res.Tensor("out_grad"), &res.Tensor("weight")}, @@ -194,7 +194,7 @@ class FusedMatmulReshapeMatmulAddPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("w_grad"))); @@ -202,7 +202,7 @@ class FusedMatmulReshapeMatmulAddPattern : public paddle::drr::DrrPatternBase { const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(false)}}}); fused_linear_param_grad_add( @@ -239,7 +239,7 @@ class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); @@ -247,7 +247,7 @@ class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(false)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), @@ -283,7 +283,7 @@ class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); @@ -291,7 +291,7 @@ class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(false)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), @@ -341,7 +341,7 @@ class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); @@ -349,7 +349,7 @@ class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(true)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), @@ -399,14 +399,14 @@ class FusedMatmulAddGradAddbPattern : public paddle::drr::DrrPatternBase { }); paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &muti_precision_attr = + const auto &multi_precision_attr = res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> bool { return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &fused_linear_param_grad_add = res.Op(paddle::dialect::FusedLinearParamGradAddOp::name(), - {{{"multi_precision", muti_precision_attr}, + {{{"multi_precision", multi_precision_attr}, {"has_bias", res.BoolAttr(true)}}}); fused_linear_param_grad_add( {&res.Tensor("x"), diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h b/paddle/fluid/pir/transforms/gpu/fused_linear_param_grad_add_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h rename to paddle/fluid/pir/transforms/gpu/fused_linear_param_grad_add_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc similarity index 51% rename from paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc rename to paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc index bf4ea92af67b2..17bd3f48461e2 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" +#include "paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" @@ -37,9 +37,20 @@ int getSMVersion() { return sm_version; } -class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { +class FusedWeightOnlyLinearWithBiasPattern + : public paddle::drr::DrrPatternBase { + private: + bool reverse_add_; + public: - std::string name() const override { return "FusedWeightOnlyLinearPattern"; } + explicit FusedWeightOnlyLinearWithBiasPattern(bool reverse_add) + : reverse_add_(reverse_add) {} + + std::string name() const override { + return "FusedWeightOnlyLinearWithBiasPattern"; + } + + uint32_t benefit() const override { return 2; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { // @@ -50,21 +61,31 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { src.Op(paddle::dialect::MatmulOp::name(), {{"transpose_x", src.Attr("matmul_transpose_x")}, {"transpose_y", src.Attr("matmul_transpose_y")}}); - const auto ¶meter = src.Op(pir::ParameterOp::name()); - src.Tensor("w") = parameter(); src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w")); const auto &add = src.Op(paddle::dialect::AddOp::name()); - src.Tensor("add_out") = add(src.Tensor("matmul_out"), src.Tensor("bias")); + + src.Tensor("add_out") = + reverse_add_ ? add(src.Tensor("matmul_out"), src.Tensor("bias")) + : add(src.Tensor("bias"), src.Tensor("matmul_out")); // // Constraints. // src.RequireNativeCall( [](const paddle::drr::MatchContext &match_ctx) -> bool { + if (!pir::ValueIsPersistable(match_ctx.Tensor("w"))) { + return false; + } bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); if (matmul_trans_x || matmul_trans_y) return false; + auto w_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("w")); + if (!w_dtype.isa() && + !w_dtype.isa()) { + return false; + } + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); @@ -73,6 +94,75 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { return false; } + if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; + if (x_dims.at(x_dims.size() - 1) != w_dims.at(0)) return false; + + return true; + }); + // + // Result Pattern. + // + paddle::drr::ResultPattern res = src.ResultPattern(); + + const auto &weight_quantize = + res.Op(paddle::dialect::WeightQuantizeOp::name(), + {{"algo", res.StrAttr("weight_only_int8")}, + {"arch", res.Int32Attr(getSMVersion())}, + {"group_size", res.Int32Attr(-1)}}); + weight_quantize({&res.Tensor("w")}, + {&res.Tensor("quanted_weight_tensor"), + &res.Tensor("weight_scale_tensor")}); + + const auto &weight_only_linear = + res.Op(paddle::dialect::WeightOnlyLinearOp::name(), + {{"weight_dtype", res.StrAttr("int8")}, + {"arch", res.Int32Attr(getSMVersion())}, + {"group_size", res.Int32Attr(-1)}}); + weight_only_linear({&res.Tensor("x"), + &res.Tensor("quanted_weight_tensor"), + &res.Tensor("bias"), + &res.Tensor("weight_scale_tensor")}, + {&res.Tensor("add_out")}); + } +}; + +class FusedWeightOnlyLinearNoBiasPattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { + return "FusedWeightOnlyLinearNoBiasPattern"; + } + + uint32_t benefit() const override { return 1; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + // + // Source Pattern. + // + paddle::drr::SourcePattern src = ctx->SourcePattern(); + const auto &matmul = + src.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", src.Attr("matmul_transpose_x")}, + {"transpose_y", src.Attr("matmul_transpose_y")}}); + src.Tensor("matmul_out") = matmul(src.Tensor("x"), src.Tensor("w")); + + // + // Constraints. + // + src.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + if (!pir::ValueIsPersistable(match_ctx.Tensor("w"))) { + return false; + } + bool matmul_trans_x = match_ctx.Attr("matmul_transpose_x"); + bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); + if (matmul_trans_x || matmul_trans_y) return false; + + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + if (!(w_dims.size() == 2 && x_dims.size() >= 2)) { + return false; + } + if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; auto w_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("w")); @@ -80,7 +170,7 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { !w_dtype.isa()) return false; - if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; + if (x_dims.at(x_dims.size() - 1) != w_dims.at(0)) return false; return true; }); @@ -105,9 +195,9 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { {"group_size", res.Int32Attr(-1)}}); weight_only_linear({&res.Tensor("x"), &res.Tensor("quanted_weight_tensor"), - &res.Tensor("bias"), + &res.InputNoneTensor(), &res.Tensor("weight_scale_tensor")}, - {&res.Tensor("add_out")}); + {&res.Tensor("matmul_out")}); } }; @@ -118,14 +208,29 @@ class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context, + true)); + ps.Add(paddle::drr::Create(context, + false)); + ps.Add(paddle::drr::Create(context)); return ps; } + pir::GreedyRewriteConfig InitializeConfig() override { + pir::GreedyRewriteConfig config; + + // NOTE(liuyuanle): Ensure that WithBiasPattern is executed before + // NoBiasPattern. + config.use_top_down_traversal = false; + + config.max_iterations = 10; + return config; + } + bool CanApplyOn(pir::Operation *op) const override { - int sm_vesion = getSMVersion(); - if (sm_vesion != 70 && sm_vesion != 75 && sm_vesion != 80 && - sm_vesion != 86) { + int sm_version = getSMVersion(); + if (sm_version != 70 && sm_version != 75 && sm_version != 80 && + sm_version != 86) { return false; } return op->num_regions() > 0; diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h b/paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h rename to paddle/fluid/pir/transforms/gpu/fused_weight_only_linear_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/multihead_matmul_fuse_pass.cc similarity index 99% rename from paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/multihead_matmul_fuse_pass.cc index 760c93fd755ec..16884e5f9cd30 100644 --- a/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/multihead_matmul_fuse_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/multihead_matmul_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" diff --git a/paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/multihead_matmul_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/multihead_matmul_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/silu_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/silu_fuse_pass.cc similarity index 97% rename from paddle/fluid/pir/transforms/fusion/silu_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/silu_fuse_pass.cc index a84b331134f08..00112bfa79124 100644 --- a/paddle/fluid/pir/transforms/fusion/silu_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/silu_fuse_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/silu_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/silu_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" diff --git a/paddle/fluid/pir/transforms/fusion/silu_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/silu_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/silu_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/silu_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.cc b/paddle/fluid/pir/transforms/gpu/transpose_flatten_concat_fuse_pass.cc similarity index 98% rename from paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.cc rename to paddle/fluid/pir/transforms/gpu/transpose_flatten_concat_fuse_pass.cc index f9a247f3c01cf..fa439a2c0344d 100644 --- a/paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/gpu/transpose_flatten_concat_fuse_pass.cc @@ -12,11 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.h" +#include "paddle/fluid/pir/transforms/gpu/transpose_flatten_concat_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" diff --git a/paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.h b/paddle/fluid/pir/transforms/gpu/transpose_flatten_concat_fuse_pass.h similarity index 100% rename from paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.h rename to paddle/fluid/pir/transforms/gpu/transpose_flatten_concat_fuse_pass.h diff --git a/paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.cc index 67177d9cee390..d75d00dbdb83a 100644 --- a/paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/onednn/conv_bias_fuse_pass.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_registry.h" @@ -52,14 +53,16 @@ class ConvBiasFusePattern : public paddle::drr::DrrPatternBase { const auto &add = pat.Op(paddle::dialect::AddOp::name()); conv({&pat.Tensor("input"), &pat.Tensor("filter")}, {&pat.Tensor("conv_out")}); - const auto ¶meter_bias = pat.Op( - pir::ParameterOp::name(), {{"parameter_name", pat.Attr("param_name")}}); - pat.Tensor("bias") = parameter_bias(); + pat.Tensor("add_out") = add(pat.Tensor("conv_out"), pat.Tensor("bias")); if (conv_name_ == paddle::dialect::Conv2dOp::name() || conv_name_ == paddle::onednn::dialect::FusedConv2dOp::name()) { pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + if (!pir::ValueIsPersistable(match_ctx.Tensor("bias"))) { + return false; + } + std::set padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; std::set data_format = {"NCHW", "NHWC", "AnyLayout"}; if (padding_algorithm.count( @@ -73,6 +76,10 @@ class ConvBiasFusePattern : public paddle::drr::DrrPatternBase { }); } else { pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + if (!pir::ValueIsPersistable(match_ctx.Tensor("bias"))) { + return false; + } + std::set padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; std::set data_format = {"NDHWC", "NCDHW"}; if (padding_algorithm.count( @@ -117,26 +124,91 @@ class ConvBiasFusePattern : public paddle::drr::DrrPatternBase { } }; -class FusedConvAddFusePattern : public paddle::drr::DrrPatternBase { - private: - std::string conv_name_; - std::string fused_conv_name_; +class ConvTransposeBiasFusePattern : public paddle::drr::DrrPatternBase { + std::string name() const override { return "ConvTransposeBiasFusePattern"; } - public: - FusedConvAddFusePattern(const std::string &conv_name, - const std::string &fused_conv_name) - : conv_name_(conv_name), fused_conv_name_(fused_conv_name) {} + uint32_t benefit() const override { return 2; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &conv = + pat.Op(paddle::dialect::Conv2dTransposeOp::name(), + {{"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"output_padding", pat.Attr("output_padding")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + conv({&pat.Tensor("input"), + &pat.Tensor("filter"), + &pat.Tensor("output_size")}, + {&pat.Tensor("conv_out")}); + + pat.Tensor("add_out") = add(pat.Tensor("conv_out"), pat.Tensor("bias")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + if (!pir::ValueIsPersistable(match_ctx.Tensor("bias"))) { + return false; + } + + std::set padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; + std::set data_format = {"NCHW", "NHWC", "AnyLayout"}; + if (padding_algorithm.count( + match_ctx.Attr("padding_algorithm")) == 0 || + data_format.count(match_ctx.Attr("data_format")) == 0 || + match_ctx.Attr("groups") < 1) { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_conv = + res.Op(paddle::onednn::dialect::Conv2dTransposeBiasOp::name(), + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"output_padding", pat.Attr("output_padding")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"force_fp32_output", res.BoolAttr(false)}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"fuse_relu", res.BoolAttr(false)}, + {"fuse_activation", res.StrAttr("")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"is_test", res.BoolAttr(true)}, + }}); + + fused_conv({&res.Tensor("input"), + &res.Tensor("filter"), + &res.Tensor("bias"), + &res.Tensor("output_size")}, + {&res.Tensor("add_out")}); + } +}; - std::string name() const override { return "FusedConvAddFusePattern"; } +class FusedConvTransposeAddFusePattern : public paddle::drr::DrrPatternBase { + std::string name() const override { + return "FusedConvTransposeAddFusePattern"; + } uint32_t benefit() const override { return 3; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); const auto &conv = - pat.Op(conv_name_, + pat.Op(paddle::dialect::Conv2dTransposeOp::name(), {{"strides", pat.Attr("strides")}, {"paddings", pat.Attr("paddings")}, + {"output_padding", pat.Attr("output_padding")}, {"padding_algorithm", pat.Attr("padding_algorithm")}, {"dilations", pat.Attr("dilations")}, {"groups", pat.Attr("groups")}, @@ -144,48 +216,33 @@ class FusedConvAddFusePattern : public paddle::drr::DrrPatternBase { const auto &add = pat.Op(paddle::dialect::AddOp::name()); const auto &add2 = pat.Op(paddle::dialect::AddOp::name()); - conv({&pat.Tensor("input"), &pat.Tensor("filter")}, + conv({&pat.Tensor("input"), + &pat.Tensor("filter"), + &pat.Tensor("output_size")}, {&pat.Tensor("conv_out")}); - const auto ¶meter_bias = pat.Op( - pir::ParameterOp::name(), {{"parameter_name", pat.Attr("param_name")}}); - pat.Tensor("bias") = parameter_bias(); pat.Tensor("add_out") = add(pat.Tensor("conv_out"), pat.Tensor("bias")); - - const auto ¶meter = pat.Op( - pir::ParameterOp::name(), {{"parameter_name", pat.Attr("param_name")}}); - pat.Tensor("other_param") = parameter(); pat.Tensor("result") = add2(pat.Tensor("add_out"), pat.Tensor("other_param")); - if (conv_name_ == paddle::dialect::Conv2dOp::name() || - conv_name_ == paddle::onednn::dialect::FusedConv2dOp::name()) { - pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - std::set padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; - std::set data_format = {"NCHW", "NHWC", "AnyLayout"}; - if (padding_algorithm.count( - match_ctx.Attr("padding_algorithm")) == 0 || - data_format.count(match_ctx.Attr("data_format")) == - 0 || - match_ctx.Attr("groups") < 1) { - return false; - } - return true; - }); - } else { - pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - std::set padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; - std::set data_format = {"NDHWC", "NCDHW"}; - if (padding_algorithm.count( - match_ctx.Attr("padding_algorithm")) == 0 || - data_format.count(match_ctx.Attr("data_format")) == - 0 || - match_ctx.Attr("groups") < 1) { - return false; - } - return true; - }); - } + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + if (!pir::ValueIsPersistable(match_ctx.Tensor("bias"))) { + return false; + } + if (!pir::ValueIsPersistable(match_ctx.Tensor("other_param"))) { + return false; + } + + std::set padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; + std::set data_format = {"NCHW", "NHWC", "AnyLayout"}; + if (padding_algorithm.count( + match_ctx.Attr("padding_algorithm")) == 0 || + data_format.count(match_ctx.Attr("data_format")) == 0 || + match_ctx.Attr("groups") < 1) { + return false; + } + return true; + }); paddle::drr::ResultPattern res = pat.ResultPattern(); @@ -194,30 +251,28 @@ class FusedConvAddFusePattern : public paddle::drr::DrrPatternBase { fused_add(res.Tensor("bias"), res.Tensor("other_param")); const auto &fused_conv = - res.Op(fused_conv_name_, + res.Op(paddle::onednn::dialect::Conv2dTransposeBiasOp::name(), {{ {"strides", pat.Attr("strides")}, {"paddings", pat.Attr("paddings")}, + {"output_padding", pat.Attr("output_padding")}, {"padding_algorithm", pat.Attr("padding_algorithm")}, {"dilations", pat.Attr("dilations")}, {"groups", pat.Attr("groups")}, {"data_format", pat.Attr("data_format")}, + {"force_fp32_output", res.BoolAttr(false)}, {"mkldnn_data_type", res.StrAttr("float32")}, + {"fuse_relu", res.BoolAttr(false)}, {"fuse_activation", res.StrAttr("")}, - {"fuse_residual_connection", res.BoolAttr(false)}, - {"force_fp32_output", res.BoolAttr(false)}, {"fuse_alpha", res.Float32Attr(0.0f)}, {"fuse_beta", res.Float32Attr(0.0f)}, - {"scale_in", res.Float32Attr(1.0f)}, - {"scale_out", res.Float32Attr(1.0f)}, - {"scale_in_eltwise", res.Float32Attr(1.0f)}, - {"scale_weights", res.VectorFloatAttr({1.0f})}, + {"is_test", res.BoolAttr(true)}, }}); fused_conv({&res.Tensor("input"), &res.Tensor("filter"), &res.Tensor("bias2"), - &res.InputNoneTensor()}, + &res.Tensor("output_size")}, {&res.Tensor("result")}); } }; @@ -232,26 +287,22 @@ class Conv2dBiasFusePass : public pir::PatternRewritePass { context, paddle::dialect::Conv2dOp::name(), paddle::onednn::dialect::FusedConv2dOp::name())); - ps.Add(paddle::drr::Create( - context, - paddle::dialect::Conv2dOp::name(), - paddle::onednn::dialect::FusedConv2dOp::name())); return ps; } }; -// class Conv2dTransposeBiasFusePass : public pir::PatternRewritePass { -// public: -// Conv2dTransposeBiasFusePass() -// : pir::PatternRewritePass("conv2d_transpose_bias_fuse_pass", 2) {} +class Conv2dTransposeBiasFusePass : public pir::PatternRewritePass { + public: + Conv2dTransposeBiasFusePass() + : pir::PatternRewritePass("conv2d_transpose_bias_fuse_pass", 2) {} -// pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override -// { -// pir::RewritePatternSet ps(context); -// ps.Add(paddle::drr::Create(context)); -// return ps; -// } -// }; + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); + return ps; + } +}; class Conv3dBiasFusePass : public pir::PatternRewritePass { public: @@ -263,10 +314,6 @@ class Conv3dBiasFusePass : public pir::PatternRewritePass { context, paddle::dialect::Conv3dOp::name(), paddle::onednn::dialect::FusedConv3dOp::name())); - ps.Add(paddle::drr::Create( - context, - paddle::dialect::Conv3dOp::name(), - paddle::onednn::dialect::FusedConv3dOp::name())); return ps; } }; @@ -281,10 +328,12 @@ std::unique_ptr CreateConv2dBiasFusePass() { return std::make_unique(); } -// std::unique_ptr CreateConv2dTransposeBiasFusePass() { -// // pd_op.conv2d_transpose + pd_op.add -> onednn_op.fused_conv2d -// return std::make_unique(); -// } +std::unique_ptr CreateConv2dTransposeBiasFusePass() { + // pd_op.conv2d_transpose + pd_op.add -> onednn_op.conv2d_transpose_bias + // onednn_op.conv2d_transpose_bias + pd_op.add -> + // onednn_op.conv2d_transpose_bias + pd_op.add + return std::make_unique(); +} std::unique_ptr CreateConv3dBiasFusePass() { // pd_op.conv3d + pd_op.add -> onednn_op.fused_conv3d @@ -294,6 +343,5 @@ std::unique_ptr CreateConv3dBiasFusePass() { } // namespace pir REGISTER_IR_PASS(conv2d_bias_fuse_pass, Conv2dBiasFusePass); -// REGISTER_IR_PASS(conv2d_transpose_bias_fuse_pass, -// Conv2dTransposeBiasFusePass); +REGISTER_IR_PASS(conv2d_transpose_bias_fuse_pass, Conv2dTransposeBiasFusePass); REGISTER_IR_PASS(conv3d_bias_fuse_pass, Conv3dBiasFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000..4ecd752b85997 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -0,0 +1,425 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { + +class ConvElementwiseAddPattern : public paddle::drr::DrrPatternBase { + private: + std::string conv_name_; + std::string fused_conv_name_; + + public: + ConvElementwiseAddPattern(const std::string &conv_name, + const std::string &fused_conv_name) + : conv_name_(conv_name), fused_conv_name_(fused_conv_name) {} + + std::string name() const override { return "ConvElementwiseAddPattern"; } + + uint32_t benefit() const override { return 2; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &conv = + pat.Op(conv_name_, + {{"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + conv({&pat.Tensor("input"), &pat.Tensor("filter")}, + {&pat.Tensor("conv2d_out")}); + + pat.Tensor("add_out") = + add(pat.Tensor("conv2d_out"), pat.Tensor("residual_param")); + pat.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto padding_algorithm = + match_ctx.Attr("padding_algorithm"); + if (padding_algorithm != "EXPLICIT" && padding_algorithm != "SAME" && + padding_algorithm != "VALID") { + return false; + } + auto groups = match_ctx.Attr("groups"); + if (groups < 1) { + return false; + } + auto data_format = match_ctx.Attr("data_format"); + if (data_format != "NCHW" && data_format != "AnyLayout") { + return false; + } + return true; + }); + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_conv2d_add = + res.Op(fused_conv_name_, + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"fuse_activation", res.StrAttr("")}, + {"fuse_residual_connection", res.BoolAttr(true)}, + {"force_fp32_output", res.BoolAttr(false)}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"scale_in", res.Float32Attr(1.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(1.0f)}, + {"scale_weights", res.VectorFloatAttr({1.0f})}, + }}); + + fused_conv2d_add({&res.Tensor("input"), + &res.Tensor("filter"), + &res.InputNoneTensor(), + &res.Tensor("residual_param")}, + {&res.Tensor("add_out")}); + } +}; + +class ConvElementwiseAddAsYPattern : public paddle::drr::DrrPatternBase { + private: + std::string conv_name_; + std::string fused_conv_name_; + + public: + ConvElementwiseAddAsYPattern(const std::string &conv_name, + const std::string &fused_conv_name) + : conv_name_(conv_name), fused_conv_name_(fused_conv_name) {} + + std::string name() const override { return "ConvElementwiseAddAsYPattern"; } + + uint32_t benefit() const override { return 2; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &conv = + pat.Op(conv_name_, + {{"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + conv({&pat.Tensor("input"), &pat.Tensor("filter")}, + {&pat.Tensor("conv2d_out")}); + pat.Tensor("add_out") = + add(pat.Tensor("residual_param"), pat.Tensor("conv2d_out")); + + pat.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto padding_algorithm = + match_ctx.Attr("padding_algorithm"); + if (padding_algorithm != "EXPLICIT" && padding_algorithm != "SAME" && + padding_algorithm != "VALID") { + return false; + } + auto groups = match_ctx.Attr("groups"); + if (groups < 1) { + return false; + } + auto data_format = match_ctx.Attr("data_format"); + if (data_format != "NCHW" && data_format != "AnyLayout") { + return false; + } + return true; + }); + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_conv2d_add = + res.Op(fused_conv_name_, + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"fuse_activation", res.StrAttr("")}, + {"fuse_residual_connection", res.BoolAttr(true)}, + {"force_fp32_output", res.BoolAttr(false)}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"scale_in", res.Float32Attr(1.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(1.0f)}, + {"scale_weights", res.VectorFloatAttr({1.0f})}, + }}); + + fused_conv2d_add({&res.Tensor("input"), + &res.Tensor("filter"), + &res.InputNoneTensor(), + &res.Tensor("residual_param")}, + {&res.Tensor("add_out")}); + } +}; + +class FusedConvBiasElementwiseAddPattern : public paddle::drr::DrrPatternBase { + private: + std::string conv_name_; + std::string fused_conv_name_; + + public: + FusedConvBiasElementwiseAddPattern(const std::string &conv_name, + const std::string &fused_conv_name) + : conv_name_(conv_name), fused_conv_name_(fused_conv_name) {} + + std::string name() const override { + return "FusedConvBiasElementwiseAddPattern"; + } + + uint32_t benefit() const override { return 2; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &conv = pat.Op( + conv_name_, + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_residual_connection", pat.Attr("fuse_residual_connection")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"scale_in", pat.Attr("scale_in")}, + {"scale_out", pat.Attr("scale_out")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_weights", pat.Attr("scale_weights")}, + }}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + conv({&pat.Tensor("input"), + &pat.Tensor("filter"), + &pat.Tensor("bias"), + &pat.InputNoneTensor()}, + {&pat.Tensor("conv2d_out")}); + + pat.Tensor("add_out") = + add(pat.Tensor("conv2d_out"), pat.Tensor("residual_param")); + pat.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto padding_algorithm = + match_ctx.Attr("padding_algorithm"); + if (padding_algorithm != "EXPLICIT" && padding_algorithm != "SAME" && + padding_algorithm != "VALID") { + return false; + } + auto groups = match_ctx.Attr("groups"); + if (groups < 1) { + return false; + } + auto data_format = match_ctx.Attr("data_format"); + if (data_format != "NCHW" && data_format != "AnyLayout") { + return false; + } + return true; + }); + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_conv2d_add = + res.Op(fused_conv_name_, + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_residual_connection", res.BoolAttr(true)}, + {"force_fp32_output", pat.Attr("force_fp32_output")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"scale_in", pat.Attr("scale_in")}, + {"scale_out", pat.Attr("scale_out")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_weights", pat.Attr("scale_weights")}, + }}); + + fused_conv2d_add({&res.Tensor("input"), + &res.Tensor("filter"), + &res.Tensor("bias"), + &res.Tensor("residual_param")}, + {&res.Tensor("add_out")}); + } +}; + +class FusedConvBiasElementwiseAddAsYPattern + : public paddle::drr::DrrPatternBase { + private: + std::string conv_name_; + std::string fused_conv_name_; + + public: + FusedConvBiasElementwiseAddAsYPattern(const std::string &conv_name, + const std::string &fused_conv_name) + : conv_name_(conv_name), fused_conv_name_(fused_conv_name) {} + + std::string name() const override { + return "FusedConvBiasElementwiseAddAsYPattern"; + } + + uint32_t benefit() const override { return 2; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &conv = pat.Op( + conv_name_, + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_residual_connection", pat.Attr("fuse_residual_connection")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"scale_in", pat.Attr("scale_in")}, + {"scale_out", pat.Attr("scale_out")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_weights", pat.Attr("scale_weights")}, + }}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + conv({&pat.Tensor("input"), + &pat.Tensor("filter"), + &pat.Tensor("bias"), + &pat.InputNoneTensor()}, + {&pat.Tensor("conv2d_out")}); + + pat.Tensor("add_out") = + add(pat.Tensor("residual_param"), pat.Tensor("conv2d_out")); + pat.RequireNativeCall( + [](const paddle::drr::MatchContext &match_ctx) -> bool { + auto padding_algorithm = + match_ctx.Attr("padding_algorithm"); + if (padding_algorithm != "EXPLICIT" && padding_algorithm != "SAME" && + padding_algorithm != "VALID") { + return false; + } + auto groups = match_ctx.Attr("groups"); + if (groups < 1) { + return false; + } + auto data_format = match_ctx.Attr("data_format"); + if (data_format != "NCHW" && data_format != "AnyLayout") { + return false; + } + return true; + }); + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_conv2d_add = + res.Op(fused_conv_name_, + {{ + {"strides", pat.Attr("strides")}, + {"paddings", pat.Attr("paddings")}, + {"padding_algorithm", pat.Attr("padding_algorithm")}, + {"dilations", pat.Attr("dilations")}, + {"groups", pat.Attr("groups")}, + {"data_format", pat.Attr("data_format")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_residual_connection", res.BoolAttr(true)}, + {"force_fp32_output", pat.Attr("force_fp32_output")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"scale_in", pat.Attr("scale_in")}, + {"scale_out", pat.Attr("scale_out")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_weights", pat.Attr("scale_weights")}, + }}); + + fused_conv2d_add({&res.Tensor("input"), + &res.Tensor("filter"), + &res.Tensor("bias"), + &res.Tensor("residual_param")}, + {&res.Tensor("add_out")}); + } +}; + +class ConvElementwiseAddFusePass : public pir::PatternRewritePass { + public: + ConvElementwiseAddFusePass() + : pir::PatternRewritePass("conv_elementwise_add_mkldnn_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(paddle::drr::Create( + context, + paddle::dialect::Conv2dOp::name(), + paddle::onednn::dialect::FusedConv2dOp::name())); + ps.Add(paddle::drr::Create( + context, + paddle::dialect::Conv2dOp::name(), + paddle::onednn::dialect::FusedConv2dOp::name())); + // conv + bias -> fused_conv2d, fused_conv2d + residual -> fused_conv2d + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedConv2dOp::name(), + paddle::onednn::dialect::FusedConv2dOp::name())); + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedConv2dOp::name(), + paddle::onednn::dialect::FusedConv2dOp::name())); + + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateConvElementwiseAddFusePass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(conv_elementwise_add_mkldnn_fuse_pass, + ConvElementwiseAddFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000..2f199a0eb8a0a --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateConvElementwiseAddFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.cc new file mode 100644 index 0000000000000..1db28281578d4 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.cc @@ -0,0 +1,704 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { +std::set act_ops = {{paddle::dialect::AbsOp::name()}, + {paddle::dialect::GeluOp::name()}, + {paddle::dialect::HardsigmoidOp::name()}, + {paddle::dialect::HardswishOp::name()}, + {paddle::dialect::LeakyReluOp::name()}, + {paddle::dialect::MishOp::name()}, + {paddle::dialect::ReluOp::name()}, + {paddle::dialect::Relu6Op::name()}, + {paddle::dialect::SigmoidOp::name()}, + {paddle::dialect::SqrtOp::name()}, + {paddle::dialect::SwishOp::name()}, + {paddle::dialect::TanhOp::name()}}; + +std::unordered_map activation_type = { + {paddle::dialect::AbsOp::name(), "abs"}, + {paddle::dialect::GeluOp::name(), "gelu"}, + {paddle::dialect::HardsigmoidOp::name(), "hard_sigmoid"}, + {paddle::dialect::HardswishOp::name(), "hard_swish"}, + {paddle::dialect::LeakyReluOp::name(), "leaky_relu"}, + {paddle::dialect::MishOp::name(), "mish"}, + {paddle::dialect::ReluOp::name(), "relu"}, + {paddle::dialect::Relu6Op::name(), "relu6"}, + {paddle::dialect::SigmoidOp::name(), "sigmoid"}, + {paddle::dialect::SqrtOp::name(), "sqrt"}, + {paddle::dialect::SwishOp::name(), "swish"}, + {paddle::dialect::TanhOp::name(), "tanh"}}; + +class MatmulActivationFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + std::string act_type_; + + public: + MatmulActivationFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + const std::string &act_type) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + act_type_(act_type) {} + + std::string name() const override { return "MatmulActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + std::unordered_map act_attrs; + if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + act_attrs.emplace("slope", pat.Attr("fuse_alpha")); + act_attrs.emplace("offset", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + act_attrs.emplace("negative_slope", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::GeluOp::name()) { + act_attrs.emplace("approximate", pat.Attr("approximate")); + } + + const auto &act = pat.Op(act_type_, act_attrs); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + if (act_type_ == paddle::dialect::GeluOp::name()) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (result_gelu) return false; + return true; + }); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + if (act_type_ == paddle::dialect::HardswishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f)); + fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f)); + } else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::SwishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f)); + } else if (act_type_ == paddle::dialect::Relu6Op::name()) { + fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f)); + } + + fused_attrs.insert(std::make_pair("fuse_activation", + res.StrAttr(activation_type[act_type_]))); + fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f))); + fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f))); + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()}, + {&res.Tensor("act_out")}); + } +}; + +class MatmulGeluTanhFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + + public: + MatmulGeluTanhFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { return "MatmulActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &act = pat.Op(paddle::dialect::GeluOp::name(), + {{"approximate", pat.Attr("approximate")}}); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (!result_gelu) return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fuse_activation", res.StrAttr("gelu_tanh")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()}, + {&res.Tensor("act_out")}); + } +}; + +class MatmulClipFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + + public: + MatmulClipFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { return "MatmulActivationFusePattern"; } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &full1 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape1")}, {"value", pat.Attr("value1")}}); + const auto &full2 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape2")}, {"value", pat.Attr("value2")}}); + pat.Tensor("min") = full1(); + pat.Tensor("max") = full2(); + + const auto &act = pat.Op(paddle::dialect::ClipOp::name()); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = + act(pat.Tensor("Out"), pat.Tensor("min"), pat.Tensor("max")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fuse_activation", res.StrAttr("clip")}, + {"fuse_alpha", pat.Attr("value1")}, + {"fuse_beta", pat.Attr("value2")}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.InputNoneTensor()}, + {&res.Tensor("act_out")}); + } +}; + +class FusedMatmulActivationFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + std::string act_type_; + + public: + FusedMatmulActivationFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + const std::string &act_type) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + act_type_(act_type) {} + + std::string name() const override { + return "FusedMatmulActivationFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + std::unordered_map act_attrs; + if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + act_attrs.emplace("slope", pat.Attr("fuse_alpha")); + act_attrs.emplace("offset", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + act_attrs.emplace("negative_slope", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::GeluOp::name()) { + act_attrs.emplace("approximate", pat.Attr("approximate")); + } + + const auto &act = pat.Op(act_type_, act_attrs); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + auto act_type = match_ctx.Attr("fuse_activation"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0 || + act_type != "") { + return false; + } + return true; + }); + if (act_type_ == paddle::dialect::GeluOp::name()) { + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (result_gelu) return false; + return true; + }); + } + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + if (act_type_ == paddle::dialect::HardswishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f / 6.0f)); + fused_attrs.emplace("fuse_beta", res.Float32Attr(1.0f / 2.0f)); + } else if (act_type_ == paddle::dialect::HardsigmoidOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + fused_attrs.emplace("fuse_beta", pat.Attr("fuse_beta")); + } else if (act_type_ == paddle::dialect::LeakyReluOp::name()) { + fused_attrs.emplace("fuse_alpha", pat.Attr("fuse_alpha")); + } else if (act_type_ == paddle::dialect::SwishOp::name()) { + fused_attrs.emplace("fuse_alpha", res.Float32Attr(1.0f)); + } else if (act_type_ == paddle::dialect::Relu6Op::name()) { + fused_attrs.emplace("fuse_beta", res.Float32Attr(6.0f)); + } + + fused_attrs.insert(std::make_pair("fuse_activation", + res.StrAttr(activation_type[act_type_]))); + fused_attrs.insert(std::make_pair("fuse_alpha", res.Float32Attr(0.0f))); + fused_attrs.insert(std::make_pair("fuse_beta", res.Float32Attr(0.0f))); + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("act_out")}); + } +}; + +class FusedMatmulGeluTanhFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + + public: + FusedMatmulGeluTanhFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { + return "FusedMatmulActivationFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + const auto &act = pat.Op(paddle::dialect::GeluOp::name(), + {{"approximate", pat.Attr("approximate")}}); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = act(pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + auto act_type = match_ctx.Attr("fuse_activation"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0 || + act_type != "") { + return false; + } + return true; + }); + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto result_gelu = match_ctx.Attr("approximate"); + if (!result_gelu) return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", res.StrAttr("gelu_tanh")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("act_out")}); + } +}; + +class FusedMatmulClipFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + std::string act_type_; + + public: + FusedMatmulClipFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit) {} + + std::string name() const override { + return "FusedMatmulActivationFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + const auto &full1 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape1")}, {"value", pat.Attr("value1")}}); + const auto &full2 = + pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape2")}, {"value", pat.Attr("value2")}}); + pat.Tensor("min") = full1(); + pat.Tensor("max") = full2(); + + const auto &act = pat.Op(paddle::dialect::ClipOp::name()); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + + pat.Tensor("act_out") = + act(pat.Tensor("Out"), pat.Tensor("min"), pat.Tensor("max")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + auto act_type = match_ctx.Attr("fuse_activation"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0 || + act_type != "") { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", res.StrAttr("clip")}, + {"fuse_alpha", pat.Attr("value1")}, + {"fuse_beta", pat.Attr("value2")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("act_out")}); + } +}; + +class MatmulActivationFusePass : public pir::PatternRewritePass { + public: + MatmulActivationFusePass() + : pir::PatternRewritePass("matmul_activation_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + // std::vector bool_set = {false, true}; + int benefit_idx = 1; + for (auto act_op : act_ops) { + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + act_op)); + benefit_idx++; + } + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + for (auto act_op : act_ops) { + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + act_op)); + benefit_idx++; + } + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx++)); + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateMatmulActivationFusePass() { + // pd_op.matmul + pd_op.relu -> onednn_op.fused_matmul + // pd_op.matmul + pd_op.add + pd_op.relu(act) -> onednn_op.fused_matmul + + // pd_op.relu(act) -> onednn_op.fused_matmul + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(matmul_activation_fuse_pass, MatmulActivationFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h new file mode 100644 index 0000000000000..87de94566ce91 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/matmul_activation_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateMatmulActivationFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.cc new file mode 100644 index 0000000000000..68354c52e2fe5 --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.cc @@ -0,0 +1,253 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { +class MatmulElementwiseAddFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + bool as_x_; // Decide input direction of add + + public: + MatmulElementwiseAddFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + bool as_x) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + as_x_(as_x) {} + + std::string name() const override { + return "MatmulElementwiseAddFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + matmul({&pat.Tensor("X"), &pat.Tensor("Y")}, {&pat.Tensor("Out")}); + + pat.Tensor("add_out") = + as_x_ ? add(pat.Tensor("Out"), pat.Tensor("residual")) + : add(pat.Tensor("residual"), pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_matmul = + res.Op(fused_matmul_name_, + {{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fuse_activation", res.StrAttr("")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_x", res.VectorInt32Attr({})}, + {"fused_transpose_x", res.VectorInt32Attr({})}, + {"fused_reshape_y", res.VectorInt32Attr({})}, + {"fused_transpose_y", res.VectorInt32Attr({})}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}, + }}); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("add_out")}); + } +}; + +class FusedMatmulElementwiseAddFusePattern + : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + bool as_x_; // Decide input direction of add + + public: + FusedMatmulElementwiseAddFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + bool as_x) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + as_x_(as_x) {} + + std::string name() const override { + return "FusedMatmulElementwiseAddFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + matmul({&pat.Tensor("X"), &pat.Tensor("Y"), &pat.Tensor("none")}, + {&pat.Tensor("Out")}); + + pat.Tensor("add_out") = + as_x_ ? add(pat.Tensor("Out"), pat.Tensor("residual")) + : add(pat.Tensor("residual"), pat.Tensor("Out")); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto none_tensor = match_ctx.Tensor("none"); + if (none_tensor.impl() != nullptr) { + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &fused_matmul = + res.Op(fused_matmul_name_, + {{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}, + }}); + + fused_matmul({&res.Tensor("X"), &res.Tensor("Y"), &res.Tensor("residual")}, + {&res.Tensor("add_out")}); + } +}; + +class MatmulElementwiseAddFusePass : public pir::PatternRewritePass { + public: + MatmulElementwiseAddFusePass() + : pir::PatternRewritePass("matmul_elementwise_add_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + std::vector bool_set = {false, true}; + int benefit_idx = 1; + for (auto as_x : bool_set) { + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + as_x)); + benefit_idx++; + } + + for (auto as_x : bool_set) { + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + as_x)); + benefit_idx++; + } + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateMatmulElementwiseAddFusePass() { + // pd_op.matmul + pd_op.add -> onednn_op.fused_matmul + // onednn_op.fused_matmul + pd_op.add -> onednn_op.fused_matmul + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(matmul_elementwise_add_fuse_pass, + MatmulElementwiseAddFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.h new file mode 100644 index 0000000000000..039b97cba2e1b --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/matmul_elementwise_add_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateMatmulElementwiseAddFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.cc b/paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.cc new file mode 100644 index 0000000000000..d317fc006300c --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.cc @@ -0,0 +1,355 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { +class ReshapeTransposeMatmulFusePattern : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + bool as_x_; // decide if the output of transpose is for input_x of matmul + + public: + ReshapeTransposeMatmulFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + bool as_x) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + as_x_(as_x) {} + + std::string name() const override { + return "ReshapeTransposeMatmulFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &full_int_array = pat.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pat.Attr("int_array")}}); + pat.Tensor("shape") = full_int_array(); + + const auto &reshape = pat.Op(paddle::dialect::ReshapeOp::name()); + reshape({&pat.Tensor("reshape_in"), &pat.Tensor("shape")}, + {&pat.Tensor("reshape_out"), &pat.Tensor("Xshape")}); + + const auto &transpose = pat.Op(paddle::dialect::TransposeOp::name(), + {{"perm", pat.Attr("perm")}}); + pat.Tensor("transpose_out") = transpose(pat.Tensor("reshape_out")); + + const auto &matmul = pat.Op(matmul_name_, + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + if (as_x_) { + matmul({&pat.Tensor("transpose_out"), &pat.Tensor("other")}, + {&pat.Tensor("Out")}); + } else { + matmul({&pat.Tensor("other"), &pat.Tensor("transpose_out")}, + {&pat.Tensor("Out")}); + } + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto shape = match_ctx.Attr>("int_array"); + auto perm = match_ctx.Attr>("perm"); + if (shape.size() < 2 || shape.size() > 4) return false; + if (shape.size() != perm.size()) return false; + if (std::count(shape.begin(), shape.end(), -1) > 1) return false; + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", res.Float32Attr(1.0f)}, + {"fuse_activation", res.StrAttr("")}, + {"fuse_alpha", res.Float32Attr(0.0f)}, + {"fuse_beta", res.Float32Attr(0.0f)}, + {"fused_output_scale", res.Float32Attr(1.0f)}, + {"fused_reshape_out", res.VectorInt32Attr({})}, + {"fused_transpose_out", res.VectorInt32Attr({})}, + {"mkldnn_data_type", res.StrAttr("float32")}, + {"scale_x", res.Float32Attr(1.0f)}, + {"scale_y", res.Float32Attr(1.0f)}, + {"scale_in_eltwise", res.Float32Attr(0.0f)}, + {"scale_out", res.Float32Attr(1.0f)}, + {"force_fp32_output", res.BoolAttr(false)}}; + + const auto &fused_reshape_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + std::vector int_array_value; + auto shape = match_ctx.Attr>("int_array"); + for (auto i : shape) { + int_array_value.emplace_back(static_cast(i)); + } + return int_array_value; + }); + + if (as_x_) { + fused_attrs.emplace("fused_reshape_x", fused_reshape_attr); + fused_attrs.emplace("fused_transpose_x", pat.Attr("perm")); + fused_attrs.emplace("fused_reshape_y", res.VectorInt32Attr({})); + fused_attrs.emplace("fused_transpose_y", res.VectorInt32Attr({})); + } else { + fused_attrs.emplace("fused_reshape_x", res.VectorInt32Attr({})); + fused_attrs.emplace("fused_transpose_x", res.VectorInt32Attr({})); + fused_attrs.emplace("fused_reshape_y", fused_reshape_attr); + fused_attrs.emplace("fused_transpose_y", pat.Attr("perm")); + } + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + if (as_x_) { + fused_matmul({&res.Tensor("reshape_in"), + &res.Tensor("other"), + &res.InputNoneTensor()}, + {&res.Tensor("Out")}); + } else { + fused_matmul({&res.Tensor("other"), + &res.Tensor("reshape_in"), + &res.InputNoneTensor()}, + {&res.Tensor("Out")}); + } + } +}; + +class ReshapeTransposeFusedMatmulFusePattern + : public paddle::drr::DrrPatternBase { + private: + std::string matmul_name_; + std::string fused_matmul_name_; + uint32_t benefit_; + bool as_x_; // decide if the output of transpose is for input_x of matmul + + public: + ReshapeTransposeFusedMatmulFusePattern(const std::string &matmul_name, + const std::string &fused_matmul_name, + uint32_t benefit, + bool as_x) + : matmul_name_(matmul_name), + fused_matmul_name_(fused_matmul_name), + benefit_(benefit), + as_x_(as_x) {} + + std::string name() const override { + return "ReshapeTransposFusedMatmulFusePattern"; + } + + uint32_t benefit() const override { return benefit_; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + + const auto &full_int_array = pat.Op(paddle::dialect::FullIntArrayOp::name(), + {{"value", pat.Attr("int_array")}}); + pat.Tensor("shape") = full_int_array(); + + const auto &reshape = pat.Op(paddle::dialect::ReshapeOp::name()); + reshape({&pat.Tensor("reshape_in"), &pat.Tensor("shape")}, + {&pat.Tensor("reshape_out"), &pat.Tensor("Xshape")}); + + const auto &transpose = pat.Op(paddle::dialect::TransposeOp::name(), + {{"perm", pat.Attr("perm")}}); + pat.Tensor("transpose_out") = transpose(pat.Tensor("reshape_out")); + + const auto &matmul = + pat.Op(matmul_name_, + {{"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_x", pat.Attr("fused_reshape_x")}, + {"fused_transpose_x", pat.Attr("fused_transpose_x")}, + {"fused_reshape_y", pat.Attr("fused_reshape_y")}, + {"fused_transpose_y", pat.Attr("fused_transpose_y")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}); + if (as_x_) { + matmul({&pat.Tensor("transpose_out"), + &pat.Tensor("other"), + &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + } else { + matmul({&pat.Tensor("other"), + &pat.Tensor("transpose_out"), + &pat.Tensor("residual")}, + {&pat.Tensor("Out")}); + } + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::set bool_sets = {true, false}; + auto result_x = match_ctx.Attr("transpose_x"); + auto result_y = match_ctx.Attr("transpose_y"); + if (bool_sets.count(result_x) == 0 || bool_sets.count(result_y) == 0) { + return false; + } + return true; + }); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + auto shape = match_ctx.Attr>("int_array"); + auto perm = match_ctx.Attr>("perm"); + if (shape.size() < 2 || shape.size() > 4) return false; + if (shape.size() != perm.size()) return false; + if (std::count(shape.begin(), shape.end(), -1) > 1) return false; + + return true; + }); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + if (as_x_) { + if (!(match_ctx.Attr>("fused_reshape_x").empty())) + return false; + } else { + if (!(match_ctx.Attr>("fused_reshape_y").empty())) + return false; + } + return true; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + std::unordered_map fused_attrs{ + {"trans_x", pat.Attr("transpose_x")}, + {"trans_y", pat.Attr("transpose_y")}, + {"matmul_alpha", pat.Attr("matmul_alpha")}, + {"fuse_activation", pat.Attr("fuse_activation")}, + {"fuse_alpha", pat.Attr("fuse_alpha")}, + {"fuse_beta", pat.Attr("fuse_beta")}, + {"fused_output_scale", pat.Attr("fused_output_scale")}, + {"fused_reshape_out", pat.Attr("fused_reshape_out")}, + {"fused_transpose_out", pat.Attr("fused_transpose_out")}, + {"mkldnn_data_type", pat.Attr("mkldnn_data_type")}, + {"scale_x", pat.Attr("scale_x")}, + {"scale_y", pat.Attr("scale_y")}, + {"scale_in_eltwise", pat.Attr("scale_in_eltwise")}, + {"scale_out", pat.Attr("scale_out")}, + {"force_fp32_output", pat.Attr("force_fp32_output")}}; + + const auto &fused_reshape_attr = res.ComputeAttr( + [](const paddle::drr::MatchContext &match_ctx) -> std::vector { + std::vector int_array_value; + auto shape = match_ctx.Attr>("int_array"); + for (auto i : shape) { + int_array_value.emplace_back(static_cast(i)); + } + return int_array_value; + }); + + if (as_x_) { + fused_attrs.emplace("fused_reshape_x", fused_reshape_attr); + fused_attrs.emplace("fused_transpose_x", pat.Attr("perm")); + fused_attrs.emplace("fused_reshape_y", pat.Attr("fused_reshape_y")); + fused_attrs.emplace("fused_transpose_y", pat.Attr("fused_transpose_y")); + } else { + fused_attrs.emplace("fused_reshape_x", pat.Attr("fused_reshape_x")); + fused_attrs.emplace("fused_transpose_x", pat.Attr("fused_transpose_x")); + fused_attrs.emplace("fused_reshape_y", fused_reshape_attr); + fused_attrs.emplace("fused_transpose_y", pat.Attr("perm")); + } + + const auto &fused_matmul = res.Op(fused_matmul_name_, fused_attrs); + + if (as_x_) { + fused_matmul({&res.Tensor("reshape_in"), + &res.Tensor("other"), + &res.Tensor("residual")}, + {&res.Tensor("Out")}); + } else { + fused_matmul({&res.Tensor("other"), + &res.Tensor("reshape_in"), + &res.Tensor("residual")}, + {&res.Tensor("Out")}); + } + } +}; + +class ReshapeTransposeMatmulFusePass : public pir::PatternRewritePass { + public: + ReshapeTransposeMatmulFusePass() + : pir::PatternRewritePass("reshape_transpose_matmul_fuse_pass", 3) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + std::vector bool_set = {false, true}; + int benefit_idx = 5; + for (auto as_x : bool_set) { + ps.Add(paddle::drr::Create( + context, + paddle::dialect::MatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + as_x)); + benefit_idx--; + } + + for (auto as_x : bool_set) { + ps.Add(paddle::drr::Create( + context, + paddle::onednn::dialect::FusedMatmulOp::name(), + paddle::onednn::dialect::FusedMatmulOp::name(), + benefit_idx, + as_x)); + benefit_idx--; + } + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateReshapeTransposeMatmulFusePass() { + // pd_op.reshape + pd_op.transpose + pd_op.matmul -> onednn_op.fused_matmul + // pd_op.reshape + pd_op.transpose + pd_op.fused_matmul -> + // onednn_op.fused_matmul + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(reshape_transpose_matmul_fuse_pass, + ReshapeTransposeMatmulFusePass); diff --git a/paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.h b/paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.h new file mode 100644 index 0000000000000..71b5fe47f034b --- /dev/null +++ b/paddle/fluid/pir/transforms/onednn/reshape_transpose_matmul_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateReshapeTransposeMatmulFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/passes.h b/paddle/fluid/pir/transforms/passes.h new file mode 100644 index 0000000000000..2423bfbc8efc2 --- /dev/null +++ b/paddle/fluid/pir/transforms/passes.h @@ -0,0 +1,55 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pir/include/pass/pass_registry.h" + +USE_PIR_PASS(dead_code_elimination_pass); +USE_PIR_PASS(multihead_matmul_fuse_pass); +USE_PIR_PASS(transpose_flatten_concat_fuse_pass); +USE_PIR_PASS(fused_gemm_epilogue_pass); +USE_PIR_PASS(fused_dropout_add_pass); +USE_PIR_PASS(fused_weight_only_linear_pass); +USE_PIR_PASS(fused_linear_param_grad_add_pass); +USE_PIR_PASS(inplace_pass); +USE_PIR_PASS(replace_fetch_with_shadow_output_pass); +USE_PIR_PASS(identity_op_clean_pass); +USE_PIR_PASS(map_op_to_another_pass); +USE_PIR_PASS(matmul_scale_fuse_pass); +USE_PIR_PASS(matmul_transpose_fuse_pass); +USE_PIR_PASS(fc_fuse_pass); +USE_PIR_PASS(silu_fuse_pass); +USE_PIR_PASS(fc_elementwise_layernorm_fuse_pass); +USE_PIR_PASS(conv2d_bn_fuse_pass); +USE_PIR_PASS(conv2d_add_fuse_pass); +USE_PIR_PASS(conv2d_add_act_fuse_pass); +USE_PIR_PASS(embedding_eltwise_layernorm_fuse_pass); +USE_PIR_PASS(add_norm_fuse_pass); +USE_PIR_PASS(fused_dot_product_attention_pass); + +#ifdef PADDLE_WITH_DNNL +USE_PIR_PASS(batch_norm_act_fuse_pass); +USE_PIR_PASS(conv2d_bias_fuse_pass); +USE_PIR_PASS(conv2d_transpose_bias_fuse_pass); +USE_PIR_PASS(conv3d_bias_fuse_pass); +USE_PIR_PASS(reshape_transpose_matmul_fuse_pass); +USE_PIR_PASS(matmul_elementwise_add_fuse_pass); +USE_PIR_PASS(matmul_activation_fuse_pass); +USE_PIR_PASS(conv_elementwise_add_mkldnn_fuse_pass); +#endif + +#ifdef PADDLE_WITH_XPU +USE_PIR_PASS(add_layernorm_xpu_fuse_pass); +#endif diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 3450140741e21..182aa009a020c 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -28,6 +28,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" @@ -38,7 +39,7 @@ #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" @@ -75,6 +76,14 @@ pir::Type ConvertOpTypeToKernelType(pir::IrContext* ctx, } else if (op_type.isa()) { return AllocatedSelectedRowsType::get( ctx, place, op_type.dyn_cast()); + } else if (op_type.isa()) { + auto vec_type = op_type.dyn_cast(); + std::vector vec_target_type; + for (size_t i = 0; i < vec_type.size(); ++i) { + vec_target_type.push_back( + ConvertOpTypeToKernelType(ctx, vec_type[i], place)); + } + return pir::VectorType::get(ctx, vec_target_type); } PADDLE_THROW(platform::errors::Unimplemented( "Not support op type %s in ConvertOpTypeToKernelType.", op_type)); @@ -83,15 +92,15 @@ pir::Type ConvertOpTypeToKernelType(pir::IrContext* ctx, static const std::vector InferMetaByValue( pir::Operation* op, const std::vector& input_values, - const pir::AttributeMap& attribute_map) { + pir::AttributeMap* p_attribute_map) { // NOLINT pir::OpInfo op_info = pir::IrContext::Instance()->GetRegisteredOpInfo(op->name()); auto infer_meta_interface = op_info.GetInterfaceImpl(); std::vector output_types; if (infer_meta_interface) { - output_types = - infer_meta_interface->infer_meta_by_value_(input_values, attribute_map); + output_types = infer_meta_interface->infer_meta_by_value_(input_values, + p_attribute_map); } return output_types; } @@ -367,18 +376,35 @@ static pir::Value AddPlaceTransferOp(pir::Value in, pir::IrContext* ctx = pir::IrContext::Instance(); auto copy_kernel_key = kernel_key; + auto place2backend = [](phi::AllocationType new_place_type) { + auto new_backend = phi::Backend::GPU; + switch (new_place_type) { + case phi::AllocationType::GPU: + new_backend = phi::Backend::GPU; + break; + case phi::AllocationType::XPU: + new_backend = phi::Backend::XPU; + break; + default: + new_backend = phi::Backend::CPU; + break; + } + return new_backend; + }; std::unordered_map op_attribute; if ((src_place.GetType() == phi::AllocationType::CPU) && - (dst_place.GetType() == phi::AllocationType::GPU)) { - copy_kernel_key.set_backend(phi::Backend::GPU); + (dst_place.GetType() == phi::AllocationType::GPU || + dst_place.GetType() == phi::AllocationType::XPU)) { + copy_kernel_key.set_backend(place2backend(dst_place.GetType())); op_attribute = { {"op_name", pir::StrAttribute::get(ctx, "pd_op.memcpy_h2d")}, {"kernel_name", pir::StrAttribute::get(ctx, "memcpy_h2d")}, {"kernel_key", KernelAttribute::get(ctx, copy_kernel_key)}, {"dst_place_type", pir::Int32Attribute::get(ctx, 1)}}; - } else if ((src_place.GetType() == phi::AllocationType::GPU) && + } else if ((src_place.GetType() == phi::AllocationType::GPU || + src_place.GetType() == phi::AllocationType::XPU) && (dst_place.GetType() == phi::AllocationType::CPU)) { - copy_kernel_key.set_backend(phi::Backend::GPU); + copy_kernel_key.set_backend(place2backend(dst_place.GetType())); std::string copy_kernel_name = "memcpy_d2h"; if (in.type().isa()) { copy_kernel_name = "memcpy_d2h_multi_io"; @@ -643,8 +669,7 @@ static phi::DataType GetKernelDtypeByYaml( auto& data_type_info = op_info_parser->OpRuntimeInfo().kernel_key_dtype; phi::DataType kernel_data_type = phi::DataType::UNDEFINED; - for (size_t i = 0; i < data_type_info.size(); ++i) { - auto slot_name = data_type_info[i]; + for (auto slot_name : data_type_info) { auto& input_map = op_info_parser->InputName2Id(); bool is_complex_tag = false; @@ -729,8 +754,7 @@ static phi::Backend GetKernelBackendByYaml( auto& backend_info = op_info_parser->OpRuntimeInfo().kernel_key_backend; phi::Backend kernel_backend = phi::Backend::UNDEFINED; - for (size_t i = 0; i < backend_info.size(); ++i) { - auto slot_name = backend_info[i]; + for (auto slot_name : backend_info) { auto& input_map = op_info_parser->InputName2Id(); if (input_map.count(slot_name)) { @@ -812,7 +836,7 @@ std::string GetKernelName(const OpYamlInfoParser* op_info_parser, kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func; } - if (op_item->isa() || op_item->isa()) { + if (op_item->isa() || op_item->isa()) { if (op_item->result(0).type().isa()) { kernel_fn_str = "add_n_sr"; } @@ -1359,6 +1383,119 @@ phi::DataType ParsePhiDType(pir::Type type) { } } +void AddShadowFeedForValue( + size_t index, + pir::Operation* op_item, + pir::Operation* op_item_with_place, + pir::Block* block, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + if (op_item->result(index).type().isa()) { + phi::KernelKey shadow_key{ + phi::Backend::GPU, + phi::DataLayout::ANY, + TransToPhiDataType( + op_item->result(index).type().dyn_cast().dtype())}; + std::unordered_map attr_map{ + {"op_name", pir::StrAttribute::get(ctx, "pd_op.shadow_feed")}, + {"kernel_name", pir::StrAttribute::get(ctx, "shadow_feed")}, + {"kernel_key", KernelAttribute::get(ctx, shadow_key)}}; + + auto out_type = AllocatedDenseTensorType::get( + ctx, + phi::TransToPhiPlace(shadow_key.backend()), + op_item->result(index).type().dyn_cast()); + + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(PhiKernelOp::name()); + pir::Operation* shadow_op = + pir::Operation::Create({op_item_with_place->result(index)}, + attr_map, + {out_type}, + phi_kernel_op_info); + block->push_back(shadow_op); + (*map_op_pair)[op_item] = shadow_op; + (*map_value_pair)[op_item->result(index)] = shadow_op->result(0); + } else if (op_item->result(index).type().isa()) { + auto vec_type = op_item->result(index).type().dyn_cast(); + for (size_t i = 0; i < vec_type.size(); ++i) { + PADDLE_ENFORCE_EQ( + vec_type[i].isa(), + true, + phi::errors::PreconditionNotMet( + "AddShadowFeedTensors only support DenseTensorType Now")); + } + // Add ShadowFeedTensors Op + phi::KernelKey shadow_key{ + phi::Backend::GPU, + phi::DataLayout::ANY, + TransToPhiDataType(vec_type[0].dyn_cast().dtype())}; + + std::unordered_map attr_map{ + {"op_name", pir::StrAttribute::get(ctx, "pd_op.shadow_feed_tensors")}, + {"kernel_name", pir::StrAttribute::get(ctx, "shadow_feed_tensors")}, + {"kernel_key", KernelAttribute::get(ctx, shadow_key)}}; + + pir::OpInfo phi_kernel_op_info = + ctx->GetRegisteredOpInfo(PhiKernelOp::name()); + + std::vector vec_out_types; + for (size_t i = 0; i < vec_type.size(); ++i) { + vec_out_types.push_back(AllocatedDenseTensorType::get( + ctx, + phi::TransToPhiPlace(shadow_key.backend()), + vec_type[i].dyn_cast())); + } + auto out_type = pir::VectorType::get(ctx, vec_out_types); + pir::Operation* shadow_tensors_op = + pir::Operation::Create({op_item_with_place->result(index)}, + attr_map, + {out_type}, + phi_kernel_op_info); + block->push_back(shadow_tensors_op); + (*map_op_pair)[op_item] = shadow_tensors_op; + (*map_value_pair)[op_item->result(index)] = shadow_tensors_op->result(0); + } else { + PADDLE_THROW( + phi::errors::Unimplemented("AddShadowFeed for value only support " + "DenseTensorType and VectorType Now")); + } +} + +void AddShadowFeedForTuplePopOp( + const phi::Place& place, + pir::Operation* op_item, + pir::Operation* op_item_with_undefined_place, + pir::Block* block, + pir::IrContext* ctx, + std::unordered_map* map_op_pair, + std::unordered_map* map_value_pair) { + VLOG(4) << "Add AddShadowFeed for op " << op_item->name(); + + bool add_shadow_feed = true; + if (op_item->attributes().count("place")) { + add_shadow_feed = (op_item->attributes() + .at("place") + .dyn_cast() + .data() + .GetType()) == phi::AllocationType::UNDEFINED; + } + + // if value place not gpu, add shadow feed op + if (platform::is_gpu_place(place) && add_shadow_feed) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + AddShadowFeedForValue(i, + op_item, + op_item_with_undefined_place, + block, + ctx, + map_op_pair, + map_value_pair); + } + } +} + void HandleForSpecialOp( const phi::Place& place, pir::Operation* op_item, @@ -1629,17 +1766,46 @@ void HandleForSpecialOp( } auto pop_back_op = op_item->dyn_cast<::pir::TuplePopOp>(); - for (size_t i = 0; i < op_item->num_results(); ++i) { - auto cur_inlet_element = pop_back_op.inlet_element(i); - PADDLE_ENFORCE_EQ(map_value_pair->count(cur_inlet_element), - true, - phi::errors::PreconditionNotMet( - "[%d]'s output of [%s] op MUST be in map pair", - i, - op_item->name())); - auto new_inlet_element = map_value_pair->at(cur_inlet_element); - op_output_types.push_back(new_inlet_element.type()); + if (pop_back_op.has_container()) { + // if TuplePopOp and TuplePushOp are in the same sub_program + for (size_t i = 0; i < op_item->num_results(); ++i) { + auto cur_inlet_element = pop_back_op.inlet_element(i); + PADDLE_ENFORCE_EQ(map_value_pair->count(cur_inlet_element), + true, + phi::errors::PreconditionNotMet( + "[%d]'s output of [%s] op MUST be in map pair", + i, + op_item->name())); + auto new_inlet_element = map_value_pair->at(cur_inlet_element); + + op_output_types.push_back(new_inlet_element.type()); + } + } else { + VLOG(4) << "TuplePopOp and TuplePushOp are in different sub_program."; + for (size_t i = 0; i < op_item->num_results(); ++i) { + auto cur_inlet_element = op_item->result(i); + auto out_place = phi::TransToPhiPlace(phi::Backend::UNDEFINED); + pir::Type new_inlet_element_type = + ConvertOpTypeToKernelType(ctx, cur_inlet_element.type(), out_place); + op_output_types.push_back(new_inlet_element_type); + } + + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name()); + pir::Operation* op = pir::Operation::Create( + vec_inputs, op_item->attributes(), op_output_types, op_info); + + block->push_back(op); + (*map_op_pair)[op_item] = op; + // only deal with single output + if (op_item->num_results() > 0) { + for (size_t i = 0; i < op_item->num_results(); ++i) { + (*map_value_pair)[op_item->result(i)] = op->result(i); + } + } + AddShadowFeedForTuplePopOp( + place, op_item, op, block, ctx, map_op_pair, map_value_pair); + return; } } @@ -1668,17 +1834,38 @@ void HandleForSpecialOp( } if (op_item->name() == "cinn_runtime.jit_kernel") { - if (op_item->num_operands() > 0) { - for (size_t i = 0; i < op_item->num_operands(); ++i) { - auto cur_in = op_item->operand_source(i); - if (!cur_in) { - vec_inputs.emplace_back(); - continue; + for (size_t i = 0; i < op_item->num_operands(); ++i) { + auto cur_in = op_item->operand_source(i); + if (!cur_in) { + vec_inputs.emplace_back(); + continue; + } + auto new_in = GetNewInput( + cur_in, *map_value_pair, static_cast(i), op_item->name()); + // For data transform + if (new_in.type().isa()) { + auto in_place = + new_in.type().dyn_cast().place(); + auto dst_backend = phi::TransToPhiBackend(place); + bool need_trans = + (in_place.GetType() != phi::AllocationType::UNDEFINED) && + (paddle::experimental::NeedTransformPlace( + in_place, dst_backend, {})); + if (need_trans) { + VLOG(6) << "need trans from " << in_place << " to " << dst_backend; + auto value_type = + op_item->operand_source(i).type().dyn_cast(); + auto out_place = phi::TransToPhiPlace(dst_backend); + auto out_type = + AllocatedDenseTensorType::get(ctx, out_place, value_type); + phi::KernelKey kernel_key(phi::Backend::GPU, + phi::DataLayout::ANY, + TransToPhiDataType(value_type.dtype())); + new_in = AddPlaceTransferOp( + new_in, out_type, in_place, out_place, kernel_key, block); } - auto new_in = GetNewInput( - cur_in, *map_value_pair, static_cast(i), op_item->name()); - vec_inputs.push_back(new_in); } + vec_inputs.push_back(new_in); } for (size_t i = 0; i < op_item->num_results(); ++i) { @@ -1925,7 +2112,7 @@ std::vector BuildOutputs( input_values.emplace_back(op_item->operand(i).source()); } std::vector output_types = - InferMetaByValue(op_item, input_values, attribute_map); + InferMetaByValue(op_item, input_values, &attribute_map); if (output_types.size() != 0) { PADDLE_ENFORCE_EQ( @@ -1959,7 +2146,7 @@ std::vector BuildOutputs( &op_output_types); } } else { - auto base_types = InferMetaByValue(op_item, new_vec_inputs, attribute_map); + auto base_types = InferMetaByValue(op_item, new_vec_inputs, &attribute_map); PADDLE_ENFORCE_EQ(base_types.size(), op_item->num_results(), phi::errors::PreconditionNotMet( @@ -2313,34 +2500,12 @@ void AddShadowFeedOpForDataOrFeed( .GetType() == phi::AllocationType::UNDEFINED); bool add_shadow_feed = feed_op_add_shadow_feed || data_op_add_shadow_feed; if (add_shadow_feed) { - // if shadow data op place not gpu,add shadow feed op - phi::KernelKey shadow_key{ - phi::Backend::GPU, - phi::DataLayout::ANY, - TransToPhiDataType( - op_item->result(0).type().dyn_cast().dtype())}; - std::unordered_map attr_map{ - {"op_name", pir::StrAttribute::get(ctx, "pd_op.shadow_feed")}, - {"kernel_name", pir::StrAttribute::get(ctx, "shadow_feed")}, - {"kernel_key", KernelAttribute::get(ctx, shadow_key)}}; - - auto out_type = AllocatedDenseTensorType::get( - ctx, - phi::TransToPhiPlace(shadow_key.backend()), - op_item->result(0).type().dyn_cast()); - - pir::OpInfo phi_kernel_op_info = - ctx->GetRegisteredOpInfo(PhiKernelOp::name()); - pir::Operation* shadow_op = pir::Operation::Create( - {kernel_op->result(0)}, attr_map, {out_type}, phi_kernel_op_info); - - (*map_op_pair)[op_item] = shadow_op; - block->push_back(shadow_op); - if (op_item->num_results() > 0) { - for (size_t i = 0; i < shadow_op->num_results(); ++i) { - (*map_value_pair)[op_item->result(i)] = shadow_op->result(i); - } - } + PADDLE_ENFORCE(op_item->num_results() == 1, + phi::errors::PreconditionNotMet( + "op_item should have only one result, but got %d", + op_item->num_results())); + AddShadowFeedForValue( + 0, op_item, kernel_op, block, ctx, map_op_pair, map_value_pair); } } diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.cc b/paddle/fluid/pir/transforms/shape_optimization_pass.cc index 80d56f75ae12b..d5ced352047da 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.cc +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.cc @@ -13,13 +13,33 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/shape_optimization_pass.h" +#include "paddle/common/flags.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/dialect.h" +#include "paddle/pir/include/core/ir_printer.h" #include "paddle/pir/include/dialect/shape/ir/shape_attribute.h" +#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/include/pass/pass_manager.h" #include "paddle/pir/include/pass/pass_registry.h" -const int vlog_level = 3; +COMMON_DECLARE_bool(pir_apply_shape_optimization_pass); + +constexpr int vlog_level = 3; + +// TODO(zhangbopd): Some op results infered by InferSymbolicShape is NOT consist +// with the result infered by InferMeta and should be fixed. +namespace { +bool NeedCheckInferSymbolicWithInferMeta(const std::string& op_name, + size_t result_idx) { + static std::unordered_map> blacklist{ + {"pd_op.reshape", {1}}, + {"pd_op.empty", {0}}, + }; + const auto& iter = blacklist.find(op_name); + if (iter == blacklist.end()) return true; + return iter->second.count(result_idx) == 0; +} +} // namespace namespace pir { namespace { @@ -27,22 +47,84 @@ namespace { using PassPipelineRunner = std::function; -void PrintProgram(pir::ModuleOp m, std::string mgs) { +void PrintProgram(pir::ModuleOp m, std::string msg) { ShapeConstraintIRAnalysis& shape_analysis = ShapeAnalysisManager::Instance().Get(m.program()); - VLOG(vlog_level) << "===================== " << mgs - << " =====================\n" - << pir::CustomPrintHelper(*m.program(), - shape_analysis.PrintHook()); + if (VLOG_IS_ON(vlog_level)) { + std::cerr << "===================== [ShapeDialect]" << msg + << " =====================\n" + << pir::CustomPrintHelper(*m.program(), + shape_analysis.PrintHook()) + << std::endl; + } +} + +std::string PrintOperationWithNoRegion(Operation* op) { + std::ostringstream os; + pir::IrPrinter printer(os); + + // print OpResults + os << "("; + auto num_op_result = op->num_results(); + for (size_t idx = 0; idx < num_op_result; idx++) { + os << "%op_" << op->id() << "_" << idx; + if (idx < num_op_result - 1) os << ", "; + } + os << ")"; + + os << " ="; + + // print OpName & OpId + os << " \"" << op->name() << "(op_" << op->id() << ")" + << "\""; + + // print OpOperands + os << " ("; + auto num_op_operands = op->num_operands(); + for (size_t idx = 0; idx < num_op_operands; idx++) { + const pir::Value& input = op->operand_source(idx); + if (input.defining_op()) { + os << "op_" << input.defining_op()->id() << "_" + << input.dyn_cast().index(); + } else { + os << "op_NULL"; + } + if (idx < num_op_operands - 1) os << ", "; + } + os << ")"; + + printer.PrintAttributeMap(op); + os << " :"; + + // PrintOpSignature + printer.PrintOperandsType(op); + os << " -> "; + + printer.PrintOpReturnType(op); + + return os.str(); +} + +void PrintOpInfo(pir::Operation* op) { + if (VLOG_IS_ON(vlog_level)) { + VLOG(vlog_level) << op->name() << "(op_id: op_" << op->id() + << ", num_results=" << op->num_results() << ")" + << " has InferSymbolicShapeInterface.\n\t" + << PrintOperationWithNoRegion(op); + if (op->name() == "cinn_op.group") { + std::cerr << "<<<<<<<<<<<<<<<<<<<< " << op->name() << "(op_id: op_" + << op->id() << ") START..." << std::endl; + } + } } void DebugPrintOpInfo( pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { - for (auto& res : op->results()) { - std::ostringstream print_stream; - - print_stream << " result(" << res.dyn_cast().index() << ") " + std::ostringstream print_stream; + for (uint32_t i = 0; i < op->num_results(); ++i) { + const auto& res = op->result(i); + print_stream << "\tresult(" << res.dyn_cast().index() << ") " << "ShapeOrData: {"; if (shape_analysis != nullptr) { @@ -74,8 +156,72 @@ void DebugPrintOpInfo( print_stream << "]"; } - print_stream << " }"; - VLOG(vlog_level) << print_stream.str(); + print_stream << " }\n"; + } + if (VLOG_IS_ON(vlog_level)) { + std::cerr << print_stream.str(); + } +} + +void CheckInferSymWithInferMeta( + pir::Operation* op, + pir::ShapeConstraintIRAnalysis* shape_analysis = nullptr) { + for (uint32_t i = 0; i < op->num_results(); ++i) { + const auto& res = op->result(i); + std::ostringstream print_stream; + + // InferMeta funcs of some Ops are not corrrect now, we don't check them. + if (!NeedCheckInferSymbolicWithInferMeta(op->name(), i)) continue; + + if (res.type().isa()) { + const std::vector& infer_meta_shape = common::vectorize( + res.type().dyn_cast().dims()); + const std::vector& infer_sym_shape = + shape_analysis->GetShapeOrDataForValue(res).shape(); + + // Check rank. + if (infer_meta_shape.size() != infer_sym_shape.size()) { + std::ostringstream print_stream; + print_stream << "Warning : Check InferSymbolicShape for " << op->name() + << " (op_" << op->id() << ") " + << " carefully! rank of infer_meta_shape is [" + << infer_meta_shape.size() + << "], but rank of infer_sym_shape is [" + << infer_sym_shape.size() << "]."; + VLOG(vlog_level) << print_stream.str(); + continue; + } + + // Check each dim. + for (size_t i = 0; i < infer_meta_shape.size(); ++i) { + // Check Static shape should NOT be a symbol. + if (infer_meta_shape[i] != -1) { + if (!infer_sym_shape[i].isa()) { + std::ostringstream print_stream; + print_stream + << "Warning : Check InferSymbolicShape for " << op->name() + << " (op_" << op->id() << ") " + << " carefully! " + << "shape[" << i + << "] of infer_sym_shape shoule be int64_t NOT a symbol!"; + VLOG(vlog_level) << print_stream.str(); + continue; + } + + // Check Static shape should be consist. + if (infer_meta_shape[i] != infer_sym_shape[i].dyn_cast()) { + std::ostringstream print_stream; + print_stream << "Warning : Check InferSymbolicShape for " + << op->name() << " (op_" << op->id() << ") " + << " carefully! " + << "infer_sym_shape is [" << infer_meta_shape[i] + << "], but infer_meta_shape is [" + << infer_sym_shape[i].dyn_cast() << "]."; + VLOG(vlog_level) << print_stream.str(); + } + } + } + } } } @@ -99,12 +245,15 @@ class ShapeOptimizationPass : public pir::Pass { << "===================== ShapeOptimizationPass Run start... " "====================="; auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); + PADDLE_ENFORCE_NOT_NULL( + module_op, + phi::errors::InvalidArgument( + "ShapeOptimizationPass should run on module op.")); PrintProgram(module_op, "Origin Program"); InferSymExprForAllValues(module_op); // Runner is for Canonicalizer. - PassPipelineRunner runner = [this](pir::PassManager& pm, pir::ModuleOp m) { + PassPipelineRunner runner = [](pir::PassManager& pm, pir::ModuleOp m) { pm.EnableIRPrinting(); return pm.Run(m.program()); }; @@ -127,12 +276,13 @@ void InferSymExprForBlock(const Block& block, auto infer_symbolic_shape_interface = op.dyn_cast(); if (infer_symbolic_shape_interface) { - VLOG(vlog_level) << op.name() << " has InferSymbolicShapeInterface."; + PrintOpInfo(&op); PADDLE_ENFORCE_EQ( infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis), true, "InferSymbolicShape for %s failed.", op.name()); + if (op.num_results() > 0) { // TODO(lanxianghit): deal with the ops which have more than 1 // ACTUAL results @@ -140,12 +290,11 @@ void InferSymExprForBlock(const Block& block, &op, shape_analysis->GetShapeOrDataForValue(op.result(0))); } } else { - VLOG(vlog_level) << op.name() + - " DOES NOT have InferSymbolicShapeInterface!"; PADDLE_THROW(phi::errors::Unimplemented( op.name() + " DOES NOT have InferSymbolicShapeInterface!")); } DebugPrintOpInfo(&op, shape_analysis); + CheckInferSymWithInferMeta(&op, shape_analysis); } } @@ -155,4 +304,38 @@ std::unique_ptr CreateShapeOptimizationPass() { } // namespace pir +namespace pir::shape { + +bool HasDynamicShape(const pir::Program& program) { + for (const auto& op : *program.block()) { + if (op.isa()) { + continue; + } + for (uint32_t i = 0; i < op.num_results(); ++i) { + if (op.result(i) && op.result(i).type()) { + auto shape_type = + op.result(i).type().dyn_cast(); + if (shape_type && shape_type.IsDynamicShape()) { + VLOG(vlog_level) << "###### HasDynamicShape == true"; + return true; + } + } + } + } + VLOG(vlog_level) << "###### HasDynamicShape == false"; + return false; +} + +void AddShapeOptimizationPass( + std::shared_ptr& pass_manager, // NOLINT + pir::Program& program) { // NOLINT + pir::IrContext* ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + if (HasDynamicShape(program) && FLAGS_pir_apply_shape_optimization_pass) { + pass_manager->AddPass(pir::CreateShapeOptimizationPass()); + } +} + +} // namespace pir::shape + REGISTER_IR_PASS(shape_optimization_pass, pir::ShapeOptimizationPass); diff --git a/paddle/fluid/pir/transforms/shape_optimization_pass.h b/paddle/fluid/pir/transforms/shape_optimization_pass.h index a23de56f35d6e..5050ea727e678 100644 --- a/paddle/fluid/pir/transforms/shape_optimization_pass.h +++ b/paddle/fluid/pir/transforms/shape_optimization_pass.h @@ -17,6 +17,7 @@ #include #include "paddle/pir/include/core/dll_decl.h" #include "paddle/pir/include/dialect/shape/utils/shape_analysis.h" +#include "paddle/pir/include/pass/pass_manager.h" namespace pir { @@ -28,3 +29,12 @@ void InferSymExprForBlock(const Block &block, ShapeConstraintIRAnalysis *shape_analysis); } // namespace pir + +namespace pir::shape { +bool HasDynamicShape(const pir::Program &program); + +void AddShapeOptimizationPass( + std::shared_ptr &pass_manager, // NOLINT + pir::Program &program); // NOLINT + +} // namespace pir::shape diff --git a/paddle/fluid/pir/transforms/sub_graph_detector.cc b/paddle/fluid/pir/transforms/sub_graph_detector.cc index 0690bc1c8399c..92753e3353529 100644 --- a/paddle/fluid/pir/transforms/sub_graph_detector.cc +++ b/paddle/fluid/pir/transforms/sub_graph_detector.cc @@ -16,6 +16,7 @@ #include +#include #include #include #include @@ -83,17 +84,20 @@ std::vector InverselyTopologicalSort(pir::Block* block) { } auto* defined_op = operand.source().defining_op(); --pending_count[defined_op]; - if (defined_op && pending_count[defined_op] == 0) { + if (defined_op && pending_count[defined_op] == 0 && + defined_op->GetParent() == block) { queue.push(defined_op); } } } - IR_ENFORCE( - block->size() == sort_ops.size(), - "sort_ops.size() must be equal to block.size(), but received %d != %d", + PADDLE_ENFORCE_EQ( block->size(), - sort_ops.size()); + sort_ops.size(), + phi::errors::InvalidArgument("sort_ops.size() must be equal to " + "block.size(), but received %d != %d", + block->size(), + sort_ops.size())); return sort_ops; } @@ -109,7 +113,8 @@ std::vector GetProducerOpsReverseSort( continue; } auto* source_op = operand.source().defining_op(); - if (source_op && !producers.count(source_op)) { + if (source_op && !producers.count(source_op) && + source_op->GetParent() == op->GetParent()) { producers.insert(source_op); PADDLE_ENFORCE( op2id.count(source_op), @@ -134,7 +139,8 @@ std::unordered_set GetProducerOps(pir::Operation* op) { if (!operand || !(operand.source())) { continue; } - if (auto* source_op = operand.source().defining_op()) { + auto* source_op = operand.source().defining_op(); + if (source_op && source_op->GetParent() == op->GetParent()) { producers.insert(source_op); } } @@ -316,11 +322,11 @@ bool SubgraphDetector::FuseSubGraph(SubGraphPtr subgraph_ptr) { if (!consumer->substitute) { continue; } - // fast depency check. + // fast dependency check. if (IsDependencySimplify(producer, consumer, consumers)) { continue; } - // global depency check. + // global dependency check. if (IsDependency(producer, consumer, consumers)) { continue; } @@ -341,7 +347,7 @@ bool SubgraphDetector::FuseSubGraph(SubGraphPtr subgraph_ptr) { producer->ops.end(), candidate->ops.begin(), candidate->ops.end()); producer->op_set.insert(candidate->op_set.begin(), candidate->op_set.end()); - // update bound for check depency + // update bound for check dependency producer->max_depth = std::max(producer->max_depth, candidate->max_depth); producer->min_depth = std::min(producer->min_depth, candidate->min_depth); @@ -364,7 +370,7 @@ bool SubgraphDetector::FuseSubGraph(SubGraphPtr subgraph_ptr) { tmp->producers.erase(candidate); } - // remove candicate in producer/consumer + // remove candidate in producer/consumer producer->producers.erase(candidate); producer->consumers.erase(candidate); @@ -387,7 +393,7 @@ bool SubgraphDetector::FuseSubGraph(SubGraphPtr subgraph_ptr) { return true; } -// check exist depency. +// check exist dependency. bool SubgraphDetector::IsDependency( const SubGraphPtr& producer_g, const SubGraphPtr& consumer, @@ -510,6 +516,74 @@ pir::Operation* FindInsertPoint(const GroupOpsVec& group_ops, } return insert_point_op; } + +struct IncrementalOrder { + bool operator()(const pir::Operation* lhs, const pir::Operation* rhs) const { + CHECK(lhs->GetParent() == rhs->GetParent()) + << "lhs and rhs should have same parent block."; + auto lhs_iter = lhs->operator Block::ConstIterator(); + auto rhs_iter = rhs->operator Block::ConstIterator(); + auto end_iter = lhs->GetParent()->end(); + while (lhs_iter != end_iter) { + lhs_iter++; + if (lhs_iter == rhs_iter) return true; + if (lhs_iter == end_iter) return false; + } + CHECK(false) << "rhs " << rhs->id() << " is not reachable from lhs " + << lhs->id(); + return false; + } +}; + +std::unordered_set GetUpstreamOpsAfterPosition( + const pir::Operation* position_op, + const pir::Block* block, + const pir::Operation* op, + std::unordered_set* visited_ops) { + std::unordered_set ops; + const auto& IsInBlock = [](const pir::Operation* src_op, + const pir::Block* block) { + for (auto& op : *block) { + if (src_op == &op) return true; + } + return false; + }; + + for (auto value : op->operands_source()) { + if (!value || !value.defining_op()) continue; + pir::Operation* defining_op = value.defining_op(); + if (visited_ops->count(defining_op)) continue; + visited_ops->insert(defining_op); + if (!IsInBlock(defining_op, block)) continue; + if (IncrementalOrder()(defining_op, position_op)) continue; + + ops.insert(defining_op); + auto recursive_ops = GetUpstreamOpsAfterPosition( + position_op, block, defining_op, visited_ops); + ops.insert(recursive_ops.begin(), recursive_ops.end()); + } + return ops; +} + +void MoveUpstreamOpBeforeGroup(const GroupOpsVec& group_ops, + pir::Block* block, + pir::Operation* insert_point_op) { + const auto moved_ops = [&]() { + std::set ops_set; + std::unordered_set visited_ops; + for (auto& op : group_ops) { + auto upstream_ops = + GetUpstreamOpsAfterPosition(insert_point_op, block, op, &visited_ops); + ops_set.insert(upstream_ops.begin(), upstream_ops.end()); + } + return ops_set; + }(); + + for (auto& op : moved_ops) { + VLOG(5) << "Move " << op->name() << " before " << insert_point_op->name(); + op->MoveTo(block, insert_point_op->operator Block::Iterator()); + } +} } // namespace void ReplaceWithGroupOp(pir::Block* block, @@ -524,6 +598,7 @@ void ReplaceWithGroupOp(pir::Block* block, // step 1: Analysis and insert group op before insert_point. auto* insert_point = FindInsertPoint(group_ops, outputs); + MoveUpstreamOpBeforeGroup(group_ops, block, insert_point); builder.set_insertion_point(insert_point); VLOG(6) << "Insert GroupOp after " << insert_point->name(); diff --git a/paddle/fluid/pir/transforms/sub_graph_detector.h b/paddle/fluid/pir/transforms/sub_graph_detector.h index 1b7ec2bc5da6a..424855b02ddcc 100644 --- a/paddle/fluid/pir/transforms/sub_graph_detector.h +++ b/paddle/fluid/pir/transforms/sub_graph_detector.h @@ -51,7 +51,7 @@ class SubgraphDetector { void DoSubGraphFusion(); bool FuseSubGraph(SubGraphPtr subgraph_ptr); - // check exist depency. + // check exist dependency. bool IsDependency(const SubGraphPtr& producer_g, const SubGraphPtr& consumer, const std::unordered_set& consumers); diff --git a/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc b/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc index 6f513e8cf5b1c..513a7f238f282 100644 --- a/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc +++ b/paddle/fluid/pir/transforms/sub_graph_extract_pass.cc @@ -46,7 +46,10 @@ class SubGraphExtractPass : public pir::Pass { void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); - IR_ENFORCE(module_op, "sub_graph_extract_pass should run on module op."); + PADDLE_ENFORCE_NOT_NULL( + module_op, + phi::errors::InvalidArgument( + "sub_graph_extract_pass should run on module op.")); auto& block = module_op.block(); std::vector groups = diff --git a/paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.cc new file mode 100644 index 0000000000000..7cb7f09095c08 --- /dev/null +++ b/paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.h" + +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/utils/general_functions.h" + +#include "paddle/pir/include/pass/pass.h" +#include "paddle/pir/include/pass/pass_registry.h" + +namespace { + +class AddLayernormPattern : public paddle::drr::DrrPatternBase { + public: + std::string name() const override { return "AddLayernormPattern"; } + + void operator()(paddle::drr::DrrPatternContext *ctx) const override { + paddle::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &add = pat.Op(paddle::dialect::AddOp::name()); + const auto &layernorm = + pat.Op(paddle::dialect::LayerNormOp::name(), + {{"epsilon", pat.Attr("epsilon")}, + {"begin_norm_axis", pat.Attr("begin_norm_axis")}}); + add({&pat.Tensor("x"), &pat.Tensor("y")}, {&pat.Tensor("add_out")}); + layernorm( + {&pat.Tensor("add_out"), &pat.Tensor("scale"), &pat.Tensor("bias")}, + {&pat.Tensor("layernorm_out"), + &pat.Tensor("layernorm_mean"), + &pat.Tensor("layernorm_variance")}); + + pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { + std::vector x_shape = + pir::GetShapeFromValue(match_ctx.Tensor("x")); + std::vector y_shape = + pir::GetShapeFromValue(match_ctx.Tensor("y")); + if (x_shape.size() == y_shape.size()) { + return true; + } + return false; + }); + + paddle::drr::ResultPattern res = pat.ResultPattern(); + + const auto &add_layernorm_xpu = + res.Op(paddle::dialect::AddLayernormXpuOp::name(), + {{{"epsilon", pat.Attr("epsilon")}, + {"begin_norm_axis", pat.Attr("begin_norm_axis")}}}); + add_layernorm_xpu({&res.Tensor("x"), + &res.Tensor("y"), + &res.Tensor("scale"), + &res.Tensor("bias")}, + {&res.Tensor("layernorm_out")}); + } +}; + +class AddLayernormXpuFusePass : public pir::PatternRewritePass { + public: + AddLayernormXpuFusePass() + : pir::PatternRewritePass("add_layernorm_xpu_fuse_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(paddle::drr::Create(context)); + return ps; + } +}; + +} // namespace + +namespace pir { +std::unique_ptr CreateAddLayernormXpuFusePass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(add_layernorm_xpu_fuse_pass, AddLayernormXpuFusePass); diff --git a/paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.h b/paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.h new file mode 100644 index 0000000000000..b154e7270d700 --- /dev/null +++ b/paddle/fluid/pir/transforms/xpu/add_layernorm_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/include/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateAddLayernormXpuFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/utils/CMakeLists.txt b/paddle/fluid/pir/utils/CMakeLists.txt new file mode 100644 index 0000000000000..943c4306d1160 --- /dev/null +++ b/paddle/fluid/pir/utils/CMakeLists.txt @@ -0,0 +1,4 @@ +cc_library( + pir_general_functions + SRCS general_functions.cc + DEPS op_dialect op_dialect_vjp pir) diff --git a/paddle/fluid/pir/transforms/transform_general_functions.cc b/paddle/fluid/pir/utils/general_functions.cc similarity index 91% rename from paddle/fluid/pir/transforms/transform_general_functions.cc rename to paddle/fluid/pir/utils/general_functions.cc index 2ef3d6d5b81dc..b061b3ae54cff 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.cc +++ b/paddle/fluid/pir/utils/general_functions.cc @@ -12,18 +12,20 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include #include "paddle/common/ddim.h" +#include "paddle/common/enforce.h" +#include "paddle/common/errors.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/core/op_operand.h" -#include "paddle/pir/include/core/parameter.h" +#include "paddle/pir/include/core/operation.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" @@ -61,7 +63,7 @@ void GetUsedExternalValueImpl( namespace pir { -std::string GetParameterNameFromValue(pir::Value value) { +std::string GetParameterNameFromValue(const pir::Value& value) { pir::Operation* owner = value.defining_op(); std::string name; if (owner->isa()) { @@ -78,7 +80,7 @@ std::string GetParameterNameFromValue(pir::Value value) { return name; } -std::vector GetShapeFromValue(pir::Value value) { +std::vector GetShapeFromValue(const pir::Value& value) { if (value.type().isa()) { return phi::vectorize( value.type().dyn_cast().dims()); @@ -91,7 +93,7 @@ std::vector GetShapeFromValue(pir::Value value) { } } -pir::Type GetDataTypeFromValue(pir::Value value) { +pir::Type GetDataTypeFromValue(const pir::Value& value) { // TODO(dev): Support other types like DenseTensor. PADDLE_ENFORCE_EQ( value.type().isa(), @@ -139,13 +141,13 @@ std::vector GetUsedExternalValue(const pir::Block& block) { return used_values; } -bool ValueIsPersitable(pir::Value value) { +bool ValueIsPersistable(const pir::Value& value) { if (!value.defining_op()) { return false; } if (value.defining_op()->num_operands() > 0) { for (const auto& source_value : value.defining_op()->operands_source()) { - if (!ValueIsPersitable(source_value)) { + if (!ValueIsPersistable(source_value)) { return false; } } diff --git a/paddle/fluid/pir/transforms/transform_general_functions.h b/paddle/fluid/pir/utils/general_functions.h similarity index 82% rename from paddle/fluid/pir/transforms/transform_general_functions.h rename to paddle/fluid/pir/utils/general_functions.h index d34c6d6863802..e2c655804def5 100644 --- a/paddle/fluid/pir/transforms/transform_general_functions.h +++ b/paddle/fluid/pir/utils/general_functions.h @@ -14,44 +14,46 @@ #pragma once -#include "paddle/common/errors.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/pir/include/core/operation.h" -#include "paddle/pir/include/core/parameter.h" +#include +#include + #include "paddle/pir/include/core/type.h" -#include "paddle/pir/include/core/value.h" namespace pir { +class Operation; +class Block; +class Value; + /** * @brief Get the name of parameter from a value. * * @note The value must be a output of a ParameterOp or a ConstantTensorOp. * - * @param pir::Value + * @param const pir::Value& * * @return std::string */ -std::string GetParameterNameFromValue(pir::Value value); +std::string GetParameterNameFromValue(const pir::Value& value); /** * @brief Get tensor's shape from a value. * - * @param pir::Value + * @param const pir::Value& * * @return std::vector */ -std::vector GetShapeFromValue(pir::Value value); +std::vector GetShapeFromValue(const pir::Value& value); /** * @brief Get tensor's data type from a value. * - * @param pir::Value + * @param const pir::Value& * * @return pir::Type */ -pir::Type GetDataTypeFromValue(pir::Value value); +pir::Type GetDataTypeFromValue(const pir::Value& value); /** * @brief Get an operation that defines the specific input of the operation. @@ -99,10 +101,10 @@ std::vector GetUsedExternalValue(const Block& block); * @brief Determine whether a value comes from a weight or has no input op. That is to say, it is permissible. * - * @param pir::Value + * @param const pir::Value& * @return bool */ -bool ValueIsPersitable(pir::Value value); +bool ValueIsPersistable(const pir::Value& value); } // namespace pir diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 4ffcf53b1a574..e3be121820684 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -133,7 +133,7 @@ void NCCLCommContext::CreateAllNCCLComms(const std::vector& dev_ids, dev_ids.size())); const int kDevices = dev_ids.size(); - ncclComm_t comms[kDevices]; + ncclComm_t comms[kDevices]; // NOLINT PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclCommInitAll( comms, dev_ids.size(), dev_ids.data())); @@ -169,7 +169,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( VLOG(1) << "Begin CreateNCCLCommMultiTrainer. device number: " << kDevices << ", ntrainers: " << ntrainers << ", train_id: " << train_id << ", rind_id: " << ring_id; - ncclComm_t comms[kDevices]; + ncclComm_t comms[kDevices]; // NOLINT { PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupStart()); for (int i = 0; i < kDevices; i++) { @@ -183,7 +183,7 @@ void NCCLCommContext::CreateNCCLCommMultiTrainer( VLOG(1) << "ncclCommInitRank: " << i; } PADDLE_ENFORCE_GPU_SUCCESS(dynload::ncclGroupEnd()); - VLOG(1) << "nccl group end seccessss"; + VLOG(1) << "nccl group end success"; } PADDLE_ENFORCE_EQ(comm_map_.count(ring_id), 0, @@ -261,7 +261,7 @@ NCCLComm* NCCLCommContext::AssignNCCLComm( platform::CUDAPlace(dev_id))); dev_ctx->set_nccl_comm(comm); } - VLOG(4) << "add mccl comm: " << comm_map_[ring_id][dev_id].get() + VLOG(4) << "add nccl comm: " << comm_map_[ring_id][dev_id].get() << ", ring_id:" << ring_id << ", dev_id:" << dev_id; return comm_map_[ring_id][dev_id].get(); } diff --git a/paddle/fluid/platform/cpu_info_test.cc b/paddle/fluid/platform/cpu_info_test.cc index 604f203ae68db..181e249cd0842 100644 --- a/paddle/fluid/platform/cpu_info_test.cc +++ b/paddle/fluid/platform/cpu_info_test.cc @@ -17,7 +17,7 @@ #include "gtest/gtest.h" #include "paddle/common/flags.h" -#include "paddle/fluid/string/printf.h" +#include "paddle/utils/string/printf.h" COMMON_DECLARE_double(fraction_of_cpu_memory_to_use); diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc index 389276fb24f49..9d522d8b2f0fe 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.cc +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.cc @@ -25,7 +25,7 @@ COMMON_DECLARE_bool(new_executor_use_cuda_graph); namespace paddle { namespace platform { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) { dev_ctx->cudnn_workspace_handle().ResetWorkspace(); @@ -69,7 +69,7 @@ phi::DeviceContext* SelectCUDAGraphDeviceContext(phi::GPUPlace place, mutable_dev_ctx = phi::backends::gpu::CUDAGraphContextManager::Instance().Get( *pool_id, place, 0); - } else if (num_stream == 1) { + } else { VLOG(4) << "Use recorded stream to capture cuda graph. Used in " "single-stream scenarios with new executor."; mutable_dev_ctx = *(all_capturing_dev_ctxs.begin()); @@ -82,7 +82,7 @@ phi::DeviceContext* SelectCUDAGraphDeviceContext(phi::GPUPlace place, } void BeginCUDAGraphCapture(phi::GPUPlace place, - cudaStreamCaptureMode mode, + gpuStreamCaptureMode mode, int64_t pool_id) { auto* mutable_dev_ctx = SelectCUDAGraphDeviceContext(place, &pool_id); auto* dev_ctx = reinterpret_cast(mutable_dev_ctx); diff --git a/paddle/fluid/platform/cuda_graph_with_memory_pool.h b/paddle/fluid/platform/cuda_graph_with_memory_pool.h index c076d33c88682..a1eca67a9ee87 100644 --- a/paddle/fluid/platform/cuda_graph_with_memory_pool.h +++ b/paddle/fluid/platform/cuda_graph_with_memory_pool.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/common/macros.h" +#include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/enforce.h" @@ -23,17 +24,17 @@ namespace paddle { namespace platform { // NOTE: These APIs are not thread-safe. -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) using CUDAGraph = phi::backends::gpu::CUDAGraph; void BeginCUDAGraphCapture(phi::GPUPlace place, - cudaStreamCaptureMode mode, + gpuStreamCaptureMode mode, int64_t pool_id = CUDAGraph::kInvalidPoolID); std::unique_ptr EndCUDAGraphCapture(); #endif inline phi::GPUPlace CUDAGraphCapturingPlace() { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return CUDAGraph::CapturingPlace(); #else PADDLE_THROW(phi::errors::Unimplemented( @@ -52,8 +53,8 @@ class SkipCUDAGraphCaptureGuard { public: SkipCUDAGraphCaptureGuard() { -#ifdef PADDLE_WITH_CUDA -#if CUDA_VERSION >= 10010 +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 10010 if (UNLIKELY(CUDAGraph::IsCapturing())) { CUDAGraph::EndSegmentCapture(); } @@ -62,8 +63,8 @@ class SkipCUDAGraphCaptureGuard { } ~SkipCUDAGraphCaptureGuard() { -#ifdef PADDLE_WITH_CUDA -#if CUDA_VERSION >= 10010 +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#if defined(PADDLE_WITH_HIP) || CUDA_VERSION >= 10010 if (UNLIKELY(CUDAGraph::IsCapturing())) { CUDAGraph::BeginSegmentCapture(); } diff --git a/paddle/fluid/platform/device/gpu/gpu_info.cc b/paddle/fluid/platform/device/gpu/gpu_info.cc index 211f937faa75c..36189cc7e4c90 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.cc +++ b/paddle/fluid/platform/device/gpu/gpu_info.cc @@ -30,11 +30,12 @@ limitations under the License. */ #include "paddle/fluid/platform/monitor.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler/mem_tracing.h" -#include "paddle/fluid/string/split.h" #include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/utils/string/split.h" #ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/dynload/miopen.h" +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" #else #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" @@ -44,6 +45,8 @@ limitations under the License. */ #if CUDA_VERSION >= 10020 #include "paddle/fluid/platform/dynload/cuda_driver.h" #endif +#else // PADDLE_WITH_HIP +#include "paddle/fluid/platform/dynload/rocm_driver.h" #endif COMMON_DECLARE_double(fraction_of_gpu_memory_to_use); @@ -256,6 +259,8 @@ class RecordedGpuMallocHelper { * would be clear. */ gpuError_t MallocAsync(void **ptr, size_t size, gpuStream_t stream) { +#if defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUDA) && (CUDA_VERSION >= 11020) LockGuardPtr lock(mtx_); if (UNLIKELY(NeedRecord() && cur_size_.load() + size > limit_size_)) { return gpuErrorOutOfMemory; @@ -263,19 +268,35 @@ class RecordedGpuMallocHelper { CUDADeviceGuard guard(dev_id_); std::call_once(set_cudamempoolattr_once_flag_, [&]() { +#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS( cudaDeviceGetDefaultMemPool(&memPool_, dev_id_)); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipDeviceGetDefaultMemPool(&memPool_, dev_id_)); +#endif uint64_t thresholdVal = FLAGS_cuda_memory_async_pool_realease_threshold; VLOG(10) << "[cudaMallocAsync] set cudaMemPoolAttrReleaseThreshold to " << thresholdVal; +#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_GPU_SUCCESS( cudaMemPoolSetAttribute(memPool_, cudaMemPoolAttrReleaseThreshold, reinterpret_cast(&thresholdVal))); +#else // PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemPoolSetAttribute(memPool_, + hipMemPoolAttrReleaseThreshold, + reinterpret_cast(&thresholdVal))); +#endif }); gpuError_t result; +#ifdef PADDLE_WITH_CUDA result = cudaMallocAsync(ptr, size, stream); +#else // PADDLE_WITH_HIP + result = hipMallocAsync(ptr, size, stream); +#endif VLOG(10) << "[cudaMallocAsync] ptr = " << (*ptr) << " size = " << static_cast(size) / (1 << 20) << " MB result = " << result << " stream = " << stream; @@ -298,6 +319,10 @@ class RecordedGpuMallocHelper { // return cudaErrorMemoryAllocation directly here. return gpuErrorOutOfMemory; } +#else + PADDLE_THROW(phi::errors::Unavailable( + "MallocAsync is not supported in this version of CUDA.")); +#endif } /** @@ -338,17 +363,23 @@ class RecordedGpuMallocHelper { } void FreeAsync(void *ptr, size_t size, gpuStream_t stream) { +#if defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUDA) && (CUDA_VERSION >= 11020) // Purposefully allow cudaErrorCudartUnloading, because // that is returned if you ever call cudaFree after the // driver has already shutdown. This happens only if the // process is terminating, in which case we don't care if // cudaFree succeeds. CUDADeviceGuard guard(dev_id_); +#ifdef PADDLE_WITH_CUDA auto err = cudaFreeAsync(ptr, stream); +#else // PADDLE_WITH_HIP + auto err = hipFreeAsync(ptr, stream); +#endif VLOG(10) << "[cudaFreeAsync] ptr = " << ptr << " size =" << static_cast(size) / (1 << 20) << " MB result = " << err << " stream = " << stream; - if (err != cudaErrorCudartUnloading) { + if (err != gpuErrorCudartUnloading) { PADDLE_ENFORCE_GPU_SUCCESS(err); cur_size_.fetch_sub(size); DEVICE_MEMORY_STAT_UPDATE(Reserved, dev_id_, -size); @@ -364,8 +395,12 @@ class RecordedGpuMallocHelper { #ifdef PADDLE_WITH_TESTING gpu_ptrs.erase(ptr); #endif - } +#else + PADDLE_THROW(phi::errors::Unavailable( + "FreeAsync is not supported in this version of CUDA.")); +#endif + } void *GetBasePtr(void *ptr) { #ifdef PADDLE_WITH_TESTING auto it = gpu_ptrs.upper_bound(ptr); @@ -439,24 +474,54 @@ class RecordedGpuMallocHelper { } #endif +#else // PADDLE_WITH_HIP + hipError_t MemCreate(hipMemGenericAllocationHandle_t *handle, + size_t size, + const hipMemAllocationProp *prop, + unsigned long long flags) { // NOLINT + auto result = + paddle::platform::dynload::hipMemCreate(handle, size, prop, flags); + if (result == hipSuccess) { + cur_size_.fetch_add(size); + } + return result; + } + + hipError_t MemRelease(hipMemGenericAllocationHandle_t handle, size_t size) { + auto result = paddle::platform::dynload::hipMemRelease(handle); + if (result == hipSuccess) { + cur_size_.fetch_sub(size); + } + return result; + } + #endif private: const int dev_id_; const uint64_t limit_size_; std::atomic cur_size_{0}; + +#if defined(PADDLE_WITH_CUDA) && (CUDA_VERSION >= 11020) cudaMemPool_t memPool_; + static std::once_flag set_cudamempoolattr_once_flag_; +#endif +#if defined(PADDLE_WITH_HIP) + hipMemPool_t memPool_; + static std::once_flag set_cudamempoolattr_once_flag_; +#endif mutable std::unique_ptr mtx_; - static std::once_flag once_flag_; - static std::once_flag set_cudamempoolattr_once_flag_; - std::set gpu_ptrs; // just for testing }; // NOLINT std::once_flag RecordedGpuMallocHelper::once_flag_; + +#if defined(PADDLE_WITH_HIP) || \ + defined(PADDLE_WITH_CUDA) && (CUDA_VERSION >= 11020) std::once_flag RecordedGpuMallocHelper::set_cudamempoolattr_once_flag_; +#endif gpuError_t RecordedGpuMalloc(void **ptr, size_t size, @@ -502,6 +567,21 @@ CUresult RecordedGpuMemRelease(CUmemGenericAllocationHandle handle, return RecordedGpuMallocHelper::Instance(dev_id)->MemRelease(handle, size); } #endif +#else // PADDLE_WITH_HIP +hipError_t RecordedGpuMemCreate(hipMemGenericAllocationHandle_t *handle, + size_t size, + const hipMemAllocationProp *prop, + unsigned long long flags, // NOLINT + int dev_id) { + return RecordedGpuMallocHelper::Instance(dev_id)->MemCreate( + handle, size, prop, flags); +} + +hipError_t RecordedGpuMemRelease(hipMemGenericAllocationHandle_t handle, + size_t size, + int dev_id) { + return RecordedGpuMallocHelper::Instance(dev_id)->MemRelease(handle, size); +} #endif bool RecordedGpuMemGetInfo(size_t *avail, @@ -577,7 +657,7 @@ int GetGPUMaxThreadsPerBlock(int id) { int GetCurrentDeviceId() { return phi::backends::gpu::GetCurrentDeviceId(); } -std::array GetGpuMaxGridDimSize(int id) { +std::array GetGpuMaxGridDimSize(int id) { return phi::backends::gpu::GetGpuMaxGridDimSize(id); } diff --git a/paddle/fluid/platform/device/gpu/gpu_info.h b/paddle/fluid/platform/device/gpu/gpu_info.h index 2714cdd1e521f..c6582667f507f 100644 --- a/paddle/fluid/platform/device/gpu/gpu_info.h +++ b/paddle/fluid/platform/device/gpu/gpu_info.h @@ -56,7 +56,7 @@ int GetGPUMaxThreadsPerBlock(int id); TEST_API int GetCurrentDeviceId(); //! Get the maximum GridDim size for GPU buddy allocator. -std::array GetGpuMaxGridDimSize(int); +std::array GetGpuMaxGridDimSize(int); //! Get a list of device ids from environment variable or use all. std::vector GetSelectedDevices(); diff --git a/paddle/fluid/platform/device/gpu/gpu_types.h b/paddle/fluid/platform/device/gpu/gpu_types.h index c9afafdef7166..8a192ba919cad 100644 --- a/paddle/fluid/platform/device/gpu/gpu_types.h +++ b/paddle/fluid/platform/device/gpu/gpu_types.h @@ -1,5 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,11 +32,13 @@ namespace paddle { +// Note(qili93): CUDA Runtime API supported by HIP +// https://github.com/ROCm/HIPIFY/blob/master/doc/markdown/CUDA_Runtime_API_functions_supported_by_HIP.md + #ifdef PADDLE_WITH_HIP #define DECLARE_TYPE_FOR_GPU(GPU_TYPE, CUDA_TYPE, ROCM_TYPE) \ using GPU_TYPE = ROCM_TYPE; -#else // CDUA - +#else // PADDLE_WITH_CUDA #define DECLARE_TYPE_FOR_GPU(GPU_TYPE, CUDA_TYPE, ROCM_TYPE) \ using GPU_TYPE = CUDA_TYPE; #endif @@ -81,22 +82,22 @@ DECLARE_TYPE_FOR_GPU(dnnDropoutDescriptor_t, cudnnDropoutDescriptor_t, miopenDropoutDescriptor_t); DECLARE_TYPE_FOR_GPU(dnnHandle_t, cudnnHandle_t, miopenHandle_t); - +DECLARE_TYPE_FOR_GPU(gpuIpcMemHandle_t, cudaIpcMemHandle_t, hipIpcMemHandle_t); DECLARE_TYPE_FOR_GPU(blasHandle_t, cublasHandle_t, rocblas_handle); +DECLARE_TYPE_FOR_GPU(gpuStreamCaptureMode, + cudaStreamCaptureMode, + hipStreamCaptureMode); // TODO(Ming Huang): Since there is no blasLt handler, // use rocblas_handle for workround. DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); -using CUDAGraphID = unsigned long long; // NOLINT - #undef DECLARE_TYPE_FOR_GPU #ifdef PADDLE_WITH_HIP #define DECLARE_CONSTANT_FOR_GPU(GPU_CV, CUDA_CV, ROCM_CV) \ constexpr auto GPU_CV = ROCM_CV; -#else // CDUA - +#else // PADDLE_WITH_CUDA #define DECLARE_CONSTANT_FOR_GPU(GPU_CV, CUDA_CV, ROCM_CV) \ constexpr auto GPU_CV = CUDA_CV; #endif @@ -106,8 +107,64 @@ DECLARE_CONSTANT_FOR_GPU(gpuErrorOutOfMemory, hipErrorOutOfMemory); DECLARE_CONSTANT_FOR_GPU(gpuErrorNotReady, cudaErrorNotReady, hipErrorNotReady); DECLARE_CONSTANT_FOR_GPU(gpuSuccess, cudaSuccess, hipSuccess); +DECLARE_CONSTANT_FOR_GPU(gpuErrorCudartUnloading, + cudaErrorCudartUnloading, + hipErrorDeinitialized); +DECLARE_CONSTANT_FOR_GPU(gpuEventDisableTiming, + cudaEventDisableTiming, + hipEventDisableTiming); +DECLARE_CONSTANT_FOR_GPU(gpuStreamNonBlocking, + cudaStreamNonBlocking, + hipStreamNonBlocking); +DECLARE_CONSTANT_FOR_GPU(gpuIpcMemLazyEnablePeerAccess, + cudaIpcMemLazyEnablePeerAccess, + hipIpcMemLazyEnablePeerAccess); #undef DECLARE_CONSTANT_FOR_GPU -} // namespace paddle +#ifdef PADDLE_WITH_HIP +#define DECLARE_FUNCTION_FOR_GPU(GPU_FUNC, CUDA_FUNC, ROCM_FUNC) \ + const auto GPU_FUNC = ROCM_FUNC; +#else // PADDLE_WITH_CUDA +#define DECLARE_FUNCTION_FOR_GPU(GPU_FUNC, CUDA_FUNC, ROCM_FUNC) \ + const auto GPU_FUNC = CUDA_FUNC; #endif + +DECLARE_FUNCTION_FOR_GPU(gpuStreamCreateWithPriority, + cudaStreamCreateWithPriority, + hipStreamCreateWithPriority); +DECLARE_FUNCTION_FOR_GPU(gpuStreamBeginCapture, + cudaStreamBeginCapture, + hipStreamBeginCapture); +DECLARE_FUNCTION_FOR_GPU(gpuStreamEndCapture, + cudaStreamEndCapture, + hipStreamEndCapture); +DECLARE_FUNCTION_FOR_GPU(gpuStreamGetCaptureInfo, + cudaStreamGetCaptureInfo, + hipStreamGetCaptureInfo); +DECLARE_FUNCTION_FOR_GPU(gpuEventCreateWithFlags, + cudaEventCreateWithFlags, + hipEventCreateWithFlags); +DECLARE_FUNCTION_FOR_GPU(gpuEventRecord, cudaEventRecord, hipEventRecord); +DECLARE_FUNCTION_FOR_GPU(gpuEventDestroy, cudaEventDestroy, hipEventDestroy); +DECLARE_FUNCTION_FOR_GPU(gpuEventQuery, cudaEventQuery, hipEventQuery); +DECLARE_FUNCTION_FOR_GPU(gpuEventSynchronize, + cudaEventSynchronize, + hipEventSynchronize); +DECLARE_FUNCTION_FOR_GPU(gpuStreamSynchronize, + cudaStreamSynchronize, + hipStreamSynchronize); +DECLARE_FUNCTION_FOR_GPU(gpuIpcOpenMemHandle, + cudaIpcOpenMemHandle, + hipIpcOpenMemHandle); +DECLARE_FUNCTION_FOR_GPU(gpuIpcCloseMemHandle, + cudaIpcCloseMemHandle, + hipIpcCloseMemHandle); + +#undef DECLARE_FUNCTION_FOR_GPU + +using CUDAGraphID = unsigned long long; // NOLINT + +} // namespace paddle + +#endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/fluid/platform/device/gpu/nccl_helper.h b/paddle/fluid/platform/device/gpu/nccl_helper.h index 8afcfc9f2b700..83026ade670f2 100644 --- a/paddle/fluid/platform/device/gpu/nccl_helper.h +++ b/paddle/fluid/platform/device/gpu/nccl_helper.h @@ -155,7 +155,8 @@ struct NCCLContext { int device_id() const { return ctx_->GetPlace().device; } }; -struct NCCLContextMap { +class NCCLContextMap { + public: std::unordered_map contexts_; std::vector order_; diff --git a/paddle/fluid/platform/device/xpu/xpu_info.cc b/paddle/fluid/platform/device/xpu/xpu_info.cc index 9be4031fed82a..cc7388df4c22f 100644 --- a/paddle/fluid/platform/device/xpu/xpu_info.cc +++ b/paddle/fluid/platform/device/xpu/xpu_info.cc @@ -171,6 +171,9 @@ class RecordedXPUMallocHelper { */ void Free(void* ptr, size_t size) { XPUDeviceGuard guard(dev_id_); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto* dev_ctx = pool.GetByPlace(XPUPlace(dev_id_)); + dev_ctx->Wait(); xpu_free(ptr); cur_size_.fetch_sub(size); } diff --git a/paddle/fluid/platform/device_event_base.cc b/paddle/fluid/platform/device_event_base.cc index cd2d31f1fbefb..6079691fe873c 100644 --- a/paddle/fluid/platform/device_event_base.cc +++ b/paddle/fluid/platform/device_event_base.cc @@ -66,9 +66,9 @@ void DeviceEventRecordCPU(DeviceEvent* event, const DeviceContext* context) { auto* wrapper = static_cast(event->GetEvent().get()); std::unique_lock lock(wrapper->mutex_); - // NOTE: As for CudaEvent_t, it can be used to Record() repeatly. CudaEvent_t - // internally reset its status from finished into initialized. - // So we simulate the process here. + // NOTE: As for CudaEvent_t, it can be used to Record() repeatedly. + // CudaEvent_t internally reset its status from finished into initialized. So + // we simulate the process here. if (wrapper->status_.load() == EventStatus::SUCCESS) { VLOG(3) << "Found EventStatus is SUCCESS before RecordCPU. Reset it into " "INITIALIZED."; diff --git a/paddle/fluid/platform/device_event_cpu.h b/paddle/fluid/platform/device_event_cpu.h index 9490d5f3ceec8..e6faeb5fd01a4 100644 --- a/paddle/fluid/platform/device_event_cpu.h +++ b/paddle/fluid/platform/device_event_cpu.h @@ -30,7 +30,7 @@ struct CPUDeviceEventWrapper { platform::is_cpu_place(place), true, platform::errors::PreconditionNotMet( - "Required device shall be CPUAPlace, but received %d. ", place)); + "Required device shall be CPUPlace, but received %d. ", place)); } std::mutex mutex_; std::condition_variable cv_completed_; diff --git a/paddle/fluid/platform/device_event_test.cc b/paddle/fluid/platform/device_event_test.cc index b2e3d3242d219..4eb0da7740f3a 100644 --- a/paddle/fluid/platform/device_event_test.cc +++ b/paddle/fluid/platform/device_event_test.cc @@ -63,7 +63,7 @@ TEST(DeviceEvent, CUDA) { status = event.Query(); ASSERT_EQ(status, false); // async - event.Wait(kCPU, context); // step 3. EventSynchornize + event.Wait(kCPU, context); // step 3. EventSynchronize status = event.Query(); ASSERT_EQ(status, true); // sync @@ -114,7 +114,7 @@ TEST(DeviceEvent, CUDA) { status = event.Query(); ASSERT_EQ(status, false); // async - event.Wait(kCPU, context); // step 3. EventSynchornize + event.Wait(kCPU, context); // step 3. EventSynchronize status = event.Query(); ASSERT_EQ(status, true); // sync diff --git a/paddle/fluid/platform/dynload/cudnn.cc b/paddle/fluid/platform/dynload/cudnn.cc index 05cacb74c8673..aa8fd62aa85cc 100644 --- a/paddle/fluid/platform/dynload/cudnn.cc +++ b/paddle/fluid/platform/dynload/cudnn.cc @@ -44,6 +44,18 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP); CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP); #endif +#ifdef CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9 +CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DEFINE_WRAP); +#endif + +#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9 +CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(DEFINE_WRAP); +#endif + +#ifdef CUDNN_DNN_ROUTINE_EACH_R9 +CUDNN_DNN_ROUTINE_EACH_R9(DEFINE_WRAP); +#endif + bool HasCUDNN() { return phi::dynload::HasCUDNN(); } } // namespace dynload diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 9af1e8065c49d..bf957554a3d75 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -90,13 +90,6 @@ extern bool HasCUDNN(); __macro(cudnnSetDropoutDescriptor); \ __macro(cudnnRestoreDropoutDescriptor); \ __macro(cudnnCreateRNNDescriptor); \ - __macro(cudnnGetRNNParamsSize); \ - __macro(cudnnGetRNNWorkspaceSize); \ - __macro(cudnnGetRNNTrainingReserveSize); \ - __macro(cudnnRNNForwardTraining); \ - __macro(cudnnRNNBackwardData); \ - __macro(cudnnRNNBackwardWeights); \ - __macro(cudnnRNNForwardInference); \ __macro(cudnnDestroyDropoutDescriptor); \ __macro(cudnnDestroyRNNDescriptor); \ __macro(cudnnSetTensorNdDescriptorEx); \ @@ -111,8 +104,7 @@ extern bool HasCUDNN(); __macro(cudnnCreateActivationDescriptor); \ __macro(cudnnSetActivationDescriptor); \ __macro(cudnnGetActivationDescriptor); \ - __macro(cudnnDestroyActivationDescriptor); \ - __macro(cudnnSetRNNDescriptor_v6); + __macro(cudnnDestroyActivationDescriptor); CUDNN_DNN_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 @@ -147,12 +139,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \ __macro(cudnnCreateRNNDataDescriptor); \ __macro(cudnnDestroyRNNDataDescriptor); \ - __macro(cudnnSetRNNDataDescriptor); \ - __macro(cudnnSetRNNPaddingMode); \ - __macro(cudnnRNNForwardTrainingEx); \ - __macro(cudnnRNNBackwardDataEx); \ - __macro(cudnnRNNBackwardWeightsEx); \ - __macro(cudnnRNNForwardInferenceEx); + __macro(cudnnSetRNNDataDescriptor); CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif @@ -182,6 +169,39 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R8(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif +#if CUDNN_VERSION < 90000 +#define CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(__macro) \ + __macro(cudnnGetRNNParamsSize); \ + __macro(cudnnGetRNNWorkspaceSize); \ + __macro(cudnnGetRNNTrainingReserveSize); \ + __macro(cudnnSetRNNDescriptor_v6); \ + __macro(cudnnRNNForwardInference); \ + __macro(cudnnRNNForwardTraining); \ + __macro(cudnnRNNBackwardData); \ + __macro(cudnnRNNBackwardWeights); +CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + +#if CUDNN_VERSION < 90000 && CUDNN_VERSION >= 7201 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(__macro) \ + __macro(cudnnSetRNNPaddingMode); \ + __macro(cudnnRNNForwardInferenceEx); \ + __macro(cudnnRNNForwardTrainingEx); \ + __macro(cudnnRNNBackwardDataEx); \ + __macro(cudnnRNNBackwardWeightsEx); +CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9( + PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + +#if CUDNN_VERSION >= 90000 +#define CUDNN_DNN_ROUTINE_EACH_R9(__macro) \ + __macro(cudnnGetRNNWeightSpaceSize); \ + __macro(cudnnGetRNNTempSpaceSizes); \ + __macro(cudnnRNNForward); \ + __macro(cudnnRNNBackwardData_v8); \ + __macro(cudnnRNNBackwardWeights_v8); +CUDNN_DNN_ROUTINE_EACH_R9(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/mklrt.h b/paddle/fluid/platform/dynload/mklrt.h index 0ee5b33b85d73..31cde5716f6e3 100644 --- a/paddle/fluid/platform/dynload/mklrt.h +++ b/paddle/fluid/platform/dynload/mklrt.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/dynamic_loader.h" #include "paddle/phi/backends/dynload/mklrt.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace paddle { namespace platform { diff --git a/paddle/fluid/platform/dynload/nccl.h b/paddle/fluid/platform/dynload/nccl.h index d9516c9f4de4e..2dba64af33206 100644 --- a/paddle/fluid/platform/dynload/nccl.h +++ b/paddle/fluid/platform/dynload/nccl.h @@ -31,6 +31,7 @@ namespace dynload { __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommInitRank2); \ __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ diff --git a/paddle/fluid/platform/dynload/rocm_driver.h b/paddle/fluid/platform/dynload/rocm_driver.h index 5c8e18611c40a..5295ffb07c1d1 100644 --- a/paddle/fluid/platform/dynload/rocm_driver.h +++ b/paddle/fluid/platform/dynload/rocm_driver.h @@ -39,13 +39,33 @@ extern bool HasCUDADriver(); __macro(hipModuleLoadData); \ __macro(hipModuleGetFunction); \ __macro(hipModuleUnload); \ - /*rocm3.5 not support the function*/ \ + /* DTK not support the function*/ \ /* __macro(hipOccupancyMaxActiveBlocksPerMultiprocessor);*/ \ __macro(hipModuleLaunchKernel); \ __macro(hipLaunchKernel); \ __macro(hipGetDevice); \ __macro(hipGetDeviceCount); \ - __macro(hipDevicePrimaryCtxGetState) + __macro(hipDevicePrimaryCtxGetState); \ + __macro(hipDeviceGetAttribute); \ + __macro(hipDeviceGet) + +#define ROCM_ROUTINE_EACH_VVM(__macro) \ + __macro(hipMemGetAllocationGranularity); \ + __macro(hipMemAddressReserve); \ + __macro(hipMemCreate); \ + __macro(hipMemMap); \ + __macro(hipMemSetAccess); \ + __macro(hipMemUnmap); \ + __macro(hipMemRelease); \ + __macro(hipMemAddressFree) + +#define ROCM_ROUTINE_EACH_GPU_GRAPH(__macro) \ + __macro(hipGraphNodeGetType); \ + __macro(hipGraphKernelNodeGetParams); \ + __macro(hipGraphExecKernelNodeSetParams) + +ROCM_ROUTINE_EACH_VVM(PLATFORM_DECLARE_DYNAMIC_LOAD_ROCM_WRAP); +ROCM_ROUTINE_EACH_GPU_GRAPH(PLATFORM_DECLARE_DYNAMIC_LOAD_ROCM_WRAP); ROCM_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_ROCM_WRAP); diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index dec1d971df004..03467d175c78f 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -65,9 +65,9 @@ limitations under the License. */ #include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/macros.h" -#include "paddle/fluid/string/printf.h" -#include "paddle/fluid/string/to_string.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" +#include "paddle/utils/string/printf.h" +#include "paddle/utils/string/to_string.h" #ifdef PADDLE_WITH_CUDA #include "paddle/phi/backends/dynload/cublas.h" diff --git a/paddle/fluid/platform/enforce_test.cc b/paddle/fluid/platform/enforce_test.cc index 9bad3f0bf1c41..e6838746fd6ac 100644 --- a/paddle/fluid/platform/enforce_test.cc +++ b/paddle/fluid/platform/enforce_test.cc @@ -594,7 +594,7 @@ TEST(enforce, cannot_to_string_type) { } TEST(GET_DATA_SAFELY_MACRO, SUCCESS) { - int* a = new int(10); + int* a = new int(10); // NOLINT GET_DATA_SAFELY(a, "Input", "X", "dummy"); } diff --git a/paddle/fluid/platform/float16_test.cu b/paddle/fluid/platform/float16_test.cu index 4575b54d48c9b..555f83d61675e 100644 --- a/paddle/fluid/platform/float16_test.cu +++ b/paddle/fluid/platform/float16_test.cu @@ -282,7 +282,7 @@ TEST(float16, compound_on_gpu) { TestDivAssign(6, 2, 3); } -TEST(float16, comparision_on_gpu) { +TEST(float16, comparison_on_gpu) { TestEqual(1, 1, true); TestEqual(1, 2, false); TestNotEqual(2, 3, true); diff --git a/paddle/fluid/platform/gen_comm_id_helper.cc b/paddle/fluid/platform/gen_comm_id_helper.cc index 40d80f8ef2cbc..7d16fc368d166 100644 --- a/paddle/fluid/platform/gen_comm_id_helper.cc +++ b/paddle/fluid/platform/gen_comm_id_helper.cc @@ -30,7 +30,7 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/common/flags.h" #include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/string/split.h" +#include "paddle/utils/string/split.h" #if defined(PADDLE_WITH_XPU_BKCL) #include "xpu/bkcl.h" #endif @@ -82,7 +82,7 @@ static int SocketSend(int fd, const char* buffer, int size) { int offset = 0; int bytes = 0; while (offset < size) { - bytes = send(fd, buffer + offset, size - offset, 0); + bytes = send(fd, buffer + offset, size - offset, 0); // NOLINT if (bytes == -1) { if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { // send failed @@ -100,7 +100,7 @@ static int SocketRecv(int fd, char* buffer, int size) { int offset = 0; int bytes = 0; while (offset < size) { - bytes = recv(fd, buffer + offset, size - offset, 0); + bytes = recv(fd, buffer + offset, size - offset, 0); // NOLINT if (bytes == 0) { // closed by client, maybe probing alive client return 0; diff --git a/paddle/fluid/platform/init.cc b/paddle/fluid/platform/init.cc index 5d0f5c3aa8d01..1fffa07a99974 100644 --- a/paddle/fluid/platform/init.cc +++ b/paddle/fluid/platform/init.cc @@ -16,8 +16,8 @@ limitations under the License. */ #include #include "paddle/fluid/platform/cpu_helper.h" -#include "paddle/fluid/string/split.h" #include "paddle/phi/backends/cpu/cpu_info.h" +#include "paddle/utils/string/split.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" diff --git a/paddle/fluid/platform/place.cc b/paddle/fluid/platform/place.cc index 118ba7d6b782c..df66cc63e3986 100644 --- a/paddle/fluid/platform/place.cc +++ b/paddle/fluid/platform/place.cc @@ -62,8 +62,6 @@ bool is_same_place(const Place &p1, const Place &p2) { if (places_are_same_class(p1, p2)) { if (is_cpu_place(p1) || is_cuda_pinned_place(p1)) { return true; - } else if (is_xpu_place(p1) || is_ipu_place(p1) || is_custom_place(p1)) { - return p1 == p2; } else { return p1 == p2; } diff --git a/paddle/fluid/platform/profiler.cc b/paddle/fluid/platform/profiler.cc index 816ae57ff4c06..b0f8f329dde4f 100644 --- a/paddle/fluid/platform/profiler.cc +++ b/paddle/fluid/platform/profiler.cc @@ -56,7 +56,7 @@ std::mutex phi::ProfilerHelper::g_all_mem_event_lists_mutex; namespace paddle { namespace platform { -MemEvenRecorder MemEvenRecorder::recorder; +MemEventRecorder MemEventRecorder::recorder; RecordInstantEvent::RecordInstantEvent(const char *name, TracerEventType type, @@ -200,8 +200,8 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3]; RecordMemEvent::has_initialized["gpu"][place.GetDeviceId()] = true; } else { - current_allocated = - DEVICE_MEMORY_STAT_CURRENT_VALUE(Allocated, place.GetDeviceId()); + current_allocated = DEVICE_MEMORY_STAT_CURRENT_VALUE( + Allocated, place.GetDeviceId()); // NOLINT peak_allocated = DEVICE_MEMORY_STAT_PEAK_VALUE(Allocated, place.GetDeviceId()); RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][0] = @@ -214,14 +214,14 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3]; } } - platform::MemEvenRecorder::Instance().PushMemRecord(ptr, - place, - size, - type, - current_allocated, - current_reserved, - peak_allocated, - peak_reserved); + platform::MemEventRecorder::Instance().PushMemRecord(ptr, + place, + size, + type, + current_allocated, + current_reserved, + peak_allocated, + peak_reserved); } else if (type == TracerMemEventType::ReservedAllocate) { uint64_t current_reserved = 0; uint64_t peak_reserved = 0; @@ -283,10 +283,10 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3]; RecordMemEvent::has_initialized["gpu"][place.GetDeviceId()] = true; } else { - current_reserved = - DEVICE_MEMORY_STAT_CURRENT_VALUE(Reserved, place.GetDeviceId()); - peak_reserved = - DEVICE_MEMORY_STAT_PEAK_VALUE(Reserved, place.GetDeviceId()); + current_reserved = DEVICE_MEMORY_STAT_CURRENT_VALUE( + Reserved, place.GetDeviceId()); // NOLINT + peak_reserved = DEVICE_MEMORY_STAT_PEAK_VALUE( + Reserved, place.GetDeviceId()); // NOLINT RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][1] = current_reserved; RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3] = @@ -297,14 +297,14 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][2]; } } - platform::MemEvenRecorder::Instance().PushMemRecord(ptr, - place, - size, - type, - current_allocated, - current_reserved, - peak_allocated, - peak_reserved); + platform::MemEventRecorder::Instance().PushMemRecord(ptr, + place, + size, + type, + current_allocated, + current_reserved, + peak_allocated, + peak_reserved); } else if (type == TracerMemEventType::Free) { uint64_t current_allocated = 0; uint64_t peak_allocated = 0; @@ -366,10 +366,10 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3]; RecordMemEvent::has_initialized["gpu"][place.GetDeviceId()] = true; } else { - current_allocated = - DEVICE_MEMORY_STAT_CURRENT_VALUE(Allocated, place.GetDeviceId()); - peak_allocated = - DEVICE_MEMORY_STAT_PEAK_VALUE(Allocated, place.GetDeviceId()); + current_allocated = DEVICE_MEMORY_STAT_CURRENT_VALUE( + Allocated, place.GetDeviceId()); // NOLINT + peak_allocated = DEVICE_MEMORY_STAT_PEAK_VALUE( + Allocated, place.GetDeviceId()); // NOLINT RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][0] = current_allocated; RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][2] = @@ -380,14 +380,14 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3]; } } - platform::MemEvenRecorder::Instance().PopMemRecord(ptr, - place, - size, - type, - current_allocated, - current_reserved, - peak_allocated, - peak_reserved); + platform::MemEventRecorder::Instance().PopMemRecord(ptr, + place, + size, + type, + current_allocated, + current_reserved, + peak_allocated, + peak_reserved); } else if (type == TracerMemEventType::ReservedFree) { uint64_t current_reserved = 0; uint64_t peak_reserved = 0; @@ -449,10 +449,10 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3]; RecordMemEvent::has_initialized["gpu"][place.GetDeviceId()] = true; } else { - current_reserved = - DEVICE_MEMORY_STAT_CURRENT_VALUE(Reserved, place.GetDeviceId()); - peak_reserved = - DEVICE_MEMORY_STAT_PEAK_VALUE(Reserved, place.GetDeviceId()); + current_reserved = DEVICE_MEMORY_STAT_CURRENT_VALUE( + Reserved, place.GetDeviceId()); // NOLINT + peak_reserved = DEVICE_MEMORY_STAT_PEAK_VALUE( + Reserved, place.GetDeviceId()); // NOLINT RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][1] = current_reserved; RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][3] = @@ -463,20 +463,20 @@ RecordMemEvent::RecordMemEvent(const void *ptr, RecordMemEvent::size_cache["gpu"][place.GetDeviceId()][2]; } } - platform::MemEvenRecorder::Instance().PopMemRecord(ptr, - place, - size, - type, - current_allocated, - current_reserved, - peak_allocated, - peak_reserved); + platform::MemEventRecorder::Instance().PopMemRecord(ptr, + place, + size, + type, + current_allocated, + current_reserved, + peak_allocated, + peak_reserved); } } -void MemEvenRecorder::PushMemRecord(const void *ptr, - const Place &place, - size_t size) { +void MemEventRecorder::PushMemRecord(const void *ptr, + const Place &place, + size_t size) { if (phi::ProfilerHelper::g_state == ProfilerState::kDisabled) { return; } @@ -487,17 +487,17 @@ void MemEvenRecorder::PushMemRecord(const void *ptr, platform::errors::InvalidArgument( "The Place can't exist in the stage of PushMemRecord")); events.emplace( - ptr, std::make_unique(place, size)); + ptr, std::make_unique(place, size)); } -void MemEvenRecorder::PushMemRecord(const void *ptr, - const Place &place, - size_t size, - TracerMemEventType type, - uint64_t current_allocated, - uint64_t current_reserved, - uint64_t peak_allocated, - uint64_t peak_reserved) { +void MemEventRecorder::PushMemRecord(const void *ptr, + const Place &place, + size_t size, + TracerMemEventType type, + uint64_t current_allocated, + uint64_t current_reserved, + uint64_t peak_allocated, + uint64_t peak_reserved) { std::lock_guard guard(mtx_); if (FLAGS_enable_host_event_recorder_hook) { // new MemRecord HostEventRecorder::GetInstance().RecordEvent( @@ -523,10 +523,10 @@ void MemEvenRecorder::PushMemRecord(const void *ptr, platform::errors::InvalidArgument( "The Place can't exist in the stage of PushMemRecord")); events.emplace( - ptr, std::make_unique(place, size)); + ptr, std::make_unique(place, size)); } -void MemEvenRecorder::PopMemRecord(const void *ptr, const Place &place) { +void MemEventRecorder::PopMemRecord(const void *ptr, const Place &place) { if (phi::ProfilerHelper::g_state == ProfilerState::kDisabled) { return; } @@ -539,14 +539,14 @@ void MemEvenRecorder::PopMemRecord(const void *ptr, const Place &place) { } } -void MemEvenRecorder::PopMemRecord(const void *ptr, - const Place &place, - size_t size, - TracerMemEventType type, - uint64_t current_allocated, - uint64_t current_reserved, - uint64_t peak_allocated, - uint64_t peak_reserved) { +void MemEventRecorder::PopMemRecord(const void *ptr, + const Place &place, + size_t size, + TracerMemEventType type, + uint64_t current_allocated, + uint64_t current_reserved, + uint64_t peak_allocated, + uint64_t peak_reserved) { std::lock_guard guard(mtx_); if (FLAGS_enable_host_event_recorder_hook) { // new MemRecord HostEventRecorder::GetInstance().RecordEvent( @@ -574,13 +574,13 @@ void MemEvenRecorder::PopMemRecord(const void *ptr, } } -void MemEvenRecorder::Flush() { +void MemEventRecorder::Flush() { std::lock_guard guard(mtx_); address_memevent_.clear(); } -MemEvenRecorder::RecordMemEvent::RecordMemEvent(const Place &place, - size_t bytes) +MemEventRecorder::RecordMemEvent::RecordMemEvent(const Place &place, + size_t bytes) : place_(place), bytes_(bytes), start_ns_(PosixInNsec()), @@ -588,7 +588,7 @@ MemEvenRecorder::RecordMemEvent::RecordMemEvent(const Place &place, PushMemEvent(start_ns_, end_ns_, bytes_, place_, alloc_in_); } -MemEvenRecorder::RecordMemEvent::~RecordMemEvent() { // NOLINT +MemEventRecorder::RecordMemEvent::~RecordMemEvent() { // NOLINT phi::DeviceTracer *tracer = phi::GetDeviceTracer(); end_ns_ = PosixInNsec(); @@ -701,7 +701,7 @@ void EnableProfiler(ProfilerState state) { void ResetProfiler() { SynchronizeAllDevice(); phi::GetDeviceTracer()->Reset(); - MemEvenRecorder::Instance().Flush(); + MemEventRecorder::Instance().Flush(); std::lock_guard guard( phi::ProfilerHelper::g_all_event_lists_mutex); for (auto &all_event_list : phi::ProfilerHelper::g_all_event_lists) { @@ -720,7 +720,7 @@ void DisableProfiler(EventSortingKey sorted_key, const std::string &profile_path) { SynchronizeAllDevice(); auto thr_events = DockHostEventRecorderHostPart(); - MemEvenRecorder::Instance().Flush(); + MemEventRecorder::Instance().Flush(); std::lock_guard l(profiler_mu); if (phi::ProfilerHelper::g_state == ProfilerState::kDisabled) return; @@ -755,7 +755,7 @@ void CompleteProfilerEvents(phi::proto::Profile *tracer_profile, std::vector> *mem_events) { SynchronizeAllDevice(); auto thr_events = DockHostEventRecorderHostPart(); - MemEvenRecorder::Instance().Flush(); + MemEventRecorder::Instance().Flush(); std::lock_guard l(profiler_mu); if (phi::ProfilerHelper::g_state == ProfilerState::kDisabled) return; // Mark the profiling stop. diff --git a/paddle/fluid/platform/profiler.h b/paddle/fluid/platform/profiler.h index 4d6bc9cc242d4..27c2bc8f77f7d 100644 --- a/paddle/fluid/platform/profiler.h +++ b/paddle/fluid/platform/profiler.h @@ -69,7 +69,7 @@ enum class EventSortingKey { kGPUTime }; -struct MemoryProfierReport { +struct MemoryProfilerReport { size_t alloc_times{0}; size_t alloc_size{0}; size_t free_times{0}; @@ -101,7 +101,7 @@ struct OverHead { std::vector sub_memcpy_items; }; -struct MemEvenRecorder { +struct MemEventRecorder { public: void PushMemRecord(const void* ptr, const Place& place, size_t size); void PopMemRecord(const void* ptr, const Place& place); @@ -122,7 +122,7 @@ struct MemEvenRecorder { uint64_t peak_allocated, uint64_t peak_reserved); void Flush(); - static MemEvenRecorder& Instance() { return recorder; } + static MemEventRecorder& Instance() { return recorder; } private: struct RecordMemEvent { @@ -137,13 +137,13 @@ struct MemEvenRecorder { std::string free_in_; }; - static MemEvenRecorder recorder; + static MemEventRecorder recorder; std::map>> address_memevent_; std::mutex mtx_; - MemEvenRecorder() {} - DISABLE_COPY_AND_ASSIGN(MemEvenRecorder); + MemEventRecorder() {} + DISABLE_COPY_AND_ASSIGN(MemEventRecorder); }; struct RecordBlock { diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.cc b/paddle/fluid/platform/profiler/chrometracing_logger.cc index de8fd01a1e59d..87fbe61979876 100644 --- a/paddle/fluid/platform/profiler/chrometracing_logger.cc +++ b/paddle/fluid/platform/profiler/chrometracing_logger.cc @@ -788,7 +788,7 @@ void ChromeTracingLogger::RefineDisplayName( "name": "process_name", "pid": %lld, "tid": %lld, "ph": "M", "args": { - "name": "Deivce %lld (%s)" + "name": "Device %lld (%s)" } }, { diff --git a/paddle/fluid/platform/profiler/chrometracing_logger.h b/paddle/fluid/platform/profiler/chrometracing_logger.h index 37323d1450bf2..89808bee842df 100644 --- a/paddle/fluid/platform/profiler/chrometracing_logger.h +++ b/paddle/fluid/platform/profiler/chrometracing_logger.h @@ -57,7 +57,7 @@ class ChromeTracingLogger : public BaseLogger { void RefineDisplayName(std::unordered_map); std::string filename_; std::ofstream output_file_stream_; - static const char* categary_name_[]; + static const char* category_name_[]; std::set> pid_tid_set_; std::set> deviceid_streamid_set_; uint64_t start_time_; diff --git a/paddle/fluid/platform/profiler/cpu_utilization.cc b/paddle/fluid/platform/profiler/cpu_utilization.cc index e84256f49f078..d373ac32ea6aa 100644 --- a/paddle/fluid/platform/profiler/cpu_utilization.cc +++ b/paddle/fluid/platform/profiler/cpu_utilization.cc @@ -24,6 +24,7 @@ // limitations under the License. #include "paddle/fluid/platform/profiler/cpu_utilization.h" +#include namespace paddle { namespace platform { @@ -53,16 +54,16 @@ void CpuUtilization::RecordBeginTimeInfo() { #elif defined(__linux__) start_ = times(&process_tms_start_); #define proc_path_size 1024 - static char proc_stat_path[proc_path_size] = "/proc/stat"; + static char proc_stat_path[proc_path_size] = "/proc/stat"; // NOLINTf FILE *stat_file = fopen(proc_stat_path, "r"); if (stat_file != nullptr) { - char temp_str[200]; + std::array temp_str; uint64_t temp_lu; int retval = fscanf(stat_file, "%s %" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64, - temp_str, + temp_str.data(), &system_tms_start_.tms_utime, &nice_time_start_, &system_tms_start_.tms_stime, @@ -98,16 +99,16 @@ void CpuUtilization::RecordEndTimeInfo() { #elif defined(__linux__) end_ = times(&process_tms_end_); #define proc_path_size 1024 - static char proc_stat_path[proc_path_size] = "/proc/stat"; + static char proc_stat_path[proc_path_size] = "/proc/stat"; // NOLINT FILE *stat_file = fopen(proc_stat_path, "r"); if (stat_file != nullptr) { - char temp_str[200]; + std::array temp_str; uint64_t temp_lu; int retval = fscanf(stat_file, "%s %" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64 "%" PRIu64, - temp_str, + temp_str.data(), &system_tms_end_.tms_utime, &nice_time_end_, &system_tms_end_.tms_stime, diff --git a/paddle/fluid/platform/profiler/dump/deserialization_reader.cc b/paddle/fluid/platform/profiler/dump/deserialization_reader.cc index 329c9f6871461..f02496ed5d082 100644 --- a/paddle/fluid/platform/profiler/dump/deserialization_reader.cc +++ b/paddle/fluid/platform/profiler/dump/deserialization_reader.cc @@ -44,12 +44,12 @@ std::unique_ptr DeserializationReader::Parse() { return nullptr; } // restore extra info - ExtraInfo extrainfo; + ExtraInfo extra_info; for (auto indx = 0; indx < node_trees_proto_->extra_info_size(); indx++) { ExtraInfoMap extra_info_map = node_trees_proto_->extra_info(indx); - extrainfo.AddExtraInfo(extra_info_map.key(), - std::string("%s"), - extra_info_map.value().c_str()); + extra_info.AddExtraInfo(extra_info_map.key(), + std::string("%s"), + extra_info_map.value().c_str()); } // restore NodeTrees @@ -139,10 +139,10 @@ std::unique_ptr DeserializationReader::Parse() { RestoreDeviceProperty(device_property_proto); } ProfilerResult* profiler_result_ptr = - new ProfilerResult(std::move(tree), extrainfo, device_property_map); + new ProfilerResult(std::move(tree), extra_info, device_property_map); #else ProfilerResult* profiler_result_ptr = - new ProfilerResult(std::move(tree), extrainfo); + new ProfilerResult(std::move(tree), extra_info); #endif // restore version and span indx profiler_result_ptr->SetVersion(node_trees_proto_->version()); diff --git a/paddle/fluid/platform/profiler/dump/serialization_logger.cc b/paddle/fluid/platform/profiler/dump/serialization_logger.cc index 17c3d42ec5e86..e7889a6727199 100644 --- a/paddle/fluid/platform/profiler/dump/serialization_logger.cc +++ b/paddle/fluid/platform/profiler/dump/serialization_logger.cc @@ -103,37 +103,33 @@ void SerializationLogger::LogNodeTrees(const NodeTrees& node_trees) { current_thread_node_tree_proto_ = node_trees_proto_->add_thread_trees(); // add ThreadNodeTreeProto current_thread_node_tree_proto_->set_thread_id(event_node.first); - for (auto hostnode = event_node.second.begin(); - hostnode != event_node.second.end(); - ++hostnode) { + for (auto hostnode : event_node.second) { HostTraceEventNodeProto* host_node_proto = current_thread_node_tree_proto_ ->add_host_nodes(); // add HostTraceEventNodeProto - host_node_proto->set_id(node_index_map[(*hostnode)]); - host_node_proto->set_parentid(node_parent_map[(*hostnode)]); + host_node_proto->set_id(node_index_map[hostnode]); + host_node_proto->set_parentid(node_parent_map[hostnode]); current_host_trace_event_node_proto_ = - host_node_proto; // set current HostTraceEventNodeProto - (*hostnode)->LogMe(this); // fill detail information + host_node_proto; // set current HostTraceEventNodeProto + hostnode->LogMe(this); // fill detail information - for (auto runtimenode : (*hostnode)->GetRuntimeTraceEventNodes()) { + for (auto runtimenode : hostnode->GetRuntimeTraceEventNodes()) { CudaRuntimeTraceEventNodeProto* runtime_node_proto = current_host_trace_event_node_proto_ ->add_runtime_nodes(); // add CudaRuntimeTraceEventNodeProto current_runtime_trace_event_node_proto_ = runtime_node_proto; // set current CudaRuntimeTraceEventNodeProto runtimenode->LogMe(this); // fill detail information - for (auto devicenode = runtimenode->GetDeviceTraceEventNodes().begin(); - devicenode != runtimenode->GetDeviceTraceEventNodes().end(); - ++devicenode) { + for (auto devicenode : runtimenode->GetDeviceTraceEventNodes()) { DeviceTraceEventNodeProto* device_node_proto = current_runtime_trace_event_node_proto_ ->add_device_nodes(); // add DeviceTraceEventNodeProto current_device_trace_event_node_proto_ = - device_node_proto; // set current DeviceTraceEventNodeProto - (*devicenode)->LogMe(this); // fill detail information + device_node_proto; // set current DeviceTraceEventNodeProto + devicenode->LogMe(this); // fill detail information } } - for (auto memnode : (*hostnode)->GetMemTraceEventNodes()) { + for (auto memnode : hostnode->GetMemTraceEventNodes()) { MemTraceEventNodeProto* mem_node_proto = current_host_trace_event_node_proto_->add_mem_nodes(); current_mem_trace_event_node_proto_ = mem_node_proto; diff --git a/paddle/fluid/platform/profiler/dump/serialization_logger.h b/paddle/fluid/platform/profiler/dump/serialization_logger.h index 80d5413106ded..e61ed701cd798 100644 --- a/paddle/fluid/platform/profiler/dump/serialization_logger.h +++ b/paddle/fluid/platform/profiler/dump/serialization_logger.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace platform { -// Dump a NodeTrees into a profobuf file. +// Dump a NodeTrees into a protobuf file. // A SerializationLogger object can only dump a NodeTrees object, // creates a file in the constructor and closes the file in the destructor. // Should only call LogNodeTrees and LogMetaInfo. diff --git a/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc b/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc index bc9407684bcd8..4872d7bb42353 100644 --- a/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc +++ b/paddle/fluid/platform/profiler/dump/test_serialization_logger.cc @@ -152,21 +152,21 @@ TEST(SerializationLoggerTest, dump_case0) { EXPECT_EQ(nodes[11].size(), 2u); std::vector thread1_nodes = nodes[10]; std::vector thread2_nodes = nodes[11]; - for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { - if ((*it)->Name() == "root node") { - EXPECT_EQ((*it)->GetChildren().size(), 3u); + for (auto& thread1_node : thread1_nodes) { + if (thread1_node->Name() == "root node") { + EXPECT_EQ(thread1_node->GetChildren().size(), 3u); } - if ((*it)->Name() == "op1") { - EXPECT_EQ((*it)->GetChildren().size(), 0u); - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); - EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u); - EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr); + if (thread1_node->Name() == "op1") { + EXPECT_EQ(thread1_node->GetChildren().size(), 0u); + EXPECT_EQ(thread1_node->GetRuntimeTraceEventNodes().size(), 2u); + EXPECT_EQ(thread1_node->GetMemTraceEventNodes().size(), 2u); + EXPECT_NE(thread1_node->GetOperatorSupplementEventNode(), nullptr); } } - for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { - if ((*it)->Name() == "op3") { - EXPECT_EQ((*it)->GetChildren().size(), 0u); - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + for (auto& thread2_node : thread2_nodes) { + if (thread2_node->Name() == "op3") { + EXPECT_EQ(thread2_node->GetChildren().size(), 0u); + EXPECT_EQ(thread2_node->GetRuntimeTraceEventNodes().size(), 2u); } } tree.LogMe(&logger); @@ -247,15 +247,15 @@ TEST(SerializationLoggerTest, dump_case1) { EXPECT_EQ(nodes[11].size(), 1u); std::vector thread1_nodes = nodes[10]; std::vector thread2_nodes = nodes[11]; - for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { - if ((*it)->Name() == "root node") { - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 3u); + for (auto& thread1_node : thread1_nodes) { + if (thread1_node->Name() == "root node") { + EXPECT_EQ(thread1_node->GetRuntimeTraceEventNodes().size(), 3u); } } - for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { - if ((*it)->Name() == "root node") { - EXPECT_EQ((*it)->GetChildren().size(), 0u); - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + for (auto& thread2_node : thread2_nodes) { + if (thread2_node->Name() == "root node") { + EXPECT_EQ(thread2_node->GetChildren().size(), 0u); + EXPECT_EQ(thread2_node->GetRuntimeTraceEventNodes().size(), 2u); } } tree.LogMe(&logger); @@ -272,21 +272,21 @@ TEST(DeserializationReaderTest, restore_case0) { EXPECT_EQ(nodes[11].size(), 2u); std::vector thread1_nodes = nodes[10]; std::vector thread2_nodes = nodes[11]; - for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { - if ((*it)->Name() == "root node") { - EXPECT_EQ((*it)->GetChildren().size(), 3u); + for (auto& thread1_node : thread1_nodes) { + if (thread1_node->Name() == "root node") { + EXPECT_EQ(thread1_node->GetChildren().size(), 3u); } - if ((*it)->Name() == "op1") { - EXPECT_EQ((*it)->GetChildren().size(), 0u); - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); - EXPECT_EQ((*it)->GetMemTraceEventNodes().size(), 2u); - EXPECT_NE((*it)->GetOperatorSupplementEventNode(), nullptr); + if (thread1_node->Name() == "op1") { + EXPECT_EQ(thread1_node->GetChildren().size(), 0u); + EXPECT_EQ(thread1_node->GetRuntimeTraceEventNodes().size(), 2u); + EXPECT_EQ(thread1_node->GetMemTraceEventNodes().size(), 2u); + EXPECT_NE(thread1_node->GetOperatorSupplementEventNode(), nullptr); } } - for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { - if ((*it)->Name() == "op3") { - EXPECT_EQ((*it)->GetChildren().size(), 0u); - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + for (auto& thread2_node : thread2_nodes) { + if (thread2_node->Name() == "op3") { + EXPECT_EQ(thread2_node->GetChildren().size(), 0u); + EXPECT_EQ(thread2_node->GetRuntimeTraceEventNodes().size(), 2u); } } } @@ -301,15 +301,15 @@ TEST(DeserializationReaderTest, restore_case1) { EXPECT_EQ(nodes[11].size(), 1u); std::vector thread1_nodes = nodes[10]; std::vector thread2_nodes = nodes[11]; - for (auto it = thread1_nodes.begin(); it != thread1_nodes.end(); it++) { - if ((*it)->Name() == "root node") { - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 3u); + for (auto& thread1_node : thread1_nodes) { + if (thread1_node->Name() == "root node") { + EXPECT_EQ(thread1_node->GetRuntimeTraceEventNodes().size(), 3u); } } - for (auto it = thread2_nodes.begin(); it != thread2_nodes.end(); it++) { - if ((*it)->Name() == "root node") { - EXPECT_EQ((*it)->GetChildren().size(), 0u); - EXPECT_EQ((*it)->GetRuntimeTraceEventNodes().size(), 2u); + for (auto& thread2_node : thread2_nodes) { + if (thread2_node->Name() == "root node") { + EXPECT_EQ(thread2_node->GetChildren().size(), 0u); + EXPECT_EQ(thread2_node->GetRuntimeTraceEventNodes().size(), 2u); } } } diff --git a/paddle/fluid/platform/profiler/event_node.cc b/paddle/fluid/platform/profiler/event_node.cc index c92ae133814f3..caceb82ec4622 100644 --- a/paddle/fluid/platform/profiler/event_node.cc +++ b/paddle/fluid/platform/profiler/event_node.cc @@ -340,7 +340,6 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship( // build relationship between host event node and op supplement node for (auto it = post_order_nodes.begin(); it < post_order_nodes.end(); ++it) { - int op_supplement_count = 0; // NOLINT bool hasenter = false; std::vector::iterator firstposition; std::vector::iterator lastposition = @@ -355,7 +354,6 @@ HostTraceEventNode* NodeTrees::BuildTreeRelationship( hasenter = true; } (*it)->SetOperatorSupplementNode(*op_supplement_it); - op_supplement_count += 1; } else { if ((*op_supplement_it)->TimeStampNs() > (*it)->EndNs()) { lastposition = op_supplement_it; @@ -434,10 +432,8 @@ void NodeTrees::HandleTrees( } for (auto event_node : (*hostnode)->GetRuntimeTraceEventNodes()) { runtime_event_node_handle(event_node); - for (auto devicenode = event_node->GetDeviceTraceEventNodes().begin(); - devicenode != event_node->GetDeviceTraceEventNodes().end(); - ++devicenode) { - device_event_node_handle(*devicenode); + for (auto devicenode : event_node->GetDeviceTraceEventNodes()) { + device_event_node_handle(devicenode); } } for (auto event_node : (*hostnode)->GetMemTraceEventNodes()) { diff --git a/paddle/fluid/platform/profiler/event_python.cc b/paddle/fluid/platform/profiler/event_python.cc index c01b4abcfbbd3..551cdd2182323 100644 --- a/paddle/fluid/platform/profiler/event_python.cc +++ b/paddle/fluid/platform/profiler/event_python.cc @@ -63,20 +63,18 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) { runtime_python_node->correlation_id = runtimenode->CorrelationId(); host_python_node->runtime_node_ptrs.push_back(runtime_python_node); // copy DeviceTraceEventNode - for (auto devicenode = runtimenode->GetDeviceTraceEventNodes().begin(); - devicenode != runtimenode->GetDeviceTraceEventNodes().end(); - ++devicenode) { + for (auto devicenode : runtimenode->GetDeviceTraceEventNodes()) { DevicePythonNode* device_python_node = new DevicePythonNode(); - device_python_node->name = (*devicenode)->Name(); - device_python_node->type = (*devicenode)->Type(); - device_python_node->start_ns = (*devicenode)->StartNs(); - device_python_node->end_ns = (*devicenode)->EndNs(); - device_python_node->device_id = (*devicenode)->DeviceId(); - device_python_node->context_id = (*devicenode)->ContextId(); - device_python_node->stream_id = (*devicenode)->StreamId(); - device_python_node->correlation_id = (*devicenode)->CorrelationId(); + device_python_node->name = devicenode->Name(); + device_python_node->type = devicenode->Type(); + device_python_node->start_ns = devicenode->StartNs(); + device_python_node->end_ns = devicenode->EndNs(); + device_python_node->device_id = devicenode->DeviceId(); + device_python_node->context_id = devicenode->ContextId(); + device_python_node->stream_id = devicenode->StreamId(); + device_python_node->correlation_id = devicenode->CorrelationId(); if (device_python_node->type == TracerEventType::Kernel) { - KernelEventInfo kernel_info = (*devicenode)->KernelInfo(); + KernelEventInfo kernel_info = devicenode->KernelInfo(); device_python_node->block_x = kernel_info.block_x; device_python_node->block_y = kernel_info.block_y; device_python_node->block_z = kernel_info.block_z; @@ -91,10 +89,10 @@ HostPythonNode* ProfilerResult::CopyTree(HostTraceEventNode* root) { device_python_node->warps_per_sm = kernel_info.warps_per_sm; device_python_node->occupancy = kernel_info.occupancy; } else if (device_python_node->type == TracerEventType::Memcpy) { - MemcpyEventInfo memcpy_info = (*devicenode)->MemcpyInfo(); + MemcpyEventInfo memcpy_info = devicenode->MemcpyInfo(); device_python_node->num_bytes = memcpy_info.num_bytes; } else if (device_python_node->type == TracerEventType::Memset) { - MemsetEventInfo memset_info = (*devicenode)->MemsetInfo(); + MemsetEventInfo memset_info = devicenode->MemsetInfo(); device_python_node->num_bytes = memset_info.num_bytes; device_python_node->value = memset_info.value; } diff --git a/paddle/fluid/platform/profiler/event_tracing.h b/paddle/fluid/platform/profiler/event_tracing.h index 08890f1369733..b427a9ba55210 100644 --- a/paddle/fluid/platform/profiler/event_tracing.h +++ b/paddle/fluid/platform/profiler/event_tracing.h @@ -28,7 +28,7 @@ namespace platform { // Chrome Trace Viewer Format: Instant Event struct RecordInstantEvent { /** - * @param name: It is the caller's reponsibility to manage the underlying + * @param name: It is the caller's responsibility to manage the underlying * storage. RecordInstantEvent stores the pointer. * @param type: Classification which is used to instruct the profiling * data statistics. diff --git a/paddle/fluid/platform/profiler/profiler.cc b/paddle/fluid/platform/profiler/profiler.cc index bcb35f5b7bd35..c9d458b1d250a 100644 --- a/paddle/fluid/platform/profiler/profiler.cc +++ b/paddle/fluid/platform/profiler/profiler.cc @@ -148,19 +148,19 @@ std::unique_ptr Profiler::Stop() { collector.MemEvents(), collector.OperatorSupplementEvents())); cpu_utilization_.RecordEndTimeInfo(); - ExtraInfo extrainfo; - extrainfo.AddExtraInfo(std::string("System Cpu Utilization"), - std::string("%f"), - cpu_utilization_.GetCpuUtilization()); - extrainfo.AddExtraInfo(std::string("Process Cpu Utilization"), - std::string("%f"), - cpu_utilization_.GetCpuCurProcessUtilization()); + ExtraInfo extra_info; + extra_info.AddExtraInfo(std::string("System Cpu Utilization"), + std::string("%f"), + cpu_utilization_.GetCpuUtilization()); + extra_info.AddExtraInfo(std::string("Process Cpu Utilization"), + std::string("%f"), + cpu_utilization_.GetCpuCurProcessUtilization()); const std::unordered_map thread_names = collector.ThreadNames(); for (const auto& kv : thread_names) { - extrainfo.AddExtraInfo(string_format(std::string("%llu"), kv.first), - std::string("%s"), - kv.second.c_str()); + extra_info.AddExtraInfo(string_format(std::string("%llu"), kv.first), + std::string("%s"), + kv.second.c_str()); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::map device_property_map; @@ -170,10 +170,10 @@ std::unique_ptr Profiler::Stop() { device_property_map[device_id] = device_property; } ProfilerResult* profiler_result_ptr = new platform::ProfilerResult( - std::move(tree), extrainfo, device_property_map); + std::move(tree), extra_info, device_property_map); #else ProfilerResult* profiler_result_ptr = - new platform::ProfilerResult(std::move(tree), extrainfo); + new platform::ProfilerResult(std::move(tree), extra_info); #endif profiler_result_ptr->SetVersion(std::string(version)); profiler_result_ptr->SetSpanIndx(span_indx); diff --git a/paddle/fluid/platform/profiler/utils.cc b/paddle/fluid/platform/profiler/utils.cc index 46a94e7fcb23c..236c77cec5b22 100644 --- a/paddle/fluid/platform/profiler/utils.cc +++ b/paddle/fluid/platform/profiler/utils.cc @@ -106,7 +106,8 @@ float CalculateEstOccupancy(uint32_t DeviceId, float occupancy = 0.0; std::vector device_ids = GetSelectedDevices(); if (DeviceId < device_ids.size()) { - const gpuDeviceProp& device_property = GetDeviceProperties(DeviceId); + const gpuDeviceProp& device_property = + GetDeviceProperties(static_cast(DeviceId)); cudaOccFuncAttributes occFuncAttr; occFuncAttr.maxThreadsPerBlock = INT_MAX; occFuncAttr.numRegs = RegistersPerThread; @@ -127,11 +128,13 @@ float CalculateEstOccupancy(uint32_t DeviceId, blockSize, dynamicSmemSize); if (status == CUDA_OCC_SUCCESS) { - if (occ_result.activeBlocksPerMultiprocessor < BlocksPerSm) { - BlocksPerSm = occ_result.activeBlocksPerMultiprocessor; + if (static_cast(occ_result.activeBlocksPerMultiprocessor) < + BlocksPerSm) { + BlocksPerSm = + static_cast(occ_result.activeBlocksPerMultiprocessor); } occupancy = - BlocksPerSm * blockSize / + BlocksPerSm * static_cast(blockSize) / static_cast(device_property.maxThreadsPerMultiProcessor); } else { LOG(WARNING) << "Failed to calculate estimated occupancy, status = " @@ -145,16 +148,16 @@ float CalculateEstOccupancy(uint32_t DeviceId, #endif // PADDLE_WITH_CUPTI const char* StringTracerMemEventType(TracerMemEventType type) { - static const char* categary_name_[] = {// NOLINT + static const char* category_name_[] = {// NOLINT "Allocate", "Free", "ReservedAllocate", "ReservedFree"}; - return categary_name_[static_cast(type)]; + return category_name_[static_cast(type)]; } const char* StringTracerEventType(TracerEventType type) { - static const char* categary_name_[] = {"Operator", // NOLINT + static const char* category_name_[] = {"Operator", // NOLINT "Dataloader", "ProfileStep", "CudaRuntime", @@ -169,7 +172,7 @@ const char* StringTracerEventType(TracerEventType type) { "Communication", "PythonOp", "PythonUserDefined"}; - return categary_name_[static_cast(type)]; + return category_name_[static_cast(type)]; } } // namespace platform diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index 8ce6fee8a5f6e..634d670c575bb 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -132,7 +132,7 @@ static double ToMegaBytes(size_t bytes) { // Print results void PrintMemProfiler( - const std::map> + const std::map> &annotation_report, const size_t name_width, const size_t data_width) { @@ -200,7 +200,7 @@ void PrintMemProfiler( void ParseMemEvents(const std::vector> &events) { if (phi::ProfilerHelper::g_state == ProfilerState::kDisabled) return; // place, annotation, alloc times, alloc size - std::map> + std::map> annotation_report; for (auto &tmp : events) { @@ -740,7 +740,7 @@ void AnalyzeEvent( size_t *max_name_width, OverHead *overhead, bool merge_thread) { - // In oreder to deal with special event in main thread + // In order to deal with special event in main thread std::set main_thread_event_name; for (size_t i = 0; i < (*analyze_events).size(); i++) { for (size_t j = 0; j < (*analyze_events)[i].size(); j++) { diff --git a/paddle/fluid/platform/stream_callback_manager.cc b/paddle/fluid/platform/stream_callback_manager.cc index c55bcb71a7d43..6719a1b6e97bc 100644 --- a/paddle/fluid/platform/stream_callback_manager.cc +++ b/paddle/fluid/platform/stream_callback_manager.cc @@ -83,7 +83,7 @@ void StreamCallbackManager::Wait() const { } #ifdef PADDLE_WITH_CUDA -template struct StreamCallbackManager; +template class StreamCallbackManager; #endif #ifdef PADDLE_WITH_HIP template struct StreamCallbackManager; diff --git a/paddle/fluid/platform/timer.h b/paddle/fluid/platform/timer.h index ab029577fbdd1..b0ece1be3c868 100644 --- a/paddle/fluid/platform/timer.h +++ b/paddle/fluid/platform/timer.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "paddle/utils/test_macros.h" #ifdef _WIN32 diff --git a/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py b/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py index 7c1cb550f893b..704ef988b7f50 100644 --- a/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py +++ b/paddle/fluid/prim/api/auto_code_generated/tensor_operants_gen.py @@ -95,11 +95,11 @@ class TEST_API EagerTensorOperants : public TensorOperantsBase { namespace prim { Tensor EagerTensorOperants::add(const Tensor& x, const Scalar& y) { - return ::add_ad_func(x, ::full_like_ad_func(x, y)); + return ::scale_ad_func(x, 1.0f, y, true); } Tensor EagerTensorOperants::subtract(const Tensor& x, const Scalar& y) { - return ::subtract_ad_func(x, ::full_like_ad_func(x, y)); + return ::scale_ad_func(x, 1.0f, -y, true); } Tensor EagerTensorOperants::multiply(const Tensor& x, const Scalar& y) { @@ -111,11 +111,11 @@ class TEST_API EagerTensorOperants : public TensorOperantsBase { } Tensor EagerTensorOperants::add(const Scalar& x, const Tensor& y) { - return ::add_ad_func(::full_like_ad_func(y, x), y); + return ::scale_ad_func(y, 1.0f, x, true); } Tensor EagerTensorOperants::subtract(const Scalar& x, const Tensor& y) { - return ::subtract_ad_func(::full_like_ad_func(y, x), y); + return ::scale_ad_func(y, -1.0f, x, true); } Tensor EagerTensorOperants::multiply(const Scalar& x, const Tensor& y) { @@ -131,7 +131,7 @@ class TEST_API EagerTensorOperants : public TensorOperantsBase { } Tensor EagerTensorOperants::pow(const Tensor& x, const Scalar& y) { - return ::elementwise_pow_ad_func(x, ::full_like_ad_func(x, y)); + return ::pow_ad_func(x, y); } """ diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 7131d37dd5496..169d41d9763e5 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -33,6 +33,19 @@ using Tensor = paddle::Tensor; using IntArray = paddle::experimental::IntArrayBase; // This function should have as same signature as phi, which defined in // paddle/phi/api/backward/backward_api.h +template +void pow_grad(const Tensor& x, + const Tensor& out_grad, + const Scalar& y, + Tensor* x_grad) { + // dx = y * x^(y-1) * out_grad + if (x_grad) { + auto y_value = y.to(); + auto dx_res = y_value * x.pow(y_value - 1) * out_grad; + set_output(dx_res, x_grad); + } // indicate we will compute dx +} + template void hardswish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { if (x_grad) { @@ -220,9 +233,9 @@ void subtract_grad(const Tensor& x, Tensor* dy) { if (dy) { auto scale_out_grad = scale(out_grad, -1.0, 0.0, true); - if (x.dims() != y.dims()) { + if (out_grad.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), out_grad.dims()); if (!reduce_dim.size()) { by_pass(scale_out_grad, dy); } else { @@ -236,9 +249,9 @@ void subtract_grad(const Tensor& x, } } if (dx) { - if (y.dims() != x.dims()) { + if (out_grad.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto reduce_dim = get_reduce_dims(x.dims(), out_grad.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { @@ -261,9 +274,9 @@ void add_grad(const Tensor& x, Tensor* dx, Tensor* dy) { if (dy) { - if (x.dims() != y.dims()) { + if (out_grad.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), out_grad.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dy); } else { @@ -277,9 +290,9 @@ void add_grad(const Tensor& x, } } if (dx) { - if (y.dims() != x.dims()) { + if (out_grad.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto reduce_dim = get_reduce_dims(x.dims(), out_grad.dims()); if (!reduce_dim.size()) { by_pass(out_grad, dx); } else { @@ -371,9 +384,9 @@ void divide_grad(const Tensor& x, if (dx) { // dx = (1/y) * dout = dout / y auto dx_res = out_grad / y; - if (y.dims() != x.dims()) { + if (out_grad.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto reduce_dim = get_reduce_dims(x.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { @@ -399,9 +412,9 @@ void elementwise_pow_grad(const Tensor& x, auto lnx = log(x); auto x_pow_y = elementwise_pow(x, y); auto dy_res = lnx * x_pow_y * out_grad; - if (x.dims() != y.dims()) { + if (out_grad.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dy_res, dy); } else { @@ -419,9 +432,9 @@ void elementwise_pow_grad(const Tensor& x, auto tmp_z = y - 1.0; auto x_pow_z = elementwise_pow(x, tmp_z); auto dx_res = y * x_pow_z * out_grad; - if (y.dims() != x.dims()) { + if (out_grad.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto reduce_dim = get_reduce_dims(x.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dx_res, dx); } else { @@ -831,8 +844,6 @@ void group_norm_grad(const Tensor& x, tmp1.sum(std::vector({0}), scale_ptr->dtype(), false), IntArray(std::vector({C}))); set_output(scale_grad_tmp, scale_grad); - } else { - scale_grad = nullptr; } } @@ -841,8 +852,6 @@ void group_norm_grad(const Tensor& x, auto bias_grad_tmp = sum_y_grad.sum(std::vector({0}), bias_ptr->dtype(), false); set_output(bias_grad_tmp, bias_grad); - } else { - bias_grad = nullptr; } } } @@ -934,8 +943,6 @@ void layer_norm_grad(const Tensor& x, scale_grad_tmp = cast(scale_grad_tmp, scale_ptr->dtype()); } set_output(scale_grad_tmp, scale_grad); - } else { - scale_grad = nullptr; } } @@ -949,8 +956,6 @@ void layer_norm_grad(const Tensor& x, bias_grad_tmp = cast(bias_grad_tmp, bias_ptr->dtype()); } set_output(bias_grad_tmp, bias_grad); - } else { - bias_grad = nullptr; } } } @@ -1146,9 +1151,9 @@ void maximum_grad(const Tensor& x, if (x_grad) { auto x_tmp = cast(greater_than(x, y), out_grad.dtype()); auto dx_res = out_grad * x_tmp; - if (y.dims() != x.dims()) { + if (out_grad.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto reduce_dim = get_reduce_dims(x.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dx_res, x_grad); } else { @@ -1165,9 +1170,9 @@ void maximum_grad(const Tensor& x, if (y_grad) { auto y_tmp = cast(less_equal(x, y), out_grad.dtype()); auto dy_res = out_grad * y_tmp; - if (x.dims() != y.dims()) { + if (out_grad.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dy_res, y_grad); } else { @@ -1600,9 +1605,9 @@ void minimum_grad(const Tensor& x, if (x_grad) { auto x_tmp = cast(less_than(x, y), out_grad.dtype()); auto dx_res = out_grad * x_tmp; - if (y.dims() != x.dims()) { + if (out_grad.dims() != x.dims()) { // Maybe need reduce here - auto reduce_dim = get_reduce_dims(x.dims(), y.dims()); + auto reduce_dim = get_reduce_dims(x.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dx_res, x_grad); } else { @@ -1619,9 +1624,9 @@ void minimum_grad(const Tensor& x, if (y_grad) { auto y_tmp = cast(greater_equal(x, y), out_grad.dtype()); auto dy_res = out_grad * y_tmp; - if (x.dims() != y.dims()) { + if (out_grad.dims() != y.dims()) { // Maybe need reduce here - phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims()); + phi::DDim reduce_dim = get_reduce_dims(y.dims(), out_grad.dims()); if (!reduce_dim.size()) { set_output(dy_res, y_grad); } else { diff --git a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h index 02bd7e29443c0..7e7ccfaf170b3 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_double_backward_api.h @@ -53,6 +53,111 @@ void tanh_double_grad(const Tensor& out, } } +template +void sin_double_grad(const Tensor& x, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* x_grad, + Tensor* grad_out_grad) { + // sin grad grad : ddout = cosx * ddx, dx = -dy * sinx * ddx + if (x_grad) { + auto x_grad_tmp = -(grad_out * sin(x) * grad_x_grad); + set_output(x_grad_tmp, x_grad); + } + + if (grad_out_grad) { + auto grad_out_grad_tmp = cos(x) * grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + +template +void cos_double_grad(const Tensor& x, + const Tensor& grad_out, + const Tensor& grad_x_grad, + Tensor* x_grad, + Tensor* grad_out_grad) { + // cos grad grad : ddout = -sinx * ddx, dx = -dy * cosx * ddx + if (x_grad) { + auto x_grad_tmp = -(grad_out * cos(x) * grad_x_grad); + set_output(x_grad_tmp, x_grad); + } + + if (grad_out_grad) { + auto grad_out_grad_tmp = -sin(x) * grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } +} + +template +void minimum_double_grad(const Tensor& x, + const Tensor& y, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + Tensor* grad_out_grad) { + if (grad_out_grad) { + if (grad_x_grad && grad_y_grad) { + auto x_mask = cast(less_than(x, y), grad_x_grad.get().dtype()); + auto ddout = + grad_x_grad.get() * x_mask + grad_y_grad.get() * (1 - x_mask); + set_output(ddout, grad_out_grad); + } else if (grad_x_grad) { + auto x_mask = cast(less_than(x, y), grad_x_grad.get().dtype()); + auto ddout = grad_x_grad.get() * x_mask; + set_output(ddout, grad_out_grad); + } else if (grad_y_grad) { + auto y_mask = cast(greater_equal(x, y), grad_y_grad.get().dtype()); + auto ddout = grad_y_grad.get() * y_mask; + set_output(ddout, grad_out_grad); + } + } +} +template +void pow_double_grad(const Tensor& x, + const Tensor& grad_out, + const Tensor& grad_x_grad, + const Scalar& y, + Tensor* x_grad, + Tensor* grad_out_grad) { + // pow grad grad : ddout = y * pow(x, y-1) * ddx, dx = y * (y-1) * pow(x, y-2) + // * dout * ddx + auto y_value = y.to(); + if (grad_out_grad) { + auto grad_out_grad_tmp = y_value * x.pow(y_value - 1) * grad_x_grad; + set_output(grad_out_grad_tmp, grad_out_grad); + } + + if (x_grad) { + auto x_grad_tmp = + y_value * (y_value - 1) * x.pow(y_value - 2) * grad_out * grad_x_grad; + set_output(x_grad_tmp, x_grad); + } +} + +template +void maximum_double_grad(const Tensor& x, + const Tensor& y, + const paddle::optional& grad_x_grad, + const paddle::optional& grad_y_grad, + Tensor* grad_out_grad) { + if (grad_out_grad) { + if (grad_x_grad && grad_y_grad) { + auto x_mask = cast(greater_than(x, y), grad_x_grad.get().dtype()); + auto ddout = + grad_x_grad.get() * x_mask + grad_y_grad.get() * (1 - x_mask); + set_output(ddout, grad_out_grad); + } else if (grad_x_grad) { + auto x_mask = cast(greater_than(x, y), grad_x_grad.get().dtype()); + auto ddout = grad_x_grad.get() * x_mask; + set_output(ddout, grad_out_grad); + } else if (grad_y_grad) { + auto y_mask = cast(less_equal(x, y), grad_y_grad.get().dtype()); + auto ddout = grad_y_grad.get() * y_mask; + set_output(ddout, grad_out_grad); + } + } +} + template void tanh_triple_grad(const Tensor& out, const Tensor& grad_out_forward, @@ -62,63 +167,122 @@ void tanh_triple_grad(const Tensor& out, Tensor* out_grad, Tensor* grad_out_forward_grad, Tensor* grad_x_grad_forward_grad) { - if (out_grad) { - if (grad_out_grad_grad) { - if (grad_out_new_grad) { - auto out_grad_tmp = - (-2 * out * grad_x_grad_forward * grad_out_grad_grad.get()) - - (2 * grad_out_forward * grad_x_grad_forward * - grad_out_new_grad.get()); - set_output(out_grad_tmp, out_grad); - } else { - auto out_grad_tmp = - -2 * out * grad_x_grad_forward * grad_out_grad_grad.get(); - set_output(out_grad_tmp, out_grad); - } - } else { - if (grad_out_new_grad) { - auto out_grad_tmp = -(2 * grad_out_forward * grad_x_grad_forward * - grad_out_new_grad.get()); - set_output(out_grad_tmp, out_grad); - } else { - auto out_grad_tmp = 0 * out; - set_output(out_grad_tmp, out_grad); - } + if (grad_out_new_grad && grad_out_grad_grad) { + /* + dy = -2 * dy * ddx * ddy - 2 * y * ddx * dddy + ddy = -2 * y * ddx * ddy + dddx = -2 * y * dy * ddy + (1 - y^2) * dddy + */ + /* precompute '-2 * y' to prevent duplicated computation*/ + Tensor neg_2_out; + if (grad_out_forward_grad || grad_x_grad_forward_grad) { + neg_2_out = scale(out, -2.0); + } + /* precompute 'dy(prev) * ddy' to prevent duplicated computation*/ + Tensor grad_out_forward_mul_grad_out_new_grad; + if (out_grad || grad_x_grad_forward_grad) { + grad_out_forward_mul_grad_out_new_grad = + grad_out_forward * grad_out_new_grad.get(); } - } - if (grad_out_forward_grad) { - if (grad_out_new_grad) { + if (out_grad) { + auto out_grad_tmp = (scale(grad_x_grad_forward, -2.0) * + (grad_out_forward_mul_grad_out_new_grad + + out * grad_out_grad_grad.get())); + set_output(out_grad_tmp, out_grad); + } + if (grad_out_forward_grad) { auto grad_out_forward_grad_tmp = - -2 * out * grad_x_grad_forward * grad_out_new_grad.get(); + (neg_2_out * grad_x_grad_forward * grad_out_new_grad.get()); set_output(grad_out_forward_grad_tmp, grad_out_forward_grad); - } else { - auto grad_out_forward_grad_tmp = 0 * out; + } + if (grad_x_grad_forward_grad) { + auto grad_x_grad_forward_grad_tmp = + (scale(out * out, -1.0, 1.0) * grad_out_grad_grad.get() + + neg_2_out * grad_out_forward_mul_grad_out_new_grad); + set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); + } + + } else if (grad_out_new_grad) { + /* + dy = -2 * dy * ddx * ddy + ddy = -2 * y * ddx * ddy + dddx = -2 * y * dy * ddy + */ + // regard 'grad_out_grad_grad' is zero + /* precompute '-2 * y' to prevent duplicated computation*/ + Tensor neg_2_out; + if (grad_out_forward_grad || grad_x_grad_forward_grad) { + neg_2_out = scale(out, -2.0); + } + /* precompute 'dy(prev) * ddy' to prevent duplicated computation*/ + Tensor grad_out_forward_mul_grad_out_new_grad; + if (out_grad || grad_x_grad_forward_grad) { + grad_out_forward_mul_grad_out_new_grad = + grad_out_forward * grad_out_new_grad.get(); + } + + if (out_grad) { + auto out_grad_tmp = (scale(grad_x_grad_forward, -2.0) * + (grad_out_forward_mul_grad_out_new_grad)); + set_output(out_grad_tmp, out_grad); + } + if (grad_out_forward_grad) { + auto grad_out_forward_grad_tmp = + (neg_2_out * grad_x_grad_forward * grad_out_new_grad.get()); set_output(grad_out_forward_grad_tmp, grad_out_forward_grad); } - } + if (grad_x_grad_forward_grad) { + auto grad_x_grad_forward_grad_tmp = + (neg_2_out * grad_out_forward_mul_grad_out_new_grad); + set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); + } - if (grad_x_grad_forward_grad) { - if (grad_out_grad_grad) { - if (grad_out_new_grad) { - auto grad_x_grad_forward_grad_tmp = - (1 - (out * out)) * grad_out_grad_grad.get() - - 2 * out * grad_out_forward * grad_out_new_grad.get(); - set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); - } else { - auto grad_x_grad_forward_grad_tmp = - (1 - (out * out)) * grad_out_grad_grad.get(); - set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); - } - } else { - if (grad_out_new_grad) { - auto grad_x_grad_forward_grad_tmp = - -(2 * out * grad_out_forward * grad_out_new_grad.get()); - set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); - } else { - auto grad_x_grad_forward_grad_tmp = 0 * grad_x_grad_forward; - set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); - } + } else if (grad_out_grad_grad) { + /* + dy = -2 * y * ddx * dddy + ddy = 0 + dddx = (1 - y^2) * dddy + */ + // regard 'grad_out_new_grad' is zero + if (out_grad) { + auto out_grad_tmp = (scale(grad_x_grad_forward, -2.0) * + (out * grad_out_grad_grad.get())); + set_output(out_grad_tmp, out_grad); + } + if (grad_out_forward_grad) { + auto grad_out_forward_grad_tmp = + full(common::vectorize(out.dims()), 0, out.dtype()); + set_output(grad_out_forward_grad_tmp, grad_out_forward_grad); + } + if (grad_x_grad_forward_grad) { + auto grad_x_grad_forward_grad_tmp = + (scale(out * out, -1.0, 1.0) * grad_out_grad_grad.get()); + set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); + } + + } else { + /* + dy = 0 + ddy = 0 + dddx = 0 + */ + if (out_grad) { + auto out_grad_tmp = + full(common::vectorize(out.dims()), 0, out.dtype()); + set_output(out_grad_tmp, out_grad); + } + if (grad_out_forward_grad) { + auto grad_out_forward_grad_tmp = + full(common::vectorize(out.dims()), 0, out.dtype()); + set_output(grad_out_forward_grad_tmp, grad_out_forward_grad); + } + if (grad_x_grad_forward_grad) { + auto grad_x_grad_forward_grad_tmp = + full(common::vectorize(grad_x_grad_forward.dims()), + 0, + grad_x_grad_forward.dtype()); + set_output(grad_x_grad_forward_grad_tmp, grad_x_grad_forward_grad); } } } @@ -440,15 +604,17 @@ void silu_double_grad(const Tensor& x, const Tensor& grad_x_grad, Tensor* grad_x, Tensor* grad_out_grad) { - auto sigmoid = 1 / (1 + exp(-x)); - auto tmp1 = 1 - sigmoid; - auto tmp2 = 1 + tmp1 * x; + auto sigmoid = 1 / (scale(exp(scale(x, -1.0)), 1.0, 1.0)); + auto tmp1 = scale(sigmoid, -1.0, 1.0); + auto tmp2 = scale(tmp1 * x, 1.0, 1.0); + auto grad_x_grad_mul_sigmoid = grad_x_grad * sigmoid; if (grad_out_grad) { - auto ddout = grad_x_grad * sigmoid * tmp2; + auto ddout = grad_x_grad_mul_sigmoid * tmp2; set_output(ddout, grad_out_grad); } if (grad_x) { - auto dx = sigmoid * grad_x_grad * out_grad * (1 + (tmp2 - out)) * tmp1; + auto dx = grad_x_grad_mul_sigmoid * out_grad * + (scale(tmp2 - out, 1.0, 1.0)) * tmp1; set_output(dx, grad_x); } } @@ -533,16 +699,15 @@ void add_double_grad(const Tensor& y, Tensor* grad_out_grad) { if (grad_out_grad) { // ddout = ddx + ddy - if (!grad_x_grad && !grad_y_grad) { - Tensor ddout = - full(common::vectorize(grad_out.dims()), 0.0, y.dtype()); - set_output(ddout, grad_out_grad); - } else if (grad_x_grad && !grad_y_grad) { - set_output(grad_x_grad.get(), grad_out_grad); - } else if (grad_y_grad && !grad_x_grad) { - set_output(grad_y_grad.get(), grad_out_grad); - } else { + if (grad_x_grad && grad_y_grad) { set_output(grad_x_grad.get() + grad_y_grad.get(), grad_out_grad); + } else if (grad_x_grad) { + by_pass(grad_x_grad.get(), grad_out_grad); + } else if (grad_y_grad) { + by_pass(grad_y_grad.get(), grad_out_grad); + } else { + set_output(full(common::vectorize(grad_out.dims()), 0.0, y.dtype()), + grad_out_grad); } } } @@ -572,8 +737,6 @@ void add_triple_grad(const paddle::optional& grad_grad_x, } else { by_pass(grad_grad_out_grad, grad_grad_y_grad); } - } else { - grad_grad_y_grad = nullptr; } } if (grad_grad_x_grad) { @@ -594,8 +757,6 @@ void add_triple_grad(const paddle::optional& grad_grad_x, } else { by_pass(grad_grad_out_grad, grad_grad_x_grad); } - } else { - grad_grad_x_grad = nullptr; } } } @@ -612,11 +773,13 @@ void subtract_double_grad(const Tensor& y, if (grad_x_grad && grad_y_grad) { set_output(grad_x_grad.get() - grad_y_grad.get(), grad_out_grad); } else if (grad_x_grad) { - set_output(grad_x_grad.get(), grad_out_grad); + by_pass(grad_x_grad.get(), grad_out_grad); } else if (grad_y_grad) { - set_output(-grad_y_grad.get(), grad_out_grad); + by_pass(-grad_y_grad.get(), grad_out_grad); } else { - grad_out_grad = nullptr; + set_output( + full(common::vectorize(grad_out.dims()), 0, grad_out.dtype()), + grad_out_grad); } } } diff --git a/paddle/fluid/prim/api/manual_prim/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h index 90a25f8bf1e1f..cbbe846671114 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/utils.h +++ b/paddle/fluid/prim/api/manual_prim/utils/utils.h @@ -29,7 +29,7 @@ namespace prim { // We put some api like utils here template Tensor empty(const paddle::experimental::IntArray& shape, - phi::DataType dype, + phi::DataType dtype, const paddle::Place& place); template @@ -37,7 +37,7 @@ Tensor empty_like(const Tensor& x, phi::DataType dtype, const paddle::Place& place); -// copy tensor for output ptr, in static need use assigh op +// copy tensor for output ptr, in static need use assign op template void by_pass(const Tensor& x, Tensor* out); @@ -48,28 +48,31 @@ void set_output(const Tensor& x_tmp, Tensor* x); // These method don't need to be specified static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, const phi::DDim& in_dims) { - std::vector result; int bat = dout_dims.size() - in_dims.size(); - for (int i = 0; i < bat; ++i) { - result.push_back(i); - } + std::vector result(bat); + std::iota(result.begin(), result.end(), 0); + for (int i = 0; i < in_dims.size(); ++i) { if (in_dims[i] == 1) { - result.push_back(i + bat); + if (dout_dims[i + bat] > 1) { + // no need to reduce when dout_dims[i + bat] == 1 though in_dims[i] == 1 + result.push_back(i + bat); + } } else { PADDLE_ENFORCE_EQ( in_dims[i], dout_dims[i + bat], platform::errors::InvalidArgument( "ReduceDims dimension mismatch. Operands could " - "not be broadcast together with the shape of dout = [%s] and " - "the shape of in_dims = [%s]. Received [%d] in X is not equal to " - "[%d] in Y at i:%d.", + "not be broadcast together with the shape of X = [%s] and " + "the shape of Y = [%s]. X.shape[%d](%d) is not equal to " + "Y.shape[%d](%d).", dout_dims, in_dims, + i + bat, dout_dims[i + bat], - in_dims[i], - i)); + i, + in_dims[i])); } } return common::make_ddim(result); @@ -77,6 +80,17 @@ static phi::DDim get_reduce_dims_from_out(const phi::DDim& dout_dims, static phi::DDim get_reduce_dims(const phi::DDim& x_dims, const phi::DDim& y_dims) { + /* + @brief Computing reduction dim(s) from z=f(x, y) to x with right-alignment + broadcast rule. + + * x_dims = [10, 1, 4, 1, 5] + * y_dims = [2, 1, 6, 1] <-- shaped are right-aligned for comparison + * <-- broadcast --> + * z_dims = [10, 2, 4, 6, 5] + * ==> reduce_dims_from_z_to_x = [1, 3] + * ==> reduce_dims_from_z_to_y = [0, 2, 4] + */ auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims); return get_reduce_dims_from_out(out_dims, x_dims); } @@ -114,7 +128,7 @@ static std::vector unsafe_vector_cast(const std::vector& src) { return dst; } -// This fucction compute unsqueeze dims for reshape to replace unsqueeze. +// This function compute unsqueeze dims for reshape to replace unsqueeze. static std::vector get_unsqueeze_dims( const Tensor& origin, const std::vector& axis) { auto origin_dims = origin.shape(); diff --git a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h index 0dd5d6fd4115c..d471b5277e029 100644 --- a/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h +++ b/paddle/fluid/prim/utils/static/composite_grad_desc_maker.h @@ -72,7 +72,7 @@ class CompositeGradOpMakerBase { virtual ~CompositeGradOpMakerBase() = default; virtual std::vector> operator()() { - VLOG(3) << "Runing Composite Grad func for " << fwd_op_.Type() << "_grad "; + VLOG(3) << "Running Composite Grad func for " << fwd_op_.Type() << "_grad "; this->Apply(); std::vector> ops; // TODO(jiabin): Support multiple blocks later diff --git a/paddle/fluid/prim/utils/static/desc_tensor.h b/paddle/fluid/prim/utils/static/desc_tensor.h index 94150a76a3e3e..1adabc7b4e86d 100644 --- a/paddle/fluid/prim/utils/static/desc_tensor.h +++ b/paddle/fluid/prim/utils/static/desc_tensor.h @@ -54,7 +54,7 @@ class DescTensor : public phi::ExtendedTensor, // TODO(jiabin): override more operators here. private: - // VarDesc's lifetime is holded by block and it's program, so we just conceal + // VarDesc's lifetime is held by block and it's program, so we just conceal // its funcs instead of its life. framework::VarDesc* desc_ptr_; // TODO(jiabin): This is really ugly, but we have to hold a dims here so that diff --git a/paddle/fluid/primitive/base/decomp_trans.cc b/paddle/fluid/primitive/base/decomp_trans.cc index f46bcf31248a2..c71da029b4e37 100644 --- a/paddle/fluid/primitive/base/decomp_trans.cc +++ b/paddle/fluid/primitive/base/decomp_trans.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/primitive/base/decomp_trans.h" +#include #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" @@ -25,6 +26,7 @@ COMMON_DECLARE_bool(prim_skip_dynamic); COMMON_DECLARE_bool(prim_check_ops); +COMMON_DECLARE_string(prim_forward_blacklist); using paddle::dialect::DenseTensorType; using paddle::dialect::SelectedRowsType; @@ -44,6 +46,26 @@ std::unordered_set decomp_op_contain_none = {"pd_op.squeeze", std::unordered_set dynamic_shape_blacklist = {"pd_op.squeeze", "pd_op.unsqueeze"}; +namespace { +std::set StringSplit(const std::string& str) { + std::istringstream iss(str); + std::set tokens; + std::string token; + + while (std::getline(iss, token, ';')) { + size_t startpos = token.find_first_not_of(" "); + size_t endpos = token.find_last_not_of(" "); + if ((startpos != std::string::npos) && (endpos != std::string::npos)) { + token = token.substr(startpos, endpos - startpos + 1); + } else if (startpos != std::string::npos) { + token = token.substr(startpos); + } + tokens.insert(token); + } + return tokens; +} +} // namespace + static bool has_dynamic_shape(const phi::DDim& dims) { std::vector vec = common::vectorize(dims); if (std::find(vec.begin(), vec.end(), -1) != vec.end()) { @@ -124,8 +146,8 @@ void DecompProgram::check_ops() { auto primitives_set = GetPrimitiveOpNames(); std::set undecomposed_set; for (const auto& element : decomposed_prog_ops_set_) { - auto iter = primitives_set.find(element); - if (iter == primitives_set.end()) { + if (primitives_set.find(element) == primitives_set.end() && + blacklist_.find(element) == blacklist_.end()) { undecomposed_set.insert(element); } } @@ -173,7 +195,8 @@ void DecompProgram::check_decomp_outputs( decomp_op_contain_none.find(op_name) != decomp_op_contain_none.end(); for (size_t i = 0; i < orig_outs.size(); i++) { if (skip_invalid_op_check && - paddle::dialect::IsEmptyValue(decomp_outs[i])) { + (paddle::dialect::IsEmptyValue(orig_outs[i]) || + paddle::dialect::IsEmptyValue(decomp_outs[i]))) { VLOG(4) << "[Prim] Decomp op skip check of " << i << "-index output of op " << op_name; } else { @@ -314,11 +337,11 @@ bool DecompProgram::enable_decomp_by_filter(const std::string& op_name) { flag = false; } } - if (blacklist_.size() > 0) { - if (blacklist_.find(op_name) != blacklist_.end()) { - flag = false; - } - } + auto from_flag_blacklist = StringSplit(FLAGS_prim_forward_blacklist); + if (from_flag_blacklist.size() > 0) + blacklist_.insert(from_flag_blacklist.begin(), from_flag_blacklist.end()); + if (blacklist_.size() > 0 && blacklist_.find(op_name) != blacklist_.end()) + flag = false; return flag; } diff --git a/paddle/fluid/primitive/base/primitive_ops.h b/paddle/fluid/primitive/base/primitive_ops.h index 29d93498723e3..aa52907f8f7fe 100644 --- a/paddle/fluid/primitive/base/primitive_ops.h +++ b/paddle/fluid/primitive/base/primitive_ops.h @@ -43,8 +43,10 @@ const std::set& GetPrimitiveOpNames() { "pd_op.sum", "pd_op.abs", "pd_op.assign", + "pd_op.assign_value", "pd_op.concat", "pd_op.elementwise_pow", + "pd_op.rsqrt", "pd_op.floor", "pd_op.gather", "pd_op.gather_nd", @@ -57,6 +59,8 @@ const std::set& GetPrimitiveOpNames() { "pd_op.min", "pd_op.maximum", "pd_op.minimum", + "pd_op.argmax", + "pd_op.argmin", "pd_op.prod", "pd_op.roll", "pd_op.scatter", @@ -99,11 +103,15 @@ const std::set& GetPrimitiveOpNames() { "pd_op.data", "builtin.shadow_output", /* skip some special ops */ + "pd_op.conv2d", + "pd_op.pad3d", + "pd_op.nearest_interp", "pd_op.squeeze", "pd_op.unsqueeze", "pd_op.select_input", "pd_op.top_p_sampling", "pd_op.tril", + "pd_op.triu", "cf.yield", "pd_op.increment_", }; diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index fb1579968423a..e4d0e50e60877 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -53,6 +53,7 @@ "embedding_grad", "full", "partial_send", + "push_dense", ] # prim op with one input and one output, with no attribute diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index b5191d62afec6..63cec678eb8ae 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -31,6 +31,13 @@ static Tensor get_slice(const Tensor& x, int64_t idx) { return slice(x, {0}, {idx}, {idx + 1}, {1}, {}); } +template +static Tensor get_slice_vec(const Tensor& x, + int64_t start_idx, + int64_t end_idx) { + return slice(x, {0}, {start_idx}, {end_idx}, {1}, {}); +} + template Tensor any_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { auto org_dtype = x.dtype(); @@ -287,7 +294,11 @@ Tensor log_softmax_decomp(const Tensor& x, const int& axis) { x_tmp = cast(x, DataType::FLOAT32); } - auto res = log(softmax_decomp(x_tmp, axis)); + auto max_tmp = max(x_tmp, {axis}, true); + auto sub = x_tmp - max_tmp; + auto molecular = exp(sub); + auto res = sub - log(sum(molecular, {axis}, molecular.dtype(), true)); + if (need_cast) { return cast(res, org_dtype); } else { @@ -353,22 +364,10 @@ Tensor relu_decomp(const Tensor& x) { } template -Tensor rsqrt_decomp(const Tensor& x) { - auto org_dtype = x.dtype(); - Tensor x_cast = x; - - bool need_cast = is_half_dtype(org_dtype); - if (need_cast) { - x_cast = cast(x, DataType::FLOAT32); - } - - auto ans = - elementwise_pow(x_cast, full(empty_shape, -0.5, x_cast.dtype())); - if (need_cast) { - return cast(ans, org_dtype); - } else { - return ans; - } +Tensor relu6_decomp(const Tensor& x) { + auto tmp = maximum(x, full(empty_shape, 0.0, x.dtype())); + auto res = minimum(tmp, full(empty_shape, 6.0, x.dtype())); + return res; } template @@ -406,6 +405,62 @@ std::tuple layer_norm_decomp( const paddle::optional& bias, float epsilon, int begin_norm_axis) { + if (has_dynamic_shape(x.shape())) { + std::vector axis; + auto org_dtype = x.dtype(); + Tensor x_cast = x; + + bool need_cast = is_half_dtype(org_dtype); + + // cast dtype to float32 if dtype =float16 or bfloat16 + if (need_cast) { + x_cast = cast(x_cast, DataType::FLOAT32); + } + + auto x_dim = x.shape(); + for (size_t i = begin_norm_axis; i < x_dim.size(); i++) { + axis.push_back(static_cast(i)); + } + auto mean_ = mean_decomp(x_cast, axis, true); + auto difference = x_cast - mean_; + auto var_tmp1 = difference * difference; + auto variance = mean_decomp(var_tmp1, axis, true); + auto var_tmp3 = variance + full(empty_shape, epsilon, variance.dtype()); + auto rsqrt_var = rsqrt(var_tmp3); + auto out = difference * rsqrt_var; + + Tensor slice_shape_l = get_slice_vec(shape(x), 0, begin_norm_axis); + Tensor slice_shape_r = + get_slice_vec(shape(x), begin_norm_axis, x_dim.size()); + Tensor scale_cast; + if (scale) { + scale_cast = backend::reshape_with_tensor(scale.get(), slice_shape_r); + if (need_cast) { + scale_cast = cast(scale_cast, DataType::FLOAT32); + } + out = out * scale_cast; + } + Tensor bias_cast; + if (bias) { + bias_cast = backend::reshape_with_tensor(bias.get(), slice_shape_r); + if (need_cast) { + bias_cast = cast(bias_cast, DataType::FLOAT32); + } + out = out + bias_cast; + } + mean_ = backend::reshape_with_tensor(mean_, slice_shape_l); + variance = backend::reshape_with_tensor(variance, slice_shape_l); + + // same as LayerNormInferMeta + // x: float32 --> out: float32, mean: float32, variance: float32 + // x: float16 --> out: float16, mean: float32, variance: float32 + if (need_cast) { + out = cast(out, org_dtype); + } + + return std::make_tuple(out, mean_, variance); + } + std::vector axis; auto org_dtype = x.dtype(); Tensor x_cast = x; @@ -426,13 +481,9 @@ std::tuple layer_norm_decomp( auto var_tmp1 = difference * difference; auto variance = mean_decomp(var_tmp1, axis, true); auto var_tmp3 = variance + epsilon; - auto rsqrt_var = elementwise_pow( - var_tmp3, full(empty_shape, -0.5, var_tmp3.dtype())); + auto rsqrt_var = rsqrt(var_tmp3); auto out = difference * rsqrt_var; - auto scale_ptr = scale.get_ptr(); - auto bias_ptr = bias.get_ptr(); - std::vector slice_shape_l; std::vector slice_shape_r; for (int64_t i = 0; i < static_cast(x_dim.size()); i++) { @@ -443,24 +494,16 @@ std::tuple layer_norm_decomp( } } Tensor scale_cast; - if (scale_ptr) { - if (slice_shape_r != scale_ptr->shape()) { - scale_cast = reshape(*scale_ptr, slice_shape_r); - } else { - scale_cast = *scale_ptr; - } + if (scale) { + scale_cast = reshape(scale.get(), slice_shape_r); if (need_cast) { scale_cast = cast(scale_cast, DataType::FLOAT32); } out = out * scale_cast; } Tensor bias_cast; - if (bias_ptr) { - if (slice_shape_r != bias_ptr->shape()) { - bias_cast = reshape(*bias_ptr, slice_shape_r); - } else { - bias_cast = *bias_ptr; - } + if (bias) { + bias_cast = reshape(bias.get(), slice_shape_r); if (need_cast) { bias_cast = cast(bias_cast, DataType::FLOAT32); } @@ -559,8 +602,7 @@ Tensor sqrt_decomp(const Tensor& x) { x_cast = cast(x, DataType::FLOAT32); } - auto ans = - elementwise_pow(x_cast, full(empty_shape, 0.5, x_cast.dtype())); + auto ans = 1.0 / rsqrt(x_cast); if (need_cast) { return cast(ans, org_dtype); } else { @@ -667,34 +709,23 @@ std::tuple instance_norm_decomp( auto var_tmp1 = difference * difference; auto variance = mean_decomp(var_tmp1, axis, true); auto var_tmp3 = variance + epsilon; - auto rsqrt_var = - elementwise_pow(var_tmp3, full(empty_shape, 0.5, var_tmp3.dtype())); - auto out = difference / rsqrt_var; + auto rsqrt_var = rsqrt(var_tmp3); + auto out = difference * rsqrt_var; - auto scale_ptr = scale.get_ptr(); - auto bias_ptr = bias.get_ptr(); std::vector slice_shape(x_dim.size(), 1); slice_shape[1] = x_dim[1]; Tensor scale_cast; - if (scale_ptr) { - if (slice_shape != scale_ptr->shape()) { - scale_cast = reshape(*scale_ptr, slice_shape); - } else { - scale_cast = *scale_ptr; - } + if (scale) { + scale_cast = reshape(scale.get(), slice_shape); if (need_cast) { scale_cast = cast(scale_cast, DataType::FLOAT32); } out = out * scale_cast; } Tensor bias_cast; - if (bias_ptr) { - if (slice_shape != bias_ptr->shape()) { - bias_cast = reshape(*bias_ptr, slice_shape); - } else { - bias_cast = *bias_ptr; - } + if (bias) { + bias_cast = reshape(bias.get(), slice_shape); if (need_cast) { bias_cast = cast(bias_cast, DataType::FLOAT32); } @@ -703,7 +734,7 @@ std::tuple instance_norm_decomp( std::vector res_shape(1, -1); auto mean_out = reshape(mean_, res_shape); - auto variance_out = reshape(1 / rsqrt_var, res_shape); + auto variance_out = reshape(rsqrt_var, res_shape); Tensor res; if (need_cast) { @@ -729,31 +760,65 @@ std::tuple flatten_decomp(const Tensor& x, "end_axis must be greater than or equal to start_axis.")); } - std::vector tmp_shape(x_dim); - tmp_shape.insert(tmp_shape.begin(), 0); - auto xshape = full(tmp_shape, 0.0, DataType::FLOAT32); - if (x_dim.size() == 0) { - std::vector res_shape(1, 1); - return std::make_tuple(reshape(x, res_shape), xshape); - } - if (end_axis == start_axis) { - return std::make_tuple(reshape(x, x_dim), xshape); - } + if (has_dynamic_shape(x.shape())) { + auto x_shape = shape(x); + Tensor x_shape_tensor = full({1}, 0, x_shape.dtype()); + std::vector tmp_shape; + tmp_shape.push_back(x_shape_tensor); + for (size_t i = 0; i < x_dim.size(); i++) { + tmp_shape.push_back(get_slice(x_shape, i)); + } + x_shape_tensor = concat(tmp_shape); + x_shape_tensor = + backend::full_with_tensor(x_shape_tensor, 0.0, DataType::FLOAT32); + if (end_axis == start_axis) { + return std::make_tuple(backend::reshape(x, x_shape), x_shape_tensor); + } + std::vector out_shape; + + for (size_t i = 0; i < x_dim.size();) { + if (i == static_cast(start_axis)) { + Tensor flat = + slice(x_shape, {0}, {start_axis}, {end_axis + 1}, {1}, {}); + flat = prod(flat, {0}, false, false); + out_shape.push_back(reshape(flat, {1})); + i = end_axis + 1; + } else { + out_shape.push_back(get_slice(x_shape, i)); + i++; + } + } - int slice_numel = 1; - for (int i = start_axis; i <= end_axis; ++i) { - slice_numel *= x_dim[i]; - } - std::vector out_shape; - for (int i = 0; i < start_axis; ++i) { - out_shape.push_back(x_dim[i]); - } - out_shape.push_back(slice_numel); - for (size_t i = end_axis + 1; i < x_dim.size(); ++i) { - out_shape.push_back(x_dim[i]); - } + Tensor out_shape_tensor = concat(out_shape); + return std::make_tuple(backend::reshape(x, out_shape_tensor), + x_shape_tensor); + } else { + std::vector tmp_shape(x_dim); + tmp_shape.insert(tmp_shape.begin(), 0); + auto xshape = full(tmp_shape, 0.0, DataType::FLOAT32); + if (x_dim.size() == 0) { + std::vector res_shape(1, 1); + return std::make_tuple(reshape(x, res_shape), xshape); + } + if (end_axis == start_axis) { + return std::make_tuple(reshape(x, x_dim), xshape); + } - return std::make_tuple(reshape(x, out_shape), xshape); + int slice_numel = 1; + for (int i = start_axis; i <= end_axis; ++i) { + slice_numel *= x_dim[i]; + } + std::vector out_shape; + for (int i = 0; i < start_axis; ++i) { + out_shape.push_back(x_dim[i]); + } + out_shape.push_back(slice_numel); + for (size_t i = end_axis + 1; i < x_dim.size(); ++i) { + out_shape.push_back(x_dim[i]); + } + + return std::make_tuple(reshape(x, out_shape), xshape); + } } template @@ -774,10 +839,21 @@ std::tuple group_norm_decomp( const float epsilon, const int groups, const std::string& data_format) { - if (data_format != "NCHW") { - // TODO(chengyanfu): support NHWC data format - PADDLE_THROW(phi::errors::Unimplemented("Only support NCHW format.")); + std::vector c_axis; + if (data_format == "NCHW") { + c_axis = {1}; + } else if (data_format == "NHWC") { + c_axis = {1, 3}; + } else { + PADDLE_THROW( + phi::errors::Unimplemented("Only support NCHW and NHWC format.")); + } + size_t rank = x.shape().size(); + if (rank < 3 || rank > 5) { + PADDLE_THROW(phi::errors::Unimplemented( + "Only support NCHW and NHWC format in rank {3, 4, 5}.")); } + auto org_dtype = x.dtype(); Tensor x_cast = x; @@ -786,30 +862,62 @@ std::tuple group_norm_decomp( x_cast = cast(x, DataType::FLOAT32); } - auto x_dim = x.shape(); - std::vector one_axis(1, 1); - - std::vector x_shape{x_dim[0] * groups, -1}; - x_cast = reshape(x_cast, x_shape); - auto mean_ = mean_decomp(x_cast, IntArray(one_axis), true); - auto var_tmp_ = - mean_decomp(x_cast * x_cast, IntArray(one_axis), true) - mean_ * mean_; - auto var_ = - maximum(var_tmp_, full(var_tmp_.shape(), 0, var_tmp_.dtype())); - auto var_inv = 1 / sqrt_decomp(var_ + epsilon); - auto res = (x_cast - mean_) * var_inv; - auto out = reshape(res, x_dim); - - auto scale_ptr = scale.get_ptr(); - auto bias_ptr = bias.get_ptr(); - - std::vector slice_bias_shape{-1, 1, 1}; + Tensor x_dim_t; + Tensor out, mean_, var_; + if (has_dynamic_shape(x_cast.shape())) { + x_dim_t = shape(x_cast); + Tensor tar_shape; + if (data_format == "NCHW") { + tar_shape = get_slice(x_dim_t, 0) * groups; + Tensor dim_1 = full({1}, -1, x_dim_t.type()); + tar_shape = concat({tar_shape, dim_1}); + } else { + Tensor N_shape = get_slice(x_dim_t, 0); + Tensor dim_1 = full({1}, -1, x_dim_t.type()); + Tensor C_shape = get_slice(x_dim_t, rank - 1); + Tensor dim_g = full({1}, groups, x_dim_t.type()); + Tensor dim_c_div_g = cast(C_shape / dim_g, x_dim_t.type()); + tar_shape = concat({N_shape, dim_1, dim_g, dim_c_div_g}); + } + x_cast = backend::reshape(x_cast, tar_shape); + mean_ = mean_decomp(x_cast, c_axis, true); + Tensor var_tmp_ = + mean_decomp(x_cast * x_cast, c_axis, true) - mean_ * mean_; + var_ = maximum( + var_tmp_, + backend::full_with_tensor(shape(var_tmp_), 0, var_tmp_.dtype())); + Tensor var_inv = + rsqrt(var_ + full(empty_shape, epsilon, var_.dtype())); + Tensor res = (x_cast - mean_) * var_inv; + out = backend::reshape(res, x_dim_t); + } else { + auto x_dim = x_cast.shape(); + if (data_format == "NCHW") { + x_cast = reshape(x_cast, {x_dim[0] * groups, -1}); + } else { + int c_div_g = x_dim[rank - 1] / groups; + x_cast = reshape(x_cast, {x_dim[0], -1, groups, c_div_g}); + } + mean_ = mean_decomp(x_cast, c_axis, true); + auto var_tmp_ = + mean_decomp(x_cast * x_cast, c_axis, true) - mean_ * mean_; + var_ = maximum(var_tmp_, full(var_tmp_.shape(), 0, var_tmp_.dtype())); + auto var_inv = rsqrt(var_ + full(empty_shape, epsilon, var_.dtype())); + auto res = (x_cast - mean_) * var_inv; + out = reshape(res, x_dim); + } + + std::vector slice_bias_shape; + slice_bias_shape = {-1}; + for (size_t i = 0; i < rank - 2; i++) { + slice_bias_shape.push_back(1); + } Tensor scale_cast; - if (scale_ptr) { - if (slice_bias_shape != scale_ptr->shape()) { - scale_cast = reshape(*scale_ptr, slice_bias_shape); + if (scale) { + if (data_format == "NCHW") { + scale_cast = reshape(scale.get(), slice_bias_shape); } else { - scale_cast = *scale_ptr; + scale_cast = scale.get(); } if (need_cast) { scale_cast = cast(scale_cast, DataType::FLOAT32); @@ -817,22 +925,29 @@ std::tuple group_norm_decomp( out = out * scale_cast; } Tensor bias_cast; - if (bias_ptr) { - if (slice_bias_shape != bias_ptr->shape()) { - bias_cast = reshape(*bias_ptr, slice_bias_shape); + if (bias) { + if (data_format == "NCHW") { + bias_cast = reshape(bias.get(), slice_bias_shape); } else { - bias_cast = *bias_ptr; + bias_cast = bias.get(); } if (need_cast) { bias_cast = cast(bias_cast, DataType::FLOAT32); } out = out + bias_cast; } - - std::vector res_shape{x_dim[0], groups}; - auto mean_out = reshape(mean_, res_shape); - auto var_out = reshape(var_, res_shape); - + Tensor mean_out, var_out; + if (has_dynamic_shape(x_cast.shape())) { + Tensor x_shape = get_slice(x_dim_t, 0); + Tensor dim_1 = full({1}, groups, x_shape.type()); + x_shape = concat({x_shape, dim_1}); + mean_out = backend::reshape(mean_, x_shape); + var_out = backend::reshape(var_, x_shape); + } else { + std::vector res_shape{x.shape().at(0), groups}; + mean_out = reshape(mean_, res_shape); + var_out = reshape(var_, res_shape); + } if (need_cast) { out = cast(out, org_dtype); } @@ -1017,10 +1132,8 @@ template Tensor index_sample_decomp(const Tensor& x, const Tensor& index) { std::vector tmp_shape{-1, 1}; auto index_dim = get_slice(shape(index), 0); - auto start = - backend::full_with_tensor(shape(index_dim), 0, index_dim.dtype()); - auto step = - backend::full_with_tensor(shape(index_dim), 1, index_dim.dtype()); + auto start = full({1}, 0, index_dim.dtype()); + auto step = full({1}, 1, index_dim.dtype()); auto arange_tmp = reshape( backend::arange_with_tensor(start, index_dim, step, index.dtype()), tmp_shape); @@ -1038,6 +1151,26 @@ Tensor index_sample_decomp(const Tensor& x, const Tensor& index) { } } +template +Tensor elu_decomp(const Tensor& x, const float alpha) { + auto org_dtype = x.dtype(); + auto x_cast = x; + + bool need_cast = is_half_dtype(org_dtype); + if (need_cast) { + x_cast = cast(x, DataType::FLOAT32); + } + + const Tensor zero = full(x_cast.shape(), 0, x_cast.type()); + auto tmp_res = alpha * (exp(x_cast) - 1); + auto ans = where(x_cast > zero, x_cast, tmp_res); + if (need_cast) { + return cast(ans, org_dtype); + } else { + return ans; + } +} + } // namespace details } // namespace primitive diff --git a/paddle/fluid/primitive/primitive.yaml b/paddle/fluid/primitive/primitive.yaml index 23ec199fdf0f0..58c3ac09b782a 100644 --- a/paddle/fluid/primitive/primitive.yaml +++ b/paddle/fluid/primitive/primitive.yaml @@ -2,57 +2,121 @@ - subtract - multiply - divide -- less_equal -- less_than +- elementwise_pow +- rsqrt +- sin +- sinh +- asin +- asinh +- cos +- cosh +- acos +- acosh +- tan +- tanh +- atan +- atanh +- abs +- sign +- exp +- expm1 +- log +- log1p +- logit +- erf +- erfinv +- ceil +- floor +- frac +- round +- trunc - equal +- angle +- as_complex +- as_real +- complex +- real +- imag +- conj - not_equal - greater_equal - greater_than +- less_equal +- less_than - bitwise_and - bitwise_not - bitwise_or - bitwise_xor -- exp +- isinf +- isnan +- remainder - scale - matmul -- expand -- sum -- abs - assign -- concat -- elementwise_pow -- floor -- gather -- gather_nd -- log - max - min - maximum - minimum +- argmax +- argmin +- cummax +- cummin +- fmax +- fmin - prod - roll +- gather +- gather_nd - scatter +- scatter_nd - scatter_nd_add -- tile -- transpose +- put_along_axis +- take_along_axis - pad +- sum +- cumprod - cumsum -- put_along_axis -- equal -- greater_than -- less_equal -- sin -- cos +- einsum +- logsumexp +- logcumsumexp +- kron +- masked_select - where -- split +- concat +- repeat_interleave +- unbind +- expand +- shape - reshape -- erf -- tanh +- squeeze +- unsqueeze +- transpose +- tile - cast -- sign - slice -- uniform -- shape +- split +- as_strided +- flip +- roll - full_int_array -- squeeze -- unsqueeze +- empty +- linspace +- logspace +- digamma +- lgamma +- diagonal +- diag_embed +- topk +- kthvalue +- searchsorted +- tril_indices +- triu_indices +- argsort +- sort +- gaussian +- bernoulli +- dirichlet +- poisson +- randint +- uniform +- unique_consecutive diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index f67a74bf3f8ae..ecf95eb234972 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -151,7 +151,12 @@ set(PYBIND_SRCS auto_parallel_py.cc eval_frame_tools.cc cpython_internals.c - eval_frame.c) + eval_frame.c + op_callstack_utils.cc) + +#ifdef PADDLE_WITH_DISTRIBUTE +set(PYBIND_SRCS ${PYBIND_SRCS} dist_api.cc) +#endif if(NOT WITH_SHARED_IR) # Note: We want to compile pir source into paddle.so directly, because @@ -263,7 +268,7 @@ endif() if(WITH_PYTHON) # generate op pybind functions automatically for dygraph. - set(OP_FUNCTION_GENERETOR_DEPS + set(OP_FUNCTION_GENERATOR_DEPS pybind proto_desc executor @@ -272,23 +277,23 @@ if(WITH_PYTHON) engine imperative_profiler imperative_flag) - list(APPEND OP_FUNCTION_GENERETOR_DEPS ${GLOB_OP_LIB}) - list(APPEND OP_FUNCTION_GENERETOR_DEPS ${GLOB_OPERATOR_DEPS}) + list(APPEND OP_FUNCTION_GENERATOR_DEPS ${GLOB_OP_LIB}) + list(APPEND OP_FUNCTION_GENERATOR_DEPS ${GLOB_OPERATOR_DEPS}) if(WITH_NCCL OR WITH_RCCL) - list(APPEND OP_FUNCTION_GENERETOR_DEPS nccl_context) + list(APPEND OP_FUNCTION_GENERATOR_DEPS nccl_context) endif() if(WITH_XPU_BKCL) - list(APPEND OP_FUNCTION_GENERETOR_DEPS bkcl_context) + list(APPEND OP_FUNCTION_GENERATOR_DEPS bkcl_context) endif() if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) - list(APPEND OP_FUNCTION_GENERETOR_DEPS ${PYTHON_LIBRARIES}) + list(APPEND OP_FUNCTION_GENERATOR_DEPS ${PYTHON_LIBRARIES}) endif() if(WITH_CUSTOM_DEVICE) - set(OP_FUNCTION_GENERETOR_DEPS ${OP_FUNCTION_GENERETOR_DEPS} + set(OP_FUNCTION_GENERATOR_DEPS ${OP_FUNCTION_GENERATOR_DEPS} custom_device_common_op_registry) endif() @@ -303,7 +308,7 @@ if(WITH_PYTHON) if(NOT WIN32) add_executable(kernel_signature_generator kernel_signature_generator.cc) target_link_libraries(kernel_signature_generator - ${OP_FUNCTION_GENERETOR_DEPS}) + ${OP_FUNCTION_GENERATOR_DEPS}) endif() get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) @@ -435,7 +440,7 @@ if(WITH_PYTHON) else() # If there are no *.so in /usr/lib or LD_LIBRARY_PATH, # copy these *.so to current directory and append current directory to - # LD_LIBRARY_PATH. This is different with Windows platformm, which search + # LD_LIBRARY_PATH. This is different with Windows platform, which search # *.dll in current directory automatically. if(WITH_ONNXRUNTIME) set(PADDLE2ONNX_PYBIND_OUT diff --git a/paddle/fluid/pybind/auto_parallel_py.cc b/paddle/fluid/pybind/auto_parallel_py.cc index 8a044b678d79b..87895d6b4df31 100644 --- a/paddle/fluid/pybind/auto_parallel_py.cc +++ b/paddle/fluid/pybind/auto_parallel_py.cc @@ -17,6 +17,8 @@ #include #include +#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" +#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/var_desc.h" @@ -24,24 +26,18 @@ #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/fluid/pybind/pybind_variant_caster.h" +#include "paddle/phi/api/lib/data_transform.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/common/reduce_type.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" #include "paddle/phi/core/distributed/auto_parallel/placement_types.h" #include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" -#include "paddle/utils/optional.h" -#include "paddle/utils/pybind.h" - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h" -#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h" -#include "paddle/phi/api/lib/data_transform.h" -#include "paddle/phi/backends/context_pool.h" -#include "paddle/phi/common/reduce_type.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_r_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.h" @@ -53,6 +49,8 @@ #include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/x_to_r_reshard_function.h" #include "paddle/phi/core/enforce.h" +#include "paddle/utils/optional.h" +#include "paddle/utils/pybind.h" #ifdef PADDLE_WITH_DISTRIBUTE #include "paddle/phi/infermeta/spmd_rules/rules.h" @@ -74,8 +72,6 @@ static bool PyCheckInteger(PyObject *obj) { using paddle::distributed::auto_parallel::DistTensorSpec; using paddle::distributed::auto_parallel::kDefault; using paddle::distributed::auto_parallel::OperatorDistAttr; -using paddle::distributed::auto_parallel::SPMDRuleBase; -using paddle::distributed::auto_parallel::SPMDRuleMap; using paddle::framework::BlockDesc; using paddle::framework::OpDesc; using paddle::framework::VarDesc; @@ -590,17 +586,6 @@ void BindAutoParallel(py::module *m) { }) .def("_clean_partial_status", &TensorDistAttr::clean_partial_status); - py::class_(*m, "SPMDRuleBase") - .def("infer_forward", &SPMDRuleBase::InferForward) - .def("infer_backward", - static_cast, - std::vector> (SPMDRuleBase::*)( - const std::vector &, - const std::vector &, - const paddle::framework::AttributeMap &)>( - &SPMDRuleBase::InferBackward)); - // .def("infer_backward", &SPMDRuleBase::InferBackward) [revert in future] - py::class_(*m, "SpmdRule") .def("infer_forward", &infer_forward) .def("infer_backward", &infer_backward); @@ -750,15 +735,7 @@ void BindAutoParallel(py::module *m) { "contains_spmd_rule", [](const std::string op_type) { return phi::distributed::SpmdRuleFactory::Instance().ContainsSpmdRule( - op_type) || - SPMDRuleMap::Instance().Has(op_type); // TODO(ljz): unify here - }, - py::return_value_policy::reference); - - m->def( - "get_spmd_rule", - [](const std::string op_type) { - return SPMDRuleMap::Instance().Get(op_type); + op_type); }, py::return_value_policy::reference); diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 391dbabb1a210..5e202a2b79d2e 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -58,6 +58,7 @@ void BindCommContextManager(py::module *m) { py::arg("size"), py::arg("hash_key") = "", py::arg("p2p_opt") = nullptr, + py::arg("nccl_comm_init_option") = 0, py::call_guard()) #endif #if defined(PADDLE_WITH_XPU_BKCL) diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 535edcfef8853..f342103a8aeb1 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -24,7 +24,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/fluid/pir/utils/general_functions.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/place.h" diff --git a/paddle/fluid/pybind/cpython_internals.c b/paddle/fluid/pybind/cpython_internals.c index 0e5329d6f1287..af7ede116e4b2 100644 --- a/paddle/fluid/pybind/cpython_internals.c +++ b/paddle/fluid/pybind/cpython_internals.c @@ -109,7 +109,7 @@ static void Internal_clear_thread_frame(PyThreadState *tstate, tstate->datastack_top); tstate->c_recursion_remaining--; assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame); - Internal_PyFrame_Clear(frame); // see _PyFrame_ClearExceptCode + Internal_PyFrame_ClearExceptCode(frame); Py_DECREF(frame->f_code); tstate->c_recursion_remaining++; Internal_PyThreadState_PopFrame(tstate, frame); @@ -125,7 +125,7 @@ static void Internal_clear_gen_frame(PyThreadState *tstate, gen->gi_exc_state.previous_item = NULL; tstate->c_recursion_remaining--; assert(frame->frame_obj == NULL || frame->frame_obj->f_frame == frame); - Internal_PyFrame_Clear(frame); // see _PyFrame_ClearExceptCode + Internal_PyFrame_ClearExceptCode(frame); tstate->c_recursion_remaining++; frame->previous = NULL; } @@ -584,7 +584,11 @@ static void Internal_take_ownership(PyFrameObject *f, } // Call on 3.11 _PyFrame_Clear is called on 3.12+ _PyFrame_ClearExceptCode +#if PY_VERSION_HEX >= 0x030c0000 +void Internal_PyFrame_ClearExceptCode(_PyInterpreterFrame *frame) { +#else void Internal_PyFrame_Clear(_PyInterpreterFrame *frame) { +#endif /* It is the responsibility of the owning generator/coroutine * to have cleared the enclosing generator, if any. */ assert(frame->owner != FRAME_OWNED_BY_GENERATOR || diff --git a/paddle/fluid/pybind/cpython_internals.h b/paddle/fluid/pybind/cpython_internals.h index 941279b88f870..fe8330312dc9e 100644 --- a/paddle/fluid/pybind/cpython_internals.h +++ b/paddle/fluid/pybind/cpython_internals.h @@ -43,6 +43,7 @@ void Internal_PyEvalFrameClearAndPop(PyThreadState *tstate, _PyInterpreterFrame *frame); _PyInterpreterFrame *Internal_PyThreadState_PushFrame(PyThreadState *tstate, size_t size); +void Internal_PyFrame_ClearExceptCode(_PyInterpreterFrame *frame); #endif #endif diff --git a/paddle/fluid/pybind/dist_api.cc b/paddle/fluid/pybind/dist_api.cc new file mode 100644 index 0000000000000..44feb061438e8 --- /dev/null +++ b/paddle/fluid/pybind/dist_api.cc @@ -0,0 +1,125 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "pybind11/stl.h" + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pybind/dist_api.h" +#include "paddle/fluid/pybind/dist_static_op_function.h" +#include "paddle/phi/core/enforce.h" + +namespace py = pybind11; + +namespace pybind11 { +namespace detail { +template +struct type_caster> + : map_caster, + Key, + Value> {}; +} // namespace detail +} // namespace pybind11 + +using paddle::dialect::OperationDistAttribute; +using paddle::dialect::TensorDistAttribute; + +namespace paddle { +namespace pybind { + +void BindOperationDistAttribute(py::module *m) { + py::class_ dist_attr(*m, "OperationDistAttribute"); + dist_attr + .def("__str__", + [](OperationDistAttribute &self) { + std::ostringstream print_stream; + print_stream << self; + return print_stream.str(); + }) + .def_property_readonly("process_mesh", + [](OperationDistAttribute &self) { + return self.process_mesh_attr().process_mesh(); + }) + .def("num_operand_dist_attrs", + &OperationDistAttribute::num_operand_dist_attrs) + .def("operand_dist_attrs", &OperationDistAttribute::operand_dist_attrs) + .def("operand_dist_attr", &OperationDistAttribute::operand_dist_attr) + .def("num_result_dist_attrs", + &OperationDistAttribute::num_result_dist_attrs) + .def("result_dist_attrs", &OperationDistAttribute::result_dist_attrs) + .def("result_dist_attr", &OperationDistAttribute::result_dist_attr); +} + +void BindTensorDistAttribute(py::module *m) { + py::class_ dist_attr(*m, "TensorDistAttribute"); + dist_attr + .def("__str__", + [](TensorDistAttribute &self) { + std::ostringstream print_stream; + print_stream << self; + return print_stream.str(); + }) + .def("__eq__", + [](TensorDistAttribute &self, const TensorDistAttribute &other) { + return self == other; + }) + .def_property_readonly("process_mesh", + [](TensorDistAttribute &self) { + return self.process_mesh_attr().process_mesh(); + }) + .def_property_readonly( + "dims_mapping", + [](TensorDistAttribute &self) { return self.dims_mapping(); }) + .def_property_readonly( + "partial_status", + [](TensorDistAttribute &self) { return self.partial_status(); }) + .def_property_readonly("partial_dims", [](TensorDistAttribute &self) { + return self.partial_dims(); + }); +} + +void BindDistOpsAPI(pybind11::module *module) { + { + if (PyModule_AddFunctions(module->ptr(), DistOpsAPI) < 0) { + { + PADDLE_THROW( + phi::errors::Fatal("Add C++ DistOpsAPI to core.ops failed!")); + } + } + } +} + +void BindOpsFunction(py::module *m) { + m->def("reshard_v2", + [](const pir::Value &x, const TensorDistAttribute &dist_attr) { + return reshard(x, dist_attr); + }); +} + +void BindDistApi(pybind11::module *module) { + auto ir_module = module->def_submodule("pir"); + BindOperationDistAttribute(&ir_module); + BindTensorDistAttribute(&ir_module); + auto ops_modules = ir_module.def_submodule("ops"); + BindDistOpsAPI(&ops_modules); + BindOpsFunction(&ops_modules); +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/string/string_helper.h b/paddle/fluid/pybind/dist_api.h similarity index 72% rename from paddle/fluid/string/string_helper.h rename to paddle/fluid/pybind/dist_api.h index 08a715bfbc764..1dafe467207e5 100644 --- a/paddle/fluid/string/string_helper.h +++ b/paddle/fluid/pybind/dist_api.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,4 +14,10 @@ #pragma once -#include "paddle/utils/string/string_helper.h" +#include + +namespace paddle { +namespace pybind { +void BindDistApi(pybind11::module *m); +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/dist_static_op_function.h b/paddle/fluid/pybind/dist_static_op_function.h new file mode 100644 index 0000000000000..afd71b7521567 --- /dev/null +++ b/paddle/fluid/pybind/dist_static_op_function.h @@ -0,0 +1,96 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/distributed/ir/dist_api.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/fluid/pybind/exception.h" +#include "paddle/phi/core/enforce.h" + +namespace paddle { + +namespace pybind { + +static PyObject *static_api_shard_tensor(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add shard_tensor op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *input_obj = PyTuple_GET_ITEM(args, 0); + auto input = CastPyArg2Value(input_obj, "shard_tensor", 0); + + PyObject *process_mesh_obj = PyTuple_GET_ITEM(args, 1); + auto process_mesh = CastPyArg2ProcessMesh(process_mesh_obj, 1); + + PyObject *dims_mapping_obj = PyTuple_GET_ITEM(args, 2); + auto dims_mapping = CastPyArg2VectorOfInt64(dims_mapping_obj, 2); + + // Call ir static api + auto static_api_out = + paddle::dialect::shard_tensor(input, process_mesh, dims_mapping); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + +static PyObject *static_api_reshard(PyObject *self, + PyObject *args, + PyObject *kwargs) { + try { + VLOG(6) << "Add reshard op into program"; + VLOG(8) << "args count: " << (PyTuple_Size(args) / 2); + + // Get Value from args + PyObject *input_obj = PyTuple_GET_ITEM(args, 0); + auto input = CastPyArg2Value(input_obj, "reshard", 0); + + PyObject *process_mesh_obj = PyTuple_GET_ITEM(args, 1); + auto process_mesh = CastPyArg2ProcessMesh(process_mesh_obj, 1); + + PyObject *dims_mapping_obj = PyTuple_GET_ITEM(args, 2); + auto dims_mapping = CastPyArg2VectorOfInt64(dims_mapping_obj, 2); + + // Call ir static api + auto static_api_out = + paddle::dialect::reshard(input, process_mesh, dims_mapping); + + return ToPyObject(static_api_out); + } catch (...) { + ThrowExceptionToPython(std::current_exception()); + return nullptr; + } +} + +static PyMethodDef DistOpsAPI[] = { + {"shard_tensor", + (PyCFunction)(void (*)(void))static_api_shard_tensor, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for shard_tensor."}, + {"reshard", + (PyCFunction)(void (*)(void))static_api_reshard, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for reshard."}, + {nullptr, nullptr, 0, nullptr}}; + +} // namespace pybind + +} // namespace paddle diff --git a/paddle/fluid/pybind/distributed_py.cc b/paddle/fluid/pybind/distributed_py.cc index 4577171fd77bb..a3af17451dc54 100644 --- a/paddle/fluid/pybind/distributed_py.cc +++ b/paddle/fluid/pybind/distributed_py.cc @@ -1235,6 +1235,7 @@ void BindDistributed(py::module *m) { py::arg("world_size"), py::arg("group_id") = 0, py::arg("timeout") = 30 * 60 * 1000, + py::arg("nccl_comm_init_option") = 0, py::call_guard()) .def_static("group_start", distributed::ProcessGroupNCCL::GroupStart) .def_static("group_end", distributed::ProcessGroupNCCL::GroupEnd); @@ -1272,7 +1273,11 @@ void BindDistributed(py::module *m) { py::arg("world_size"), py::arg("group_id") = 0, py::return_value_policy::reference_internal, - py::call_guard()); + py::call_guard()) + .def("get_comm_name", + &distributed::ProcessGroupCustom::GetCommName, + py::arg("rank"), + py::call_guard()); #endif diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 3cb3ccf964ec8..00b6ba994233f 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -442,7 +442,7 @@ Placements ParsePlacementsArgs( Placements placements; const std::string& placements_key = "placements"; - if (kw_order_map[placements_key] <= args_num) { + if (kw_order_map[placements_key] <= args_num) { // NOLINT placements = CastPyArg2VectorOfPlacement( PyTuple_GET_ITEM(args, kw_order_map[placements_key] - 1), kw_order_map[placements_key] - 1); diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 0a72208f36ccc..66ffa2ba23d12 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -567,12 +567,12 @@ PyObject* eager_api_run_custom_op(PyObject* self, VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add un-initialized tensor " "because the optional input is None"; - ctx.EmplaceBackInput(std::move(paddle::Tensor())); + ctx.EmplaceBackInput(paddle::Tensor()); continue; } if (paddle::framework::detail::IsDuplicableVar(input)) { std::vector tensors = - std::move(CastPyArg2VectorOfTensor(obj, i + 1)); // NOLINT + CastPyArg2VectorOfTensor(obj, i + 1); ctx.EmplaceBackInputs(std::move(tensors)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add vector size = " @@ -600,12 +600,12 @@ PyObject* eager_api_run_custom_op(PyObject* self, VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add un-initialized tensor " "because the optional input is None"; - ctx.EmplaceBackInput(std::move(paddle::Tensor())); + ctx.EmplaceBackInput(paddle::Tensor()); continue; } if (paddle::framework::detail::IsDuplicableVar(input)) { std::vector tensors = - std::move(CastPyArg2VectorOfTensor(obj, i + 1, mesh)); // NOLINT + CastPyArg2VectorOfTensor(obj, i + 1, mesh); ctx.EmplaceBackInputs(std::move(tensors)); VLOG(7) << "Custom operator add input " << input << " to CustomOpKernelContext. Add vector size = " @@ -644,7 +644,7 @@ PyObject* eager_api_run_custom_op(PyObject* self, } else if (attr_type_str == "std::string") { ctx.EmplaceBackAttr( CastPyArg2AttrString(obj, attr_start_idx + i)); // NOLINT - } else if (attr_type_str == "std::vector") { + } else if (attr_type_str == "std::vector") { // NOLINT ctx.EmplaceBackAttr(CastPyArg2VectorOfInt(obj, attr_start_idx + i)); } else if (attr_type_str == "std::vector") { ctx.EmplaceBackAttr(CastPyArg2VectorOfFloat(obj, attr_start_idx + i)); @@ -684,7 +684,7 @@ PyObject* eager_api_run_custom_op(PyObject* self, VLOG(7) << "Custom operator add output " << output << " to CustomOpKernelContext. Add un-initialized tensor " "because the inplace optional input is None"; - ctx.EmplaceBackOutput(std::move(paddle::Tensor())); + ctx.EmplaceBackOutput(paddle::Tensor()); continue; } /// inplace vector, initialized tensor. @@ -706,7 +706,7 @@ PyObject* eager_api_run_custom_op(PyObject* self, << " to CustomOpKernelContext. Add initialized Tensor because " "using general or inplace mechanism"; // general Tensor or inplace Tensor, initialized tensor. - ctx.EmplaceBackOutput(std::move(InitializedEmptyTensor())); + ctx.EmplaceBackOutput(InitializedEmptyTensor()); } VLOG(7) << "Run Kernel of Custom Op: " << op_type; diff --git a/paddle/fluid/pybind/eager_legacy_op_function_generator.cc b/paddle/fluid/pybind/eager_legacy_op_function_generator.cc index e7c9c62e01661..835680f38fa53 100644 --- a/paddle/fluid/pybind/eager_legacy_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_legacy_op_function_generator.cc @@ -29,7 +29,7 @@ #include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/fluid/pybind/eager_generator.h" #include "paddle/fluid/pybind/pybind.h" -#include "paddle/fluid/string/string_helper.h" +#include "paddle/utils/string/string_helper.h" // phi #include "paddle/phi/kernels/declarations.h" @@ -212,7 +212,6 @@ std::string GenerateOpFunctionsBody( std::string outs_initializer_with_null = ""; std::string return_str = ""; - int outs_num = 0; for (auto& output : op_proto->outputs()) { auto& out_name = output.name(); @@ -287,10 +286,6 @@ std::string GenerateOpFunctionsBody( } outs_initializer += ","; } - - // return_str += paddle::string::Sprintf(return_template, out_name); - // return_str += ","; - outs_num += 1; } call_api_str += "attrs);"; if (outs_initializer.back() == ',') { diff --git a/paddle/fluid/pybind/eager_math_op_patch.cc b/paddle/fluid/pybind/eager_math_op_patch.cc index 21fd549cb0b2d..17b36e9237e78 100644 --- a/paddle/fluid/pybind/eager_math_op_patch.cc +++ b/paddle/fluid/pybind/eager_math_op_patch.cc @@ -818,10 +818,10 @@ static PyObject* tensor__rdiv__method(TensorObject* self, bool has_other_double = false; if (PyFloat_Check(other_obj) || PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { - if (PyFloat_Check(other_obj)) { + if (PyFloat_Check(other_obj)) { // NOLINT other_double = CastPyArg2Double(other_obj, "__rdiv__", 0); has_other_double = true; - } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { + } else if (PyCheckInteger(other_obj) || IsNumpyType(other_obj)) { // NOLINT other_double = CastPyArg2Double(other_obj, "__rdiv__", 0); has_other_double = true; } diff --git a/paddle/fluid/pybind/eager_method.cc b/paddle/fluid/pybind/eager_method.cc index 6fe07282a2223..d096119235b4c 100644 --- a/paddle/fluid/pybind/eager_method.cc +++ b/paddle/fluid/pybind/eager_method.cc @@ -603,7 +603,7 @@ static PyObject* tensor_method__copy_to(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_reconstruct_from___doc__, +PyDoc_STRVAR(tensor_reconstruct_from___doc__, // NOLINT R"DOC(reconstruct_from_($self, other/) -- @@ -786,7 +786,7 @@ Enables this Tensor to have their grad populated during backward(). It is a no-o >>> print(y.grad) Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=False, [1., 1., 1.]) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_retain_grads(TensorObject* self, PyObject* args, @@ -1219,7 +1219,7 @@ static PyObject* tensor_method_detach_(TensorObject* self, Py_INCREF(reinterpret_cast(self)); return reinterpret_cast(self); EAGER_CATCH_AND_THROW_RETURN_NULL -} +} // NOLINT PyDoc_STRVAR(tensor_method_get_tensor__doc__, R"DOC(get_tensor($self, /) -- @@ -1243,7 +1243,7 @@ Returns the underline tensor in the origin Tensor. - layout: NCHW - dtype: float32 - data: [1] -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_get_underline_tensor(TensorObject* self, PyObject* args, @@ -1449,10 +1449,41 @@ static PyObject* tensor__getitem_from_offset(TensorObject* self, PyObject* kwargs) { EAGER_TRY phi::DenseTensor* ptr = nullptr; + phi::DenseTensor tensor_after_reshard; if (self->tensor.is_selected_rows()) { auto* selected_rows = static_cast(self->tensor.impl().get()); ptr = static_cast(selected_rows->mutable_value()); + } else if (self->tensor.is_dist_tensor()) { +#ifdef PADDLE_WITH_DISTRIBUTE + auto* dist_tensor = + static_cast(self->tensor.impl().get()); + PADDLE_ENFORCE( + dist_tensor->initialized(), + paddle::platform::errors::Fatal( + "The input dist tensor can't be uninitialized for we don't " + "know the correct mesh to be reshard.")); + const auto& placements = dist_tensor->placements(); + bool need_reshard = false; + for (const auto& placement : placements) { + if (!placement->is_replicated()) { + need_reshard = true; + break; + } + } + if (need_reshard) { + tensor_after_reshard = ReshardXToReplicated(dist_tensor); + ptr = &tensor_after_reshard; + } else { + ptr = dist_tensor->unsafe_mutable_value(); + } +#else + PADDLE_THROW(platform::errors::Unavailable( + "The `_getitem_from_offset` method of (Dist)Tensor is not supported " + "in the current PaddlePaddle, please recompile and install " + "PaddlePaddle " + "with the option of `WITH_DISTRIBUTE=ON`.")); +#endif } else { ptr = static_cast(self->tensor.impl().get()); } @@ -1797,10 +1828,11 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self, .is_contiguous()) ? paddle::Tensor( std::make_shared( - std::move(paddle::experimental::Trans2Contiguous( + paddle::experimental::Trans2Contiguous( *(std::dynamic_pointer_cast( - transback_sub_tensor.impl()))))), - transback_sub_tensor.mutable_autograd_meta()) + transback_sub_tensor.impl())))), + transback_sub_tensor.mutable_autograd_meta(), + transback_sub_tensor.name()) : transback_sub_tensor; grad_node = std::shared_ptr( @@ -1955,7 +1987,7 @@ This hook will be called every time the gradient of current Tensor has been full There are two differences with `_register_grad_hook`: 1. This backward hook will be executed after the gradient accumulation completed across batches, - but the hook registered by `_register_grad_hook` will be executed the gradient accumulation + but the hook registered by `_register_grad_hook` will be executed before the gradient accumulation completed in current batch. 2. This backward hook function should have the following signature: @@ -2197,7 +2229,7 @@ Returns the total number of non zero elements in input SparseCooTensor/SparseCsr >>> coo.nnz() 3 -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_get_non_zero_nums(TensorObject* self, PyObject* args, @@ -2247,7 +2279,7 @@ Returns the indices of non zero elements in input SparseCooTensor. [[0, 1, 2], [1, 2, 0]]) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_get_non_zero_indices(TensorObject* self, PyObject* args, @@ -2290,7 +2322,7 @@ Returns the values of non zero elements in input SparseCooTensor. Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1., 2., 3.]) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_get_non_zero_elements(TensorObject* self, PyObject* args, @@ -2344,7 +2376,7 @@ Returns the compressed row index of non zero elements in input SparseCsrTensor. Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True, [0, 2, 3, 5]) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_get_non_zero_crows(TensorObject* self, PyObject* args, @@ -2388,7 +2420,7 @@ Returns the column index of non zero elements in input SparseCsrTensor. Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True, [1, 3, 2, 0, 1]) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_get_non_zero_cols(TensorObject* self, PyObject* args, @@ -2422,7 +2454,7 @@ Whether the Tensor is a Dense Tensor. >>> x = paddle.to_tensor([1.0], stop_gradient=False) >>> print(x.is_dense()) True -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_is_dense(TensorObject* self, PyObject* args, @@ -2452,7 +2484,7 @@ Whether the Tensor is a Distributed Tensor. >>> x = paddle.to_tensor([1.0], stop_gradient=False) >>> print(x.is_dist()) False -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_is_dist(TensorObject* self, PyObject* args, @@ -2489,7 +2521,8 @@ When input is SparseCooTensor/SparseCsrTensor, will return True. When input is D >>> coo.is_sparse() True -)DOC"); +)DOC"); // NOLINT + static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args, PyObject* kwargs) { @@ -2526,7 +2559,7 @@ When input is SparseCooTensor, will return True. When input is DenseTensor/Spars >>> coo.is_sparse_coo() True -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_is_sparse_coo(TensorObject* self, PyObject* args, @@ -2564,7 +2597,7 @@ When input is SparseCsrTensor, will return True. When input is DenseTensor/Spars >>> csr.is_sparse_csr() True -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_is_sparse_csr(TensorObject* self, PyObject* args, @@ -2607,7 +2640,7 @@ When input is SparseCooTensor, will convert `COO` to `CSR` . When input is Dense cols=[1, 2, 0], values=[1., 2., 3.]) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_to_sparse_csr(TensorObject* self, PyObject* args, @@ -2654,7 +2687,7 @@ Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are suppor >>> x.is_same_shape(z) False -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_is_same_shape(TensorObject* self, PyObject* args, @@ -2957,7 +2990,7 @@ Returns the address of the first element of current Tensor. >>> # doctest: +SKIP('return the address') 93220864 >>> # doctest: -SKIP -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_data_ptr(TensorObject* self, PyObject* args, @@ -3019,7 +3052,7 @@ Returns the strides of current Tensor. >>> y = x[1] >>> print(y.get_strides()) [] -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_method_strides(TensorObject* self, PyObject* args, @@ -3061,7 +3094,7 @@ If self tensor is already contiguous, this function returns the current Tensor. >>> y = y.contiguous() >>> print(y) Tensor(shape=[], dtype=int64, place=Place(cpu), stop_gradient=True, 2) -)DOC"); +)DOC"); // NOLINT static PyObject* tensor_contiguous(TensorObject* self, PyObject* args, @@ -3110,7 +3143,8 @@ Whether the Tensor is contiguous. >>> x = paddle.to_tensor([1, 2, 3]) >>> y = x[1] >>> print(y.is_contiguous()) -)DOC"); +)DOC"); // NOLINT + static PyObject* tensor_is_contiguous(TensorObject* self, PyObject* args, PyObject* kwargs) { diff --git a/paddle/fluid/pybind/eager_properties.cc b/paddle/fluid/pybind/eager_properties.cc index 2a2b94b715abd..ba857e9cdbfbd 100644 --- a/paddle/fluid/pybind/eager_properties.cc +++ b/paddle/fluid/pybind/eager_properties.cc @@ -35,12 +35,14 @@ limitations under the License. */ #pragma GCC diagnostic ignored "-Wwrite-strings" +COMMON_DECLARE_bool(enable_pir_api); + namespace paddle { namespace pybind { extern PyTypeObject* p_tensor_type; -PyDoc_STRVAR(tensor_name__doc__, +PyDoc_STRVAR(tensor_name__doc__, // NOLINT R"DOC(name Tensor's name. @@ -75,7 +77,7 @@ PyObject* tensor_properties_get_name(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_type__doc__, +PyDoc_STRVAR(tensor_type__doc__, // NOLINT R"DOC(type Tensor's type. @@ -165,7 +167,7 @@ int tensor_properties_set_name(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NEG } -PyDoc_STRVAR(tensor_stop_gradient__doc__, +PyDoc_STRVAR(tensor_stop_gradient__doc__, // NOLINT R"DOC(stop_gradient Tensor's stop_gradient. @@ -195,7 +197,7 @@ PyObject* tensor_properties_get_stop_gradient(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_data__doc__, +PyDoc_STRVAR(tensor_data__doc__, // NOLINT R"DOC(data Tensor's self. @@ -258,7 +260,7 @@ int tensor_properties_set_data(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NEG } -PyDoc_STRVAR(tensor_grad__doc__, +PyDoc_STRVAR(tensor_grad__doc__, // NOLINT R"DOC(grad Tensor's grad Tensor. @@ -356,7 +358,7 @@ int tensor_properties_set_stop_gradient(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NEG } -PyDoc_STRVAR(tensor_persistable__doc__, +PyDoc_STRVAR(tensor_persistable__doc__, // NOLINT R"DOC(persistable Tensor's persistable. @@ -395,7 +397,7 @@ int tensor_properties_set_persistable(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NEG } -PyDoc_STRVAR(tensor_process_mesh__doc__, +PyDoc_STRVAR(tensor_process_mesh__doc__, // NOLINT R"DOC(process_mesh Get process_mesh property from shard tensor. @@ -441,7 +443,7 @@ PyObject* tensor_properties_get_process_mesh(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_placements__doc__, +PyDoc_STRVAR(tensor_placements__doc__, // NOLINT R"DOC(placements Get placements property from shard tensor. @@ -487,7 +489,7 @@ PyObject* tensor_properties_get_placements(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_num_shard__doc__, +PyDoc_STRVAR(tensor_num_shard__doc__, // NOLINT R"DOC(num_shard Tensor's num_shard. @@ -553,7 +555,7 @@ PyObject* tensor_properties_get_local_shape(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_shape__doc__, +PyDoc_STRVAR(tensor_shape__doc__, // NOLINT R"DOC(shape Tensor's shape. @@ -640,7 +642,7 @@ PyObject* tensor_properties_get_shape(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_strides__doc__, +PyDoc_STRVAR(tensor_strides__doc__, // NOLINT R"DOC(strides Tensor's strides. @@ -679,7 +681,7 @@ PyObject* tensor_properties_get_strides(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_offset__doc__, +PyDoc_STRVAR(tensor_offset__doc__, // NOLINT R"DOC(offset The address of the first element relative to the offset of the video memory. @@ -726,7 +728,7 @@ PyObject* tensor_properties_get_offset(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_layout__doc__, +PyDoc_STRVAR(tensor_layout__doc__, // NOLINT R"DOC(layout Tensor's memory layout. @@ -761,7 +763,7 @@ PyObject* tensor_properties_get_layout(TensorObject* self, void* closure) { EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_place__doc__, +PyDoc_STRVAR(tensor_place__doc__, // NOLINT R"DOC(place The device Tensor's memory locate. @@ -828,7 +830,7 @@ PyObject* tensor_properties_get_placements_str(TensorObject* self, EAGER_CATCH_AND_THROW_RETURN_NULL } -PyDoc_STRVAR(tensor_dtype__doc__, +PyDoc_STRVAR(tensor_dtype__doc__, // NOLINT R"DOC(dtype Tensor's data type. @@ -847,25 +849,47 @@ Tensor's data type. )DOC"); PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) { EAGER_TRY - if (!self->tensor.defined()) { - // be same to old dygraph - return ToPyObject(framework::proto::VarType::FP32); - } - if (egr::IsVariableCompatTensor(self->tensor)) { - auto* var_tensor = static_cast( - self->tensor.impl().get()); - if (var_tensor->IsType()) { - return ToPyObject(framework::proto::VarType::RAW); - } else if (var_tensor->IsType()) { - return ToPyObject(framework::proto::VarType::STRING); + if (FLAGS_enable_pir_api) { + if (!self->tensor.defined()) { + // be same to old dygraph + return ToPyObject(phi::DataType::FLOAT32); + } + if (egr::IsVariableCompatTensor(self->tensor)) { + auto* var_tensor = static_cast( + self->tensor.impl().get()); + if (var_tensor->IsType()) { + return ToPyObject(phi::DataType::UNDEFINED); + } else if (var_tensor->IsType()) { + return ToPyObject(phi::DataType::PSTRING); + } else { + PADDLE_THROW(paddle::platform::errors::Unavailable( + "VariableCompatTensor only support get shape from Vocab or " + "Strings.")); + } } else { - PADDLE_THROW(paddle::platform::errors::Unavailable( - "VariableCompatTensor only support get shape from Vocab or " - "Strings.")); + return ToPyObject(self->tensor.type()); } } else { - return ToPyObject( - paddle::framework::TransToProtoVarType(self->tensor.type())); + if (!self->tensor.defined()) { + // be same to old dygraph + return ToPyObject(framework::proto::VarType::FP32); + } + if (egr::IsVariableCompatTensor(self->tensor)) { + auto* var_tensor = static_cast( + self->tensor.impl().get()); + if (var_tensor->IsType()) { + return ToPyObject(framework::proto::VarType::RAW); + } else if (var_tensor->IsType()) { + return ToPyObject(framework::proto::VarType::STRING); + } else { + PADDLE_THROW(paddle::platform::errors::Unavailable( + "VariableCompatTensor only support get shape from Vocab or " + "Strings.")); + } + } else { + return ToPyObject( + paddle::framework::TransToProtoVarType(self->tensor.type())); + } } EAGER_CATCH_AND_THROW_RETURN_NULL } diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index daaac0c20e780..fb4235f619e99 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -478,9 +478,11 @@ PyObject* pylayer_method_apply(PyObject* cls, for (size_t i = 0; i < inputs_autograd_meta.size(); i++) { if (ctx->forward_input_tensor_is_duplicable[i]) { + std::vector tmp; for (auto t : inputs_tensor[i]) { - grad_node->SetGradOutMeta(*t, i); + tmp.push_back(t); } + grad_node->SetGradOutMeta(tmp, i); } else { grad_node->SetGradOutMeta(*inputs_tensor[i][0], i); } @@ -490,9 +492,7 @@ PyObject* pylayer_method_apply(PyObject* cls, if (ctx->forward_output_tensor_is_duplicable[i]) { egr::EagerUtils::SetOutRankWithSlot(&outputs_autograd_meta[i], i); egr::EagerUtils::SetHistory(&outputs_autograd_meta[i], grad_node); - for (auto t : outputs_tensor[i]) { - grad_node->SetGradInMeta(*t, i); - } + grad_node->SetGradInMeta(outputs_tensor[i], i); } else { egr::EagerUtils::SetOutRankWithSlot(outputs_autograd_meta[i][0], i); egr::EagerUtils::SetHistory(outputs_autograd_meta[i][0], grad_node); diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index d613c008b4958..aba7c99662bbe 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -518,7 +518,7 @@ std::vector CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos) { } else if (obj == Py_None) { return {}; } else if (PyObject_CheckLongOrConvertToLong(&obj)) { - return {static_cast(PyLong_AsLong(obj))}; + return {static_cast(PyLong_AsLong(obj))}; // NOLINT } else { PADDLE_THROW(platform::errors::InvalidType( "argument (position %d) must be " @@ -566,7 +566,7 @@ std::vector CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos) { } else if (obj == Py_None) { return {}; } else if (PyObject_CheckLongOrConvertToLong(&obj)) { - return {PyLong_AsSize_t(obj)}; + return {PyLong_AsSize_t(obj)}; // NOLINT } else { PADDLE_THROW(platform::errors::InvalidType( "argument (position %d) must be " @@ -614,7 +614,7 @@ std::vector CastPyArg2VectorOfFloat(PyObject* obj, size_t arg_pos) { } else if (obj == Py_None) { return {}; } else if (PyObject_CheckFloatOrConvertToFloat(&obj)) { - return {static_cast(PyFloat_AsDouble(obj))}; + return {static_cast(PyFloat_AsDouble(obj))}; // NOLINT } else { PADDLE_THROW(platform::errors::InvalidType( "argument (position %d) must be " @@ -647,7 +647,7 @@ std::vector> CastPyArg2VectorOfVectorOfSize_t( platform::Place CastPyArg2Place(PyObject* obj, ssize_t arg_pos) { platform::Place place; - if (PyObject_TypeCheck(obj, g_place_pytype)) { + if (PyObject_TypeCheck(obj, g_place_pytype)) { // NOLINT place = ::pybind11::handle(obj).cast(); } else if (PyObject_TypeCheck(obj, g_cudaplace_pytype)) { place = ::pybind11::handle(obj).cast(); @@ -761,7 +761,8 @@ std::vector CastPyArg2VectorOfTensorBase(PyObject* obj, i)); } } - } else if (PyObject_TypeCheck(obj, g_framework_lodtensorarray_pytype)) { + } else if (PyObject_TypeCheck(obj, + g_framework_lodtensorarray_pytype)) { // NOLINT for (auto& tensor : (::pybind11::handle(obj).cast())) { result.emplace_back(tensor); @@ -788,7 +789,7 @@ using phi::distributed::Shard; Placements CastPyArg2VectorOfPlacement(PyObject* obj, ssize_t arg_pos) { Placements result; auto check_and_emplace = [&](PyObject* item, ssize_t i) { - if (PyObject_TypeCheck(item, g_placement_shard_pytype)) { + if (PyObject_TypeCheck(item, g_placement_shard_pytype)) { // NOLINT result.emplace_back( std::make_shared(::pybind11::handle(item).cast())); } else if (PyObject_TypeCheck(item, g_placement_replicated_pytype)) { @@ -1076,6 +1077,12 @@ PyObject* ToPyObject(const phi::DenseTensor* value) { return obj.ptr(); } +PyObject* ToPyObject(const phi::DataType& dtype) { + auto obj = ::pybind11::cast(dtype); + obj.inc_ref(); + return obj.ptr(); +} + PyObject* ToPyObject(const pir::Value& value) { auto obj = ::pybind11::cast(value); obj.inc_ref(); @@ -2409,9 +2416,11 @@ paddle::DataType CastPyArg2DataType(PyObject* obj, if (obj == Py_None) { return phi::DataType::UNDEFINED; } - - framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos); - return framework::TransToPhiDataType(type); + if (PyObject_TypeCheck(obj, g_vartype_pytype)) { + framework::proto::VarType::Type type = CastPyArg2ProtoType(obj, arg_pos); + return framework::TransToPhiDataType(type); + } + return CastPyArg2DataTypeDirectly(obj, op_type, arg_pos); } paddle::Tensor PyTensorHook::operator()(const paddle::Tensor& var) { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 2511ddb57dbb5..e56741aa90776 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -148,6 +148,7 @@ PyObject* ToPyObject(const phi::distributed::Placements& value); PyObject* ToPyObject(const phi::SelectedRows* value); PyObject* ToPyObject(const paddle::framework::proto::VarType::Type& dtype); PyObject* ToPyObject(const paddle::framework::proto::VarType& type); +PyObject* ToPyObject(const phi::DataType& type); PyObject* ToPyObject(const void* value); PyObject* ToPyObject(const std::unordered_map& value); PyObject* ToPyObject( diff --git a/paddle/fluid/pybind/eval_frame.c b/paddle/fluid/pybind/eval_frame.c index 3e5b50211cdec..aa5a4c0022fcc 100644 --- a/paddle/fluid/pybind/eval_frame.c +++ b/paddle/fluid/pybind/eval_frame.c @@ -366,6 +366,9 @@ static PyObject *_custom_eval_frame(PyThreadState *tstate, PyObject *result = PyObject_CallObject(callback, args); Py_DECREF(args); if (result == NULL) { +#if PY_VERSION_HEX >= 0x030C0000 + Internal_PyEvalFrameClearAndPop(tstate, frame); +#endif return NULL; } code = PyObject_GetAttrString(result, "code"); diff --git a/paddle/fluid/pybind/eval_frame_tools.cc b/paddle/fluid/pybind/eval_frame_tools.cc index da78ce66373e8..f0209f90610ee 100644 --- a/paddle/fluid/pybind/eval_frame_tools.cc +++ b/paddle/fluid/pybind/eval_frame_tools.cc @@ -34,12 +34,12 @@ class TreeNode { private: int is_prefix; - TreeNode* children[256]; + TreeNode* children[256]; // NOLINT }; void TreeNode::clear() { - for (int i = 0; i < 256; i++) { - if (children[i] != nullptr) delete children[i]; + for (auto& i : children) { + if (i != nullptr) delete i; } } @@ -200,8 +200,8 @@ void CodeStatus::add_with_graph_code(PyCodeObject* code) { } void CodeStatus::clear() { - for (auto iter = code_map.begin(); iter != code_map.end(); iter++) { - delete iter->second; + for (auto& iter : code_map) { + delete iter.second; } code_map.clear(); } diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index c540fe0687d88..b70efdbabbebc 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -651,10 +651,6 @@ void BindImperative(py::module *m_ptr) { *(imperative::AmpOperators::Instance().GetMutableAllowOps()), *(imperative::AmpOperators::Instance().GetMutableBlockOps())); }); - py::class_(m, "ProgramDescTracer", "") - .def("create_program_desc", - &imperative::jit::ProgramDescTracer::CreateProgramDesc) - .def("reset", &imperative::jit::ProgramDescTracer::Reset); py::enum_(m, "AmpLevel", py::arithmetic()) .value("O0", paddle::imperative::AmpLevel::O0) @@ -679,9 +675,6 @@ void BindImperative(py::module *m_ptr) { py::class_>( m, "Tracer", R"DOC()DOC") .def(py::init([]() { return std::make_unique(); })) - .def_property("_enable_program_desc_tracing", - &imperative::Tracer::IsProgramDescTracingEnabled, - &imperative::Tracer::SetEnableProgramDescTracing) .def_property("_use_promote", &imperative::Tracer::GetUsePromote, &imperative::Tracer::SetUsePromote) @@ -745,9 +738,6 @@ void BindImperative(py::module *m_ptr) { "but got Unknown Type!")); } }) - .def("_get_program_desc_tracer", - &imperative::Tracer::GetProgramDescTracer, - py::return_value_policy::reference) .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName, py::arg("key") = "dygraph_tmp") @@ -1357,8 +1347,9 @@ void BindImperative(py::module *m_ptr) { auto *index_data = index_tensor.data(); auto *buffer_data = buffer_tensor->mutable_data(buffer_tensor->place()); - const int &slice_size = src_tensor.numel() / src_tensor.dims()[0]; - const int ©_bytes = slice_size * sizeof(float); + const int &slice_size = + static_cast(src_tensor.numel()) / src_tensor.dims()[0]; + const int ©_bytes = static_cast(slice_size) * sizeof(float); int64_t c = 0; for (int64_t i = 0; i < index_tensor.numel(); i++) { std::memcpy(buffer_data + c * slice_size, diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 268806509031e..457bc649f98d1 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -803,7 +803,7 @@ void BindAnalysisConfig(py::module *m) { &AnalysisConfig::EnableXpu, py::arg("l3_size") = 16 * 1024 * 1024, py::arg("l3_locked") = false, - py::arg("conv_autotune") = true, + py::arg("conv_autotune") = false, py::arg("conv_autotune_file") = "", py::arg("transformer_encoder_precision") = "int16", py::arg("transformer_encoder_adaptive_seqlen") = false, @@ -869,6 +869,8 @@ void BindAnalysisConfig(py::module *m) { .def("enable_new_executor", &AnalysisConfig::EnableNewExecutor, py::arg("x") = true) + .def("enable_new_ir", &AnalysisConfig::EnableNewIR, py::arg("x") = true) + .def("new_ir_enabled", &AnalysisConfig::new_ir_enabled) .def("enable_profile", &AnalysisConfig::EnableProfile) .def("disable_glog_info", &AnalysisConfig::DisableGlogInfo) .def("glog_info_disabled", &AnalysisConfig::glog_info_disabled) @@ -926,6 +928,7 @@ void BindAnalysisConfig(py::module *m) { .def("enable_tuned_tensorrt_dynamic_shape", &AnalysisConfig::EnableTunedTensorRtDynamicShape, py::arg("shape_range_info_path") = "", + py::arg("allow_build_at_runtime") = true) .def("tuned_tensorrt_dynamic_shape", &AnalysisConfig::tuned_tensorrt_dynamic_shape) @@ -934,6 +937,10 @@ void BindAnalysisConfig(py::module *m) { .def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs) .def("exp_disable_tensorrt_subgraph", &AnalysisConfig::Exp_DisableTensorRtSubgraph) + .def("exp_specify_tensorrt_subgraph_precision", + &AnalysisConfig::Exp_SpecifyTensorRTSubgraphPrecision) + .def("exp_disable_tensorrt_dynamic_shape_ops", + &AnalysisConfig::Exp_DisableTensorRTDynamicShapeOPs) .def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA, py::arg("dla_core") = 0) @@ -974,7 +981,8 @@ void BindAnalysisConfig(py::module *m) { .def("lite_engine_enabled", &AnalysisConfig::lite_engine_enabled) .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug, - py::arg("x") = true) + py::arg("x") = true, + py::arg("passes") = std::vector()) .def("enable_mkldnn", &AnalysisConfig::EnableMKLDNN) .def("disable_mkldnn", &AnalysisConfig::DisableMKLDNN) .def("mkldnn_enabled", &AnalysisConfig::mkldnn_enabled) @@ -1029,6 +1037,13 @@ void BindAnalysisConfig(py::module *m) { return dynamic_cast(self.pass_builder()); }, py::return_value_policy::reference) + .def("enable_custom_passes", + &AnalysisConfig::EnableCustomPasses, + py::arg("passes") = std::vector(), + py::arg("custom_pass_only") = false) + .def("set_optimization_level", + &AnalysisConfig::SetOptimizationLevel, + py::arg("opt_level") = 2) .def("nnadapter", &AnalysisConfig::NNAdapter) .def("set_dist_config", &AnalysisConfig::SetDistConfig) .def("dist_config", &AnalysisConfig::dist_config); @@ -1210,8 +1225,8 @@ void BindPaddleInferPredictor(py::module *m) { .def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory) .def("clear_intermediate_tensor", &paddle_infer::Predictor::ClearIntermediateTensor) - .def("register_output_hook", - &paddle_infer::Predictor::RegisterOutputHook); + .def("register_output_hook", &paddle_infer::Predictor::RegisterOutputHook) + .def("register_input_hook", &paddle_infer::Predictor::RegisterInputHook); } void BindZeroCopyTensor(py::module *m) { diff --git a/paddle/fluid/pybind/manual_static_op_function.h b/paddle/fluid/pybind/manual_static_op_function.h index ced41e6905e5c..7767c4a4569b3 100644 --- a/paddle/fluid/pybind/manual_static_op_function.h +++ b/paddle/fluid/pybind/manual_static_op_function.h @@ -24,6 +24,7 @@ #include "paddle/fluid/pir/dialect/operator/utils/utils.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/exception.h" +#include "paddle/fluid/pybind/op_callstack_utils.h" #include "paddle/fluid/pybind/op_function_common.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/enforce.h" @@ -43,8 +44,10 @@ static PyObject *static_api_parameter(PyObject *self, PyObject *name_obj = PyTuple_GET_ITEM(args, 0); std::string name = CastPyArg2String(name_obj, "name", 0); // Call ir static api + CallStackRecorder callstack_recoder("parameter"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::parameter(name); - + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { ThrowExceptionToPython(std::current_exception()); @@ -67,8 +70,10 @@ static PyObject *static_api_set_parameter(PyObject *self, PyObject *name_obj = PyTuple_GET_ITEM(args, 1); std::string name = CastPyArg2String(name_obj, "name", 1); // Call ir static api + CallStackRecorder callstack_recoder("set_parameter"); + callstack_recoder.Record(); paddle::dialect::set_parameter(parameter, name); - + callstack_recoder.AttachToOps(); Py_RETURN_NONE; } catch (...) { ThrowExceptionToPython(std::current_exception()); @@ -91,8 +96,10 @@ static PyObject *static_api_set_persistable_value(PyObject *self, PyObject *name_obj = PyTuple_GET_ITEM(args, 1); std::string name = CastPyArg2String(name_obj, "name", 1); // Call ir static api + CallStackRecorder callstack_recoder("shadow_output"); + callstack_recoder.Record(); paddle::dialect::shadow_output(persist_value, name); - + callstack_recoder.AttachToOps(); Py_RETURN_NONE; } catch (...) { ThrowExceptionToPython(std::current_exception()); @@ -119,7 +126,10 @@ PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) { !PyObject_CheckIRValue(value_obj)) { std::vector shape = CastPyArg2Longs(shape_obj, "full", 0); float value = CastPyArg2Float(value_obj, "full", 1); + CallStackRecorder callstack_recoder("full"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::full(shape, value, dtype, place); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } else { pir::Value shape, value; @@ -146,8 +156,12 @@ PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) { phi::CPUPlace()); } + CallStackRecorder callstack_recoder("full_with_tensor"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::full_with_tensor(shape, value, dtype); + callstack_recoder.AttachToOps(); + return ToPyObject(static_api_out); } } catch (...) { @@ -169,7 +183,10 @@ static PyObject *static_api_create_array(PyObject *self, CastPyArg2DataTypeDirectly(dtype_obj, "create_array", 0); // Call ir static api + CallStackRecorder callstack_recoder("create_array"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::create_array(dtype); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -194,8 +211,10 @@ static PyObject *static_api_create_array_like(PyObject *self, float value = CastPyArg2Float(value_obj, "create_array_like", 1); // Call ir static api + CallStackRecorder callstack_recoder("create_array_like"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::create_array_like(input, value); - + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { ThrowExceptionToPython(std::current_exception()); @@ -215,7 +234,10 @@ static PyObject *static_api_array_length(PyObject *self, auto x = CastPyArg2Value(x_obj, "array_length", 0); // Call ir static api + CallStackRecorder callstack_recoder("array_length"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::array_length(x); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -248,7 +270,10 @@ static PyObject *static_api_array_read(PyObject *self, } // Call ir static api + CallStackRecorder callstack_recoder("array_read"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::array_read(array, i); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -282,7 +307,10 @@ static PyObject *static_api_array_write_(PyObject *self, } // Call ir static api + CallStackRecorder callstack_recoder("array_write_"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::array_write_(array, x, i); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -321,7 +349,10 @@ static PyObject *static_api_array_to_tensor(PyObject *self, auto use_stack = CastPyArg2Boolean(use_stack_obj, "array_to_tensor", 2); // Call ir static api + CallStackRecorder callstack_recoder("array_to_tensor"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::array_to_tensor(x, axis, use_stack); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -341,10 +372,10 @@ PyObject *static_api_add_n_array(PyObject *self, PyObject *inputs_obj = PyTuple_GET_ITEM(args, 0); auto inputs = CastPyArg2VectorOfValue(inputs_obj, "add_n", 0); - // Parse Attributes - - // Call ir static api + CallStackRecorder callstack_recoder("add_n_array"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::add_n_array(inputs); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -395,7 +426,10 @@ static PyObject *static_api_slice_array(PyObject *self, } // Call ir static api + CallStackRecorder callstack_recoder("slice_array"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::slice_array(input, starts, ends); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -430,9 +464,11 @@ static PyObject *static_api_slice_array_dense(PyObject *self, starts = paddle::dialect::full_int_array( starts_tmp, phi::DataType::INT64, phi::CPUPlace()); } - // Call ir static api + CallStackRecorder callstack_recoder("slice_array_dense"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::slice_array_dense(input, starts); + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { @@ -500,13 +536,17 @@ static PyObject *static_api_run_custom_op(PyObject *self, VLOG(7) << "Add un-initialized tensor " "because the optional input is None"; if (paddle::framework::detail::IsDuplicableVar(input)) { - vec_input_shapes.emplace_back(); - vec_input_dtypes.emplace_back(); + std::vector> vec_input_shape; + std::vector vec_input_dtype; + vec_input_shapes.emplace_back(vec_input_shape); + vec_input_dtypes.emplace_back(vec_input_dtype); vec_input_name2id_map[inputs[i]] = vec_input_index; vec_input_index++; } else { - input_shapes.emplace_back(); - input_dtypes.emplace_back(); + std::vector input_shape; + DataType input_dtype = DataType::UNDEFINED; + input_shapes.emplace_back(input_shape); + input_dtypes.emplace_back(input_dtype); input_name2id_map[inputs[i]] = input_index; input_index++; } @@ -519,7 +559,7 @@ static PyObject *static_api_run_custom_op(PyObject *self, vec_input_name2id_map[inputs[i]] = vec_input_index; vec_input_index++; std::vector input_values = - std::move(CastPyArg2VectorOfValue(obj, op_type, i + 1)); // NOLINT + CastPyArg2VectorOfValue(obj, op_type, i + 1); for (auto &input_value : input_values) { paddle::dialect::DenseTensorType input_tensor = input_value.type().dyn_cast(); @@ -529,8 +569,10 @@ static PyObject *static_api_run_custom_op(PyObject *self, } vec_input_shapes.push_back(tmp_input_shapes); vec_input_dtypes.push_back(tmp_input_dtypes); - auto input_value = paddle::dialect::stack(input_values, /*axis*/ 0); - argument_inputs.push_back(input_value); + auto combine_op = paddle::dialect::ApiBuilder::Instance() + .GetBuilder() + ->Build(input_values); + argument_inputs.push_back(combine_op.out()); } else { input_name2id_map[inputs[i]] = input_index; input_index++; @@ -681,13 +723,20 @@ static PyObject *static_api_run_custom_op(PyObject *self, "`SetInplaceMap` in your output when registry custom operator.")); const auto &input = inplace_reverse_map.at(output); auto index = vec_input_name2id_map[input]; - auto &input_shapes = vec_input_shapes[index]; - output_name2value_num[output] = input_shapes.size(); - all_values_num += input_shapes.size(); + auto &vec_input_shape = vec_input_shapes[index]; + output_name2value_num[output] = vec_input_shape.size(); } else { - output_name2value_num[output] = 1; - all_values_num++; + if (inplace_reverse_map.find(output) != inplace_reverse_map.end()) { + const auto &input = inplace_reverse_map.at(output); + auto index = input_name2id_map[input]; + // input_shapes[index] is dim of tensor, if the dim doesn't have + // element, it must be a optional tensor that is None in custom operator + output_name2value_num[output] = input_shapes[index].size() == 0 ? 0 : 1; + } else { + output_name2value_num[output]++; + } } + all_values_num += output_name2value_num[output]; } PADDLE_ENFORCE_EQ( @@ -715,8 +764,14 @@ static PyObject *static_api_run_custom_op(PyObject *self, size_t value_index = 0; for (size_t i = 0; i < outputs.size(); ++i) { const auto &output = outputs.at(i); + auto value_num = output_name2value_num[output]; + if (value_num == 0) { + // Optional value condition + pir::Type out_type; + argument_outputs.push_back(out_type); + continue; + } if (paddle::framework::detail::IsDuplicableVar(output)) { - auto value_num = output_name2value_num[output]; std::vector out_types; for (size_t j = 0; j < value_num; ++j) { auto ddims = phi::make_ddim(output_shapes[value_index]); @@ -754,7 +809,8 @@ static PyObject *static_api_run_custom_op(PyObject *self, argument.AddOutputs(argument_outputs.begin(), argument_outputs.end()); ::pir::PassStopGradientsDefaultly(argument); - + CallStackRecorder callstack_recoder("run_custom_op"); + callstack_recoder.Record(); std::vector op_results; pir::Operation *op = paddle::dialect::ApiBuilder::Instance().GetBuilder()->Build( @@ -762,17 +818,19 @@ static PyObject *static_api_run_custom_op(PyObject *self, for (size_t i = 0; i < outputs.size(); ++i) { const auto &output = outputs.at(i); if (paddle::framework::detail::IsDuplicableVar(output)) { - auto split_op = paddle::dialect::ApiBuilder::Instance() - .GetBuilder() - ->Build(op->result(i)); - auto split_outputs = split_op.outputs(); - op_results.insert( - op_results.end(), split_outputs.begin(), split_outputs.end()); + if (op->result(i).type().dyn_cast()) { + auto split_op = paddle::dialect::ApiBuilder::Instance() + .GetBuilder() + ->Build(op->result(i)); + auto split_outputs = split_op.outputs(); + op_results.insert( + op_results.end(), split_outputs.begin(), split_outputs.end()); + } } else { op_results.push_back(op->result(i)); } } - + callstack_recoder.AttachToOps(); return ToPyObject(op_results); } @@ -811,10 +869,13 @@ static PyObject *static_api_fused_gemm_epilogue(PyObject *self, PyObject *activation_obj = PyTuple_GET_ITEM(args, 5); std::string activation = CastPyArg2String(activation_obj, "fused_gemm_epilogue", 5); - // Call ir static api + CallStackRecorder callstack_recoder("fused_gemm_epilogue"); + callstack_recoder.Record(); auto out = paddle::dialect::fused_gemm_epilogue( x, y, bias, trans_x, trans_y, activation); + callstack_recoder.AttachToOps(); + return ToPyObject(out); } catch (...) { ThrowExceptionToPython(std::current_exception()); @@ -836,8 +897,10 @@ static PyObject *static_api_array_pop(PyObject *self, auto index = CastPyArg2Int(index_obj, "array_pop", 1); // Call ir static api + CallStackRecorder callstack_recoder("array_pop"); + callstack_recoder.Record(); auto static_api_out = paddle::dialect::array_pop(input, index); - + callstack_recoder.AttachToOps(); return ToPyObject(static_api_out); } catch (...) { ThrowExceptionToPython(std::current_exception()); diff --git a/paddle/fluid/pybind/op_callstack_utils.cc b/paddle/fluid/pybind/op_callstack_utils.cc new file mode 100644 index 0000000000000..1e8e2c1630cd9 --- /dev/null +++ b/paddle/fluid/pybind/op_callstack_utils.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/pybind/op_callstack_utils.h" + +pir::Attribute CallStackRecorder::GetOpCallstackInfo() { + PyObject* traceback_str = PyUnicode_FromString("traceback"); + PyObject* traceback_module = PyImport_Import(traceback_str); + + if (NULL == traceback_module) { + Py_DECREF(traceback_str); + Py_DECREF(traceback_module); + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Failed to import traceback module while getting callstack information " + "for %s.", + api_name_)); + } + PyObject* tb = PyObject_GetAttrString(traceback_module, "extract_stack"); + PyObject* stack = PyObject_CallObject(tb, NULL); + if (NULL == stack) { + Py_DECREF(tb); + Py_DECREF(traceback_str); + Py_DECREF(traceback_module); + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "Failed to get callstack object while getting callstack information " + "for " + "%s.", + api_name_)); + } + Py_ssize_t stack_size = PyList_Size(stack); + std::vector op_callstack_infos; + for (Py_ssize_t i = 0; i < stack_size; ++i) { + PyObject* frame_summary = PyList_GetItem(stack, i); + PyObject* filename = PyObject_GetAttrString(frame_summary, "filename"); + PyObject* lineno = PyObject_GetAttrString(frame_summary, "lineno"); + PyObject* name = PyObject_GetAttrString(frame_summary, "name"); + PyObject* line = PyObject_GetAttrString(frame_summary, "line"); + PyObject* callstack_info = PyUnicode_FromFormat( + " File \"%S\", line %S, in %S", filename, lineno, name); + PyObject* callstack_source_line = PyUnicode_FromFormat(" %S", line); + op_callstack_infos.push_back( + pir::StrAttribute::get(pir::IrContext::Instance(), + std::string(PyUnicode_AsUTF8(callstack_info)))); + op_callstack_infos.push_back(pir::StrAttribute::get( + pir::IrContext::Instance(), + std::string(PyUnicode_AsUTF8(callstack_source_line)))); + Py_DECREF(callstack_info); + Py_DECREF(callstack_source_line); + Py_DECREF(filename); + Py_DECREF(lineno); + Py_DECREF(name); + Py_DECREF(line); + } + Py_DECREF(tb); + Py_DECREF(traceback_str); + Py_DECREF(traceback_module); + return pir::ArrayAttribute::get(pir::IrContext::Instance(), + op_callstack_infos); +} + +void CallStackRecorder::Record() { + auto before_insertion_point = + paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint(); + before_insertion_iterator_ = (--before_insertion_point.second); + before_insertion_block_ = before_insertion_point.first; +} + +void CallStackRecorder::AttachToOps() { + before_insertion_iterator_++; + pir::Attribute callstack_info_attr = GetOpCallstackInfo(); + pir::InsertionPoint after_insertion_point = + paddle::dialect::ApiBuilder::Instance().GetCurrentInsertionPoint(); + PADDLE_ENFORCE_EQ(before_insertion_block_, + after_insertion_point.first, + paddle::platform::errors::PreconditionNotMet( + "The block obtained before and after calling the " + "static API %s is inconsistent.", + api_name_)); + auto after_insertion_iterator = after_insertion_point.second; + for (auto block_iterator = before_insertion_iterator_; + block_iterator != after_insertion_iterator; + block_iterator++) { + block_iterator->set_attribute(paddle::framework::OpProtoAndCheckerMaker:: + OpCreationCallstackAttrName(), + callstack_info_attr); + } +} diff --git a/paddle/fluid/pybind/op_callstack_utils.h b/paddle/fluid/pybind/op_callstack_utils.h new file mode 100644 index 0000000000000..a380fd37619b6 --- /dev/null +++ b/paddle/fluid/pybind/op_callstack_utils.h @@ -0,0 +1,31 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/pir/include/core/block.h" +#include "paddle/pir/include/core/builtin_attribute.h" + +class CallStackRecorder { + public: + explicit CallStackRecorder(const std::string& api_name) + : api_name_(api_name), before_insertion_block_(nullptr) {} + pir::Attribute GetOpCallstackInfo(); + void Record(); + void AttachToOps(); + + private: + const std::string& api_name_; + pir::Block::Iterator before_insertion_iterator_; + pir::Block* before_insertion_block_; +}; diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 5d7977ce5c442..f8f1424ded243 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -64,6 +64,7 @@ class OpAttrTypeMap { }; extern PyTypeObject* g_vartype_pytype; +extern PyTypeObject* g_data_type_pytype; extern PyTypeObject* g_blockdesc_pytype; extern PyTypeObject* p_tensor_type; @@ -72,6 +73,7 @@ bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); } bool PyObject_CheckLongOrToLong(PyObject** obj) { if ((PyLong_Check(*obj) && !PyBool_Check(*obj)) || PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT + PyObject_TypeCheck(*obj, g_data_type_pytype) || // NOLINT (PyObject_TypeCheck(*obj, p_tensor_type) && // NOLINT (((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT return true; diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index 9060e158c9ed9..d19eb9c5910ef 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -125,7 +125,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/tensor_py.h" -#include "paddle/fluid/string/to_string.h" +#include "paddle/utils/string/to_string.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" @@ -931,7 +931,7 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT .def_property( "memory_optimize", [](const BuildStrategy &self) -> py::object { - if (self.memory_optimize_) { + if (self.memory_optimize_) { // NOLINT return py::cast(self.memory_optimize_.get()); } else { return py::cast(nullptr); diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 54fa9bf54f057..80ffa9ad19b90 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -23,11 +23,16 @@ #include #include +#include "paddle/common/flags.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_dialect.h" +#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h" +#include "paddle/fluid/pir/dialect/distributed/transforms/mix_to_dist_pass.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" @@ -39,34 +44,18 @@ #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_parser.h" #include "paddle/fluid/pir/dialect/operator/utils/utils.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/embedding_eltwise_layernorm_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" -#include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" -#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/multihead_matmul_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/silu_fuse_pass.h" -#include "paddle/fluid/pir/transforms/fusion/transpose_flatten_concat_fuse_pass.h" -#include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" -#include "paddle/fluid/pir/transforms/inplace_pass.h" -#include "paddle/fluid/pir/transforms/map_op_to_another_pass.h" -#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h" +#include "paddle/fluid/pir/transforms/passes.h" #include "paddle/fluid/pir/transforms/shape_optimization_pass.h" #include "paddle/fluid/pybind/control_flow_api.h" #include "paddle/fluid/pybind/eager_utils.h" #include "paddle/fluid/pybind/pybind_variant_caster.h" +#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/include/core/attribute.h" #include "paddle/pir/include/core/block.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_op.h" +#include "paddle/pir/include/core/ir_mapping.h" #include "paddle/pir/include/core/parser/ir_parser.h" #include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/type.h" @@ -78,8 +67,6 @@ #include "paddle/pir/include/pass/pass.h" #include "paddle/pir/include/pass/pass_manager.h" #include "paddle/pir/include/pass/pass_registry.h" - -#include "paddle/common/flags.h" #include "pybind11/stl.h" #ifdef PADDLE_WITH_CINN @@ -88,23 +75,26 @@ #include "paddle/cinn/hlir/framework/pir_compiler.h" #endif -#ifdef PADDLE_WITH_DNNL -#include "paddle/fluid/pir/transforms/onednn/batch_norm_act_fuse_pass.h" -#endif - namespace py = pybind11; using paddle::dialect::ApiBuilder; using paddle::dialect::DenseTensorArrayType; using paddle::dialect::DenseTensorType; +using paddle::dialect::DistDenseTensorType; using paddle::dialect::IfOp; using paddle::dialect::PyLayerOp; using paddle::dialect::SelectedRowsType; using paddle::dialect::WhileOp; +using paddle::dialect::OperationDistAttribute; +using paddle::dialect::TensorDistAttribute; + using pir::Attribute; using pir::Block; using pir::BlockArgument; using pir::BoolAttribute; +using pir::CloneOptions; +using pir::IrContext; +using pir::IrMapping; using pir::IrParser; using pir::Operation; using pir::OpOperand; @@ -116,31 +106,6 @@ using pir::Type; using pir::Value; using pybind11::return_value_policy; -USE_PIR_PASS(dead_code_elimination_pass); -USE_PIR_PASS(multihead_matmul_fuse_pass); -USE_PIR_PASS(transpose_flatten_concat_fuse_pass); -USE_PIR_PASS(fused_gemm_epilogue_pass); -USE_PIR_PASS(fused_dropout_add_pass); -USE_PIR_PASS(fused_weight_only_linear_pass); -USE_PIR_PASS(fused_linear_param_grad_add_pass); -USE_PIR_PASS(inplace_pass); -USE_PIR_PASS(replace_fetch_with_shadow_output_pass); -USE_PIR_PASS(identity_op_clean_pass); -USE_PIR_PASS(map_op_to_another_pass); -USE_PIR_PASS(matmul_scale_fuse_pass); -USE_PIR_PASS(fc_fuse_pass); -USE_PIR_PASS(silu_fuse_pass); -USE_PIR_PASS(fc_elementwise_layernorm_fuse_pass); -USE_PIR_PASS(conv2d_bn_fuse_pass); -USE_PIR_PASS(conv2d_add_fuse_pass); -USE_PIR_PASS(conv2d_add_act_fuse_pass); -USE_PIR_PASS(embedding_eltwise_layernorm_fuse_pass); -USE_PIR_PASS(fused_dot_product_attention_pass); - -#ifdef PADDLE_WITH_DNNL -USE_PIR_PASS(batch_norm_act_fuse_pass); -#endif - COMMON_DECLARE_bool(print_ir); COMMON_DECLARE_bool(pir_apply_shape_optimization_pass); @@ -206,6 +171,25 @@ std::string GetValueInfo(Value v) { return ss.str(); } +Value GetOutputValueByName(const Program &program, const std::string &name) { + auto &block = *program.block(); + pir::StrAttribute name_attr = + pir::StrAttribute::get(IrContext::Instance(), name); + Value value; + for (auto &op : block) { + if (op.isa()) { + if (op.attribute("output_name") == name_attr) { + if (value) { + PADDLE_THROW(common::errors::PreconditionNotMet( + "More than one shadow ouput named with %s found.", name)); + } + value = op.operand_source(0); + } + } + } + return value; +} + void BindProgram(py::module *m) { py::class_> program( *m, "Program", py::dynamic_attr(), R"DOC( @@ -317,6 +301,10 @@ void BindProgram(py::module *m) { [](std::shared_ptr self, int64_t random_seed) { SetProgramInt64Attr(self, "random_seed", random_seed); }) + .def("get_output_value_by_name", + [](Program &self, const std::string &name) { + return GetOutputValueByName(self, name); + }) .def("num_ops", [](Program &self) { return self.num_ops(); }); } @@ -456,6 +444,30 @@ void BindBlock(py::module *m) { }); } +void BindIrMapping(py::module *m) { + py::class_ ir_mapping(*m, "IrMapping"); + ir_mapping.def(py::init<>()) + .def("look_up", + [](IrMapping &self, Value from) { return self.Lookup(from); }) + .def("add", [](IrMapping &self, Value from, Value to) { + self.Add(from, to); + }); +} + +void BindCloneOptions(py::module *m) { + py::class_ clone_options(*m, "CloneOptions"); + clone_options.def( + "__init__", + [](CloneOptions &self, + bool clone_regions, + bool clone_operands, + bool clone_successors) { + new (&self) + CloneOptions(clone_regions, clone_operands, clone_successors); + }, + return_value_policy::reference); +} + void BindOperation(py::module *m) { py::class_ op(*m, "Operation", R"DOC( In IR, all the operation are represented by Operation, and Operation @@ -499,11 +511,22 @@ void BindOperation(py::module *m) { for (auto &pair : self.attributes()) { // SymbolAttribute is only used in PIR, no need to pass to Python if (pair.second.isa()) continue; - attrs_dict[pair.first.c_str()] = - paddle::dialect::GetAttributeData(pair.second); + if (pair.first == kAttrOpDistAttr) { + attrs_dict[pair.first.c_str()] = + pair.second.dyn_cast(); + } else { + attrs_dict[pair.first.c_str()] = + paddle::dialect::GetAttributeData(pair.second); + } } return attrs_dict; }) + .def("set_scheduling_priority", + [](Operation &self, int64_t priority) { + self.set_attribute("scheduling_priority", + pir::Int64Attribute::get( + pir::IrContext::Instance(), priority)); + }) .def("operands_source", [](Operation &self) -> py::list { py::list op_list; @@ -591,12 +614,74 @@ void BindOperation(py::module *m) { }) .def("as_while_op", [](Operation &self) { return PyWhileOp(self.dyn_cast()); }) - .def("__repr__", [](Operation &self) { - std::ostringstream print_stream; - print_stream << "Operation("; - self.Print(print_stream); - print_stream << ")"; - return print_stream.str(); + .def("__repr__", + + [](Operation &self) { + std::ostringstream print_stream; + print_stream << "Operation("; + self.Print(print_stream); + print_stream << ")"; + return print_stream.str(); + }) + .def( + "clone", + [](Operation &self, IrMapping &ir_mapping, CloneOptions options) { + auto op = self.Clone(ir_mapping, options); + return ApiBuilder::Instance().GetBuilder()->Insert(op); + }, + return_value_policy::reference) + .def("move_before", + [](Operation &self, Operation &other) { + self.MoveTo(other.GetParent(), Block::Iterator{other}); + }) + .def_property( + "callstack", + [](Operation &self) -> py::list { + py::list callstack_list; + pir::Attribute op_callstack = self.attribute( + paddle::framework::OpProtoAndCheckerMaker:: + OpCreationCallstackAttrName()); + PADDLE_ENFORCE(op_callstack.isa(), + phi::errors::PreconditionNotMet( + "The callstack of operation `%s` should be an " + "array attribute.", + self.name())); + auto op_callstack_array_attr = + op_callstack.dyn_cast(); + for (size_t i = 0; i < op_callstack_array_attr.size(); ++i) { + PADDLE_ENFORCE( + op_callstack_array_attr.at(i).isa(), + phi::errors::PreconditionNotMet( + "The callstack info of operation `%s` should be array of " + "string attribute.", + self.name())); + callstack_list.append(op_callstack_array_attr.at(i) + .dyn_cast() + .AsString()); + } + return callstack_list; + }, + [](Operation &self, + const std::vector &callstack) -> void { + std::vector op_callstack_infos; + for (auto str : callstack) { + op_callstack_infos.push_back( + pir::StrAttribute::get(pir::IrContext::Instance(), str)); + } + + self.set_attribute( + paddle::framework::OpProtoAndCheckerMaker:: + OpCreationCallstackAttrName(), + pir::ArrayAttribute::get(pir::IrContext::Instance(), + op_callstack_infos)); + }) + .def("dist_attr", [](Operation &self) { + if (self.HasAttribute(kAttrOpDistAttr)) { + return self.attribute(kAttrOpDistAttr); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("dist_attr is only for dist op.")); + } }); py::class_ block_container( *m, "Operation_BlockContainer", R"DOC( @@ -631,10 +716,13 @@ phi::DataType GetValueDtype(Value value) { } else if (value.type().isa()) { return paddle::dialect::TransToPhiDataType( value.type().dyn_cast().dtype()); + } else if (value.type().isa()) { + return paddle::dialect::TransToPhiDataType( + value.type().dyn_cast().dtype()); } else { PADDLE_THROW(phi::errors::InvalidArgument( "Currently, we can only get phi::DataType from DenseTensorType and " - "SelectedRowsType.")); + "SelectedRowsType, DistDenseTensorType.")); } } @@ -646,9 +734,11 @@ const phi::DDim &GetValueDims(Value value) { return value.type().dyn_cast().dims(); } else if (value.type().isa()) { return value.type().dyn_cast().dims(); + } else if (value.type().isa()) { + return value.type().dyn_cast().global_ddim(); } else { PADDLE_THROW(phi::errors::InvalidArgument( - "Currently, we can only get shape for dense " + "Currently, we can only get shape for dense and distdense" "tensor.")); } } @@ -685,6 +775,40 @@ pir::Value apply(Value self, py::object func) { return out; } +#define DEF_VALUE_BOOL_PROPERTY(name) \ + def_property( \ + name, \ + [](Value self) { \ + auto bool_data = self.attribute(name); \ + return !bool_data || bool_data.data(); \ + }, \ + [](Value self, bool bool_data) { \ + self.set_attribute( \ + name, BoolAttribute::get(pir::IrContext::Instance(), bool_data)); \ + }) + +#define DEF_VALUE_POINTER_PROPERTY(name) \ + def_property( \ + name, \ + [](Value self) -> py::object { \ + auto prop_ptr = self.property(name); \ + if (!prop_ptr) { \ + return py::cast(Py_None); \ + } \ + auto py_data = reinterpret_cast(prop_ptr); \ + py::object obj = py::object(py::handle(py_data), true); \ + return obj; \ + }, \ + [](Value self, py::object obj) { \ + pir::PropertiesDeleter deleter = [](void *python_obj) { \ + Py_DECREF(python_obj); \ + }; \ + PyObject *pointer_data = obj.release().ptr(); \ + pir::Property value_property(reinterpret_cast(pointer_data), \ + deleter); \ + self.set_property(name, value_property); \ + }) + void BindValue(py::module *m) { py::class_ value(*m, "Value", @@ -696,8 +820,7 @@ void BindValue(py::module *m) { The constructor of Value should not be invoked directly. Value can be automatically constructed when build network. - )DOC", - pybind11::dynamic_attr()); + )DOC"); g_ir_value_pytype = reinterpret_cast(value.ptr()); value.def(py::init<>()) .def_property_readonly( @@ -749,6 +872,20 @@ void BindValue(py::module *m) { PADDLE_THROW(phi::errors::InvalidArgument( "can't set shape when building static graph")); }) + .def_property( + "_local_shape", + [](Value self) { + if (!self.type().isa()) { + PADDLE_THROW(phi::errors::InvalidArgument( + "_local_shape is only for distdense tensor.")); + } + return phi::vectorize( + self.type().dyn_cast().local_ddim()); + }, + [](Value self, const std::vector &shape) { + PADDLE_THROW(phi::errors::InvalidArgument( + "can't set _local_shape when building static graph")); + }) .def_property( "dtype", [](Value self) { return GetValueDtype(self); }, @@ -764,30 +901,15 @@ void BindValue(py::module *m) { return true; } }) - .def_property( - "stop_gradient", - [](Value self) { - auto stop_gradient = - self.attribute(kAttrStopGradients); - return !stop_gradient || stop_gradient.data(); - }, - [](Value self, bool stop_gradient) { - self.set_attribute( - kAttrStopGradients, - BoolAttribute::get(pir::IrContext::Instance(), stop_gradient)); - }) - .def_property( - "persistable", - [](Value self) { - auto persistable = - self.attribute(kAttrIsPersistable); - return !persistable || persistable.data(); - }, - [](Value self, bool persistable) { - self.set_attribute( - kAttrIsPersistable, - BoolAttribute::get(pir::IrContext::Instance(), persistable)); - }) + .DEF_VALUE_BOOL_PROPERTY("stop_gradient") + .DEF_VALUE_BOOL_PROPERTY("trainable") + .DEF_VALUE_BOOL_PROPERTY("persistable") + .DEF_VALUE_BOOL_PROPERTY("need_clip") + .DEF_VALUE_BOOL_PROPERTY("is_distributed") + .DEF_VALUE_BOOL_PROPERTY("is_parameter") + .DEF_VALUE_POINTER_PROPERTY("optimize_attr") + .DEF_VALUE_POINTER_PROPERTY("regularizer") + .DEF_VALUE_POINTER_PROPERTY("do_model_average") .def("all_used_ops", [](Value &self) -> py::list { py::list op_list; @@ -808,8 +930,24 @@ void BindValue(py::module *m) { [](Value self) { return self.type().isa(); }) .def("is_dense_tensor_array_type", [](Value self) { return self.type().isa(); }) + .def("is_dist_dense_tensor_type", + [](Value self) { return self.type().isa(); }) + .def("value_assign", [](Value &self, Value value) { self = value; }) .def("replace_all_uses_with", [](Value self, Value value) { self.ReplaceAllUsesWith(value); }) + .def("replace_grad_users_with", + [](Value self, + Value value, + std::unordered_set &grad_ops) { + for (auto it = self.use_begin(); it != self.use_end();) { + auto use_op = it.owner(); + if (grad_ops.find(use_op) != grad_ops.end()) { + (it++)->set_source(value); + } else { + it++; + } + } + }) .def("set_type", [](Value self, Type type) { self.set_type(type); }) .def("first_use", &Value::first_use, return_value_policy::reference) .def("has_one_use", &Value::HasOneUse) @@ -829,7 +967,14 @@ void BindValue(py::module *m) { BoolAttribute::get(pir::IrContext::Instance(), true)); return out; }) - .def("__repr__", &Value2String); + .def("__repr__", &Value2String) + .def("dist_attr", [](Value &self) { + if (!self.type().isa()) { + PADDLE_THROW(phi::errors::InvalidArgument( + "dist_attr is only for distdense tensor.")); + } + return self.type().dyn_cast().tensor_dist_attr(); + }); } void BindOpOperand(py::module *m) { @@ -927,6 +1072,131 @@ void range_block_do(const Block *block, std::vector range, F fn) { } } +template +bool ExistsInMapValues(const std::map &m, V value) { + for (const auto &[k, v] : m) { + if (v == value) { + return true; + } + } + return false; +} + +std::map GetOpInplaceInfo(const pir::Operation *op) { + std::map inplace_info; + if (!op->HasTrait()) { + return inplace_info; + } + pir::IrContext *ctx = pir::IrContext::Instance(); + std::string op_name = op->name(); + if (op->attributes().count("op_name")) { + op_name = + op->attributes().at("op_name").dyn_cast().AsString(); + } + + pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); + paddle::dialect::OpYamlInfoParser yaml_parser( + op_info.GetInterfaceImpl() + ->get_op_info_(op_name), + paddle::dialect::IsLegacyOp(op_name)); + + for (size_t i = 0; i < op->num_results(); ++i) { + std::string value_name = yaml_parser.OutputNames()[i]; + if (yaml_parser.HasInplace(value_name)) { + const std::string &inplace_name = yaml_parser.InplaceName(value_name); + inplace_info[i] = yaml_parser.InputName2Id().at(inplace_name); + } + if (yaml_parser.HasView(value_name)) { + const std::string &view_name = yaml_parser.ViewName(value_name); + inplace_info[i] = yaml_parser.InputName2Id().at(view_name); + } + } + + return inplace_info; +} + +std::vector> GetOpInplaceChains(const Block *block) { + std::vector> inplace_chains; + std::map value_to_inplace_chain_index; + + for (auto &op : *block) { + pir::Walk(&op, [&](Operation *inner_op) { + auto op_inplace_info = GetOpInplaceInfo(inner_op); + for (auto &[out_idx, in_idx] : op_inplace_info) { + auto target_value = inner_op->results()[out_idx]; + auto source_value = inner_op->operands()[in_idx].source(); + VLOG(8) << "Inplace Mapping: " << Value2String(source_value) << " -> " + << Value2String(target_value); + + if (value_to_inplace_chain_index.count(source_value) == 0 && + value_to_inplace_chain_index.count(target_value) == 0) { + size_t chain_insertion_idx = inplace_chains.size(); + inplace_chains.push_back({source_value, target_value}); + value_to_inplace_chain_index.insert( + {source_value, chain_insertion_idx}); + value_to_inplace_chain_index.insert( + {target_value, chain_insertion_idx}); + } else { + PADDLE_ENFORCE_NE( + value_to_inplace_chain_index.count(source_value), + 0, + phi::errors::Unavailable("source value should be in the chain")); + PADDLE_ENFORCE_EQ(value_to_inplace_chain_index.count(target_value), + 0, + phi::errors::Unavailable( + "target value should not be in the chain")); + size_t chain_insertion_idx = + value_to_inplace_chain_index[source_value]; + inplace_chains[chain_insertion_idx].push_back(target_value); + value_to_inplace_chain_index.insert( + {target_value, chain_insertion_idx}); + } + } + }); + } + return inplace_chains; +} + +std::optional FindInplaceSource( + const std::vector> inplace_chains, + pir::Value value) { + if (value.impl() == nullptr) { + return std::nullopt; + } + for (auto &chain : inplace_chains) { + for (auto &v : chain) { + if (v == value) { + return chain[0]; + } + } + } + return std::nullopt; +} + +std::map ReplaceValueWithInplaceSource( + const std::vector> &source_domain, + std::vector *target_values, + const std::vector> inplace_chains) { + std::map replacements; + for (auto &target_value : *target_values) { + auto inplace_source = FindInplaceSource(inplace_chains, target_value); + if (!inplace_source.has_value()) { + continue; + } + for (auto &source_values : source_domain) { + if (std::find(source_values.begin(), + source_values.end(), + inplace_source.value()) != source_values.end()) { + VLOG(4) << "Replace " << Value2String(target_value) << " with " + << Value2String(inplace_source.value()); + replacements.insert({target_value, inplace_source.value()}); + target_value = inplace_source.value(); + } + } + } + return replacements; +} + std::pair, std::unordered_set> AnalysisMiddleVariable(const Program &program, const std::vector &forward_inputs, @@ -950,11 +1220,14 @@ AnalysisMiddleVariable(const Program &program, program.block(), forward_range, [&middle_values, &backward_inputs, &x_or_param](Operation *op) { - for (auto &t : op->results()) { - auto v = Value(t.Value::impl()); - if (backward_inputs.count(v) && !x_or_param.count(v)) - middle_values.push_back(v); - } + pir::Walk(op, [&](Operation *inner_op) { + for (auto &t : inner_op->results()) { + auto v = Value(t.Value::impl()); + if (backward_inputs.count(v) && !x_or_param.count(v)) { + middle_values.push_back(v); + } + } + }); }); return std::make_pair(middle_values, backward_inputs); } @@ -1107,10 +1380,26 @@ SplitedResult SplitForwardBackward( pir::IrContext *ctx = pir::IrContext::Instance(); auto forward_program = std::make_shared(ctx); auto backward_program = std::make_shared(ctx); + std::vector forward_outputs_mutable = forward_outputs; std::vector middle_values; std::unordered_set backward_inputs; + const auto &inplace_chains = GetOpInplaceChains(program.block()); std::tie(middle_values, backward_inputs) = AnalysisMiddleVariable( program, forward_in_out_values, forward_range, backward_range); + + // Replace inplace value with source value. + // NOTE(SigureMo): Why not process inplace value for forward_inputs in + // forward? + // Because all forward_inputs uses data op, after lower to kernel + // pass, the data op will following a non-inplace op shadow_feed, so we don't + // need to process inplace for forward_inputs in forward. + // Same reason for whole backward program, because all backward inputs are + // created by block kwargs, it also add a shadow_feed op after lower to kernel + // pass. + auto replacement_for_forward_middles = ReplaceValueWithInplaceSource( + {forward_params}, &middle_values, inplace_chains); + auto replacement_for_forward_outputs = ReplaceValueWithInplaceSource( + {forward_params}, &forward_outputs_mutable, inplace_chains); pir::Block &backward_block = *backward_program->block(); bool has_backward = (backward_range[1] > backward_range[0]); @@ -1135,8 +1424,13 @@ SplitedResult SplitForwardBackward( auto create_kwarg_fn = [&backward_block, &backward_inputs, &backward_value_map, + &replacement_for_forward_middles, + &replacement_for_forward_outputs, &counter](const pir::Value &v) { - if (v && backward_inputs.count(v)) { + if (v && !backward_value_map.count(v) && + (backward_inputs.count(v) || + ExistsInMapValues(replacement_for_forward_middles, v) || + ExistsInMapValues(replacement_for_forward_outputs, v))) { backward_value_map[v] = backward_block.AddKwarg( "input_" + std::to_string(counter++), v.type()); } @@ -1145,10 +1439,19 @@ SplitedResult SplitForwardBackward( auto create_output_fn_forward = [&ctx, &forward_value_map, &counter, - &forward_program](const pir::Value &v) { + &forward_program, + &forward_inputs, + &forward_params](const pir::Value &v) { if (v.impl() == nullptr) { return; } + // Skip the value that already in forward_inputs or forward_params. + if (std::find(forward_inputs.begin(), forward_inputs.end(), v) != + forward_inputs.end() || + std::find(forward_params.begin(), forward_params.end(), v) != + forward_params.end()) { + return; + } // NOTE(Aurelius84): we should skip insert ShadowOutputOp repeatedly by // calling SplitForwardBackward multi-times. std::string shadow_output_name = @@ -1202,14 +1505,14 @@ SplitedResult SplitForwardBackward( counter += 1; }; - // counter = 0; if (has_backward) { VLOG(4) << "start create backward inputs, creating keyword argument."; VLOG(4) << "Create keyword argument for backward program: fo, start with input_" << counter; - std::for_each( - forward_outputs.begin(), forward_outputs.end(), create_kwarg_fn); + std::for_each(forward_outputs_mutable.begin(), + forward_outputs_mutable.end(), + create_kwarg_fn); VLOG(4) << "Create keyword argument for backward program: fx, start with input_" << counter; @@ -1232,14 +1535,27 @@ SplitedResult SplitForwardBackward( create_kwarg_fn); VLOG(4) << "Create keyword argument for backward program end. input_" << counter; + + // Update the value map with inplace source value. + VLOG(4) << "start update inplace names"; + VLOG(4) << "replacement_for_forward_middles size is: " + << replacement_for_forward_middles.size(); + for (auto &[target, source] : replacement_for_forward_middles) { + backward_value_map[target] = backward_value_map.at(source); + } + VLOG(4) << "replacement_for_forward_outputs size is: " + << replacement_for_forward_outputs.size(); + for (auto &[target, source] : replacement_for_forward_outputs) { + backward_value_map[target] = backward_value_map.at(source); + } } - // counter = 0; VLOG(4) << "start create forward outputs, inserting set_parameter ops."; std::for_each( middle_values.begin(), middle_values.end(), create_output_fn_forward); - std::for_each( - forward_outputs.begin(), forward_outputs.end(), create_output_fn_forward); + std::for_each(forward_outputs_mutable.begin(), + forward_outputs_mutable.end(), + create_output_fn_forward); // Step2. copy backward ops . VLOG(4) << "start copy backward ops"; @@ -1250,7 +1566,6 @@ SplitedResult SplitForwardBackward( auto *cloned_op = op->Clone(backward_mapper, clone_options); backward_program->block()->push_back(cloned_op); }); - // counter = 0; VLOG(4) << "start create backward outputs, inserting set_parameter ops."; if (has_backward) { std::for_each(forward_inputs_grads.begin(), @@ -1275,20 +1590,20 @@ SplitedResult SplitForwardBackward( // construct all attributes we needed. - mapping_value(middle_values, forward_value_map, fm); // write 'fm' - mapping_value(middle_values, backward_value_map, bm); // write 'bm' - mapping_value(forward_inputs, forward_value_map, fx); // write 'fx' - mapping_value(forward_inputs, backward_value_map, bx); // write 'bx' - mapping_value(forward_params, forward_value_map, fp); // write 'fp' - mapping_value(forward_params, backward_value_map, bp); // write 'bp' - mapping_value(forward_outputs, forward_value_map, fo); // write 'fo' + mapping_value(middle_values, forward_value_map, fm); // write 'fm' + mapping_value(middle_values, backward_value_map, bm); // write 'bm' + mapping_value(forward_inputs, forward_value_map, fx); // write 'fx' + mapping_value(forward_inputs, backward_value_map, bx); // write 'bx' + mapping_value(forward_params, forward_value_map, fp); // write 'fp' + mapping_value(forward_params, backward_value_map, bp); // write 'bp' + mapping_value(forward_outputs_mutable, forward_value_map, fo); // write 'fo' mapping_value( forward_inputs_grads, backward_value_map, bx_g); // write 'bx_g' mapping_value( forward_params_grads, backward_value_map, bp_g); // write 'bp_g' mapping_value( - forward_outputs_grads, backward_value_map, bo_g); // write 'bo_g' - mapping_value(forward_outputs, backward_value_map, bo); // write 'bo' + forward_outputs_grads, backward_value_map, bo_g); // write 'bo_g' + mapping_value(forward_outputs_mutable, backward_value_map, bo); // write 'bo' mapping_value(GetNoNeedBufferValue(program.block(), backward_range), forward_value_map, no_need_buffer_values); // write 'no_need_buffers' @@ -1326,39 +1641,32 @@ pir::Type CreateSelectedRowsTypeByDenseTensor(pir::Type dense_tensor_type) { } } -void ResetShadowOutputName(pir::Operation *op, const std::string &name) { - pir::IrContext *ctx = pir::IrContext::Instance(); - if (op->isa()) { - op->set_attribute("output_name", pir::StrAttribute::get(ctx, name)); +pir::Type CreateDistDenseTensorTypeByDenseTensor( + const pir::Type &gdense_tensor_type, + const std::vector &lshape, + const phi::distributed::ProcessMesh &mesh, + const std::vector &dims_mapping) { + if (gdense_tensor_type.isa()) { + DenseTensorType type = gdense_tensor_type.dyn_cast(); + paddle::flat_hash_map partial_status; + paddle::dialect::TensorDistAttribute tensor_dist_attr = + paddle::dialect::TensorDistAttribute::get( + pir::IrContext::Instance(), mesh, dims_mapping, partial_status); + return DistDenseTensorType::get(pir::IrContext::Instance(), + type, + tensor_dist_attr, + phi::make_ddim(lshape)); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Currently, input is not a dense tensor type are not supported.")); } } -std::map GetOpInplaceInfo(const pir::Operation *op) { - std::map inplace_info; - if (!op->HasTrait()) { - return inplace_info; - } +void ResetShadowOutputName(pir::Operation *op, const std::string &name) { pir::IrContext *ctx = pir::IrContext::Instance(); - std::string op_name = op->name(); - if (op->attributes().count("op_name")) { - op_name = - op->attributes().at("op_name").dyn_cast().AsString(); - } - - pir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); - paddle::dialect::OpYamlInfoParser yaml_parser( - op_info.GetInterfaceImpl() - ->get_op_info_(op_name), - paddle::dialect::IsLegacyOp(op_name)); - - for (size_t i = 0; i < op->num_results(); ++i) { - std::string value_name = yaml_parser.OutputNames()[i]; - if (yaml_parser.HasInplace(value_name)) { - const std::string &inplace_name = yaml_parser.InplaceName(value_name); - inplace_info[i] = yaml_parser.InputName2Id().at(inplace_name); - } + if (op->isa()) { + op->set_attribute("output_name", pir::StrAttribute::get(ctx, name)); } - return inplace_info; } void BindUtils(pybind11::module *m) { @@ -1388,13 +1696,19 @@ void BindUtils(pybind11::module *m) { pir::IrContext::Instance() ->GetOrRegisterDialect(); }); + m->def("register_dist_dialect", []() { + pir::IrContext::Instance() + ->GetOrRegisterDialect(); + }); m->def("create_selected_rows_type_by_dense_tensor", CreateSelectedRowsTypeByDenseTensor); + m->def("create_dist_dense_tensor_type_by_dense_tensor", + CreateDistDenseTensorTypeByDenseTensor); m->def( "translate_to_pir", [](const ::paddle::framework::ProgramDesc &legacy_program) { std::shared_ptr ret = - std::move(paddle::TranslateLegacyProgramToProgram(legacy_program)); + paddle::TranslateLegacyProgramToProgram(legacy_program); return ret; }, R"DOC( @@ -1438,10 +1752,10 @@ void BindUtils(pybind11::module *m) { >>> print(pir_program) { - (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persistable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,4],stop_gradient:[false]} : () -> pd_op.tensor<4x4xf32> - (%1) = "pd_op.matmul" (%0, %0) {is_persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> - (%2) = "pd_op.add" (%1, %1) {is_persistable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> - (%3) = "pd_op.tanh" (%2) {is_persistable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persistable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,4],stop_gradient:[false]} : () -> builtin.tensor<4x4xf32> + (%1) = "pd_op.matmul" (%0, %0) {is_persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:false} : (builtin.tensor<4x4xf32>, builtin.tensor<4x4xf32>) -> builtin.tensor<4x4xf32> + (%2) = "pd_op.add" (%1, %1) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<4x4xf32>, builtin.tensor<4x4xf32>) -> builtin.tensor<4x4xf32> + (%3) = "pd_op.tanh" (%2) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<4x4xf32>) -> builtin.tensor<4x4xf32> } @@ -1513,45 +1827,29 @@ void BindUtils(pybind11::module *m) { >>> print(pir_program) { - (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persistable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,4],stop_gradient:[false]} : () -> pd_op.tensor<4x4xf32> - (%1) = "pd_op.matmul" (%0, %0) {is_persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:false} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> - (%2) = "pd_op.add" (%1, %1) {is_persistable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>, pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> - (%3) = "pd_op.tanh" (%2) {is_persistable:[false],stop_gradient:[false]} : (pd_op.tensor<4x4xf32>) -> pd_op.tensor<4x4xf32> + (%0) = "pd_op.data" () {dtype:(pd_op.DataType)float32,is_persistable:[false],name:"x",place:(pd_op.Place)Place(undefined:0),shape:(pd_op.IntArray)[4,4],stop_gradient:[false]} : () -> builtin.tensor<4x4xf32> + (%1) = "pd_op.matmul" (%0, %0) {is_persistable:[false],stop_gradient:[false],transpose_x:false,transpose_y:false} : (builtin.tensor<4x4xf32>, builtin.tensor<4x4xf32>) -> builtin.tensor<4x4xf32> + (%2) = "pd_op.add" (%1, %1) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<4x4xf32>, builtin.tensor<4x4xf32>) -> builtin.tensor<4x4xf32> + (%3) = "pd_op.tanh" (%2) {is_persistable:[false],stop_gradient:[false]} : (builtin.tensor<4x4xf32>) -> builtin.tensor<4x4xf32> } >>> print(mappings) - {'matmul_v2_0.tmp_0': [Value(define_op_name=pd_op.matmul, index=0, dtype=pd_op.tensor<4x4xf32>)], 'x': [Value(define_op_name=pd_op.data, index=0, dtype=pd_op.tensor<4x4xf32>)], 'tanh_0.tmp_0': [Value(define_op_name=pd_op.tanh, index=0, dtype=pd_op.tensor<4x4xf32>)], 'elementwise_add_0': [Value(define_op_name=pd_op.add, index=0, dtype=pd_op.tensor<4x4xf32>)]} + {'matmul_v2_0.tmp_0': [Value(define_op_name=pd_op.matmul, index=0, dtype=builtin.tensor<4x4xf32>)], 'x': [Value(define_op_name=pd_op.data, index=0, dtype=builtin.tensor<4x4xf32>)], 'tanh_0.tmp_0': [Value(define_op_name=pd_op.tanh, index=0, dtype=builtin.tensor<4x4xf32>)], 'elementwise_add_0': [Value(define_op_name=pd_op.add, index=0, dtype=builtin.tensor<4x4xf32>)]} )DOC"); - m->def("clear_pir_compiler_manager", []() { + m->def("clear_cinn_compilation_cache", + []() { #ifdef PADDLE_WITH_CINN - pybind11::gil_scoped_release release; - VLOG(4) << "clear PirCompilerManager and free PirCompiler resources."; - cinn::hlir::framework::PirCompilerManager::Instance().clear(); + pybind11::gil_scoped_release release; + VLOG(4) << "clear CINN CompilationCache and free BackendResource."; + cinn::hlir::framework::CompilationCache::Instance().Clear(); #endif - }); + }), + m->def("apply_mix2dist_pass", paddle::dialect::MixToDistPass); } namespace { -bool HasDynamicShape(const pir::Program &program) { - for (const auto &op : *program.block()) { - if (op.isa()) { - continue; - } - for (uint32_t i = 0; i < op.num_results(); ++i) { - if (op.result(i) && op.result(i).type()) { - auto shape_type = - op.result(i).type().dyn_cast(); - if (shape_type && shape_type.IsDynamicShape()) { - return true; - } - } - } - } - return false; -} - void ApplyCinnPass(Program &program) { // NOLINT #ifdef PADDLE_WITH_CINN cinn::dialect::ir::ApplyCinnPass(&program, [] { @@ -1579,7 +1877,8 @@ void InferSymbolicShapePass( pir::Program &program) { // NOLINT pir::IrContext *ctx = pir::IrContext::Instance(); ctx->GetOrRegisterDialect(); - if (HasDynamicShape(program) && FLAGS_pir_apply_shape_optimization_pass) { + if (pir::shape::HasDynamicShape(program) && + FLAGS_pir_apply_shape_optimization_pass) { pass_manager->AddPass(pir::CreateShapeOptimizationPass()); } } @@ -1617,8 +1916,7 @@ void BindPassManager(pybind11::module *m) { py::arg("opt_level") = 2) .def("add_pass", [](PassManager &self, const std::string &pass_name) { - self.AddPass( - std::move(pir::PassRegistry::Instance().Get(pass_name))); + self.AddPass(pir::PassRegistry::Instance().Get(pass_name)); }) .def("passes", [](PassManager &self) { @@ -1632,15 +1930,19 @@ void BindPassManager(pybind11::module *m) { .def("empty", &PassManager::empty) .def("clear", &PassManager::clear) .def("enable_ir_printing", - [](PassManager &self) { self.EnableIRPrinting(); }); + [](PassManager &self) { self.EnableIRPrinting(); }) + .def("enable_print_statistics", + [](PassManager &self) { self.EnablePrintStatistics(); }); } void BindPir(pybind11::module *module) { auto ir_module = module->def_submodule("pir"); BindProgram(&ir_module); BindBlock(&ir_module); - BindOperation(&ir_module); BindValue(&ir_module); + BindIrMapping(&ir_module); + BindCloneOptions(&ir_module); + BindOperation(&ir_module); BindOpOperand(&ir_module); BindType(&ir_module); BindAttribute(&ir_module); diff --git a/paddle/fluid/pybind/place.cc b/paddle/fluid/pybind/place.cc index e9c98f0d8b31b..e6c25413988b8 100644 --- a/paddle/fluid/pybind/place.cc +++ b/paddle/fluid/pybind/place.cc @@ -125,7 +125,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/tensor_py.h" -#include "paddle/fluid/string/to_string.h" +#include "paddle/utils/string/to_string.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f1d53f3f88750..5470f4d7ec4f2 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -78,7 +78,8 @@ limitations under the License. */ #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" #include "paddle/fluid/prim/utils/utils.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/fluid/memory/allocation/auto_growth_best_fit_allocator_v2.h" #include "paddle/fluid/memory/allocation/cuda_ipc_allocator.h" #endif #include "paddle/common/macros.h" @@ -134,6 +135,10 @@ limitations under the License. */ #include "paddle/phi/core/lod_utils.h" #include "paddle/utils/none.h" +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/fluid/pybind/dist_api.h" +#endif + #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/pybind/nccl_wrapper_py.h" #endif @@ -145,7 +150,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/tensor.h" #include "paddle/fluid/pybind/tensor_py.h" -#include "paddle/fluid/string/to_string.h" +#include "paddle/utils/string/to_string.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" @@ -223,6 +228,9 @@ PYBIND11_MAKE_OPAQUE(paddle::framework::FetchType); DECLARE_FILE_SYMBOLS(init_phi); DECLARE_FILE_SYMBOLS(kernel_dialect); +#ifdef PADDLE_WITH_DISTRIBUTE +DECLARE_FILE_SYMBOLS(dist_dialect); +#endif DECLARE_FILE_SYMBOLS(buffered_allocator); DECLARE_FILE_SYMBOLS(best_fit_allocator); DECLARE_FILE_SYMBOLS(aligned_allocator); @@ -971,12 +979,12 @@ PYBIND11_MODULE(libpaddle, m) { #endif m.def("is_cuda_graph_capturing", &platform::IsCUDAGraphCapturing); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) py::class_(m, "CUDAGraph") .def_static("begin_capture", [](platform::CUDAPlace place, int mode) { platform::BeginCUDAGraphCapture( - place, static_cast(mode)); + place, static_cast(mode)); }) .def_static("end_capture", &platform::EndCUDAGraphCapture) .def_static("gen_new_memory_pool_id", @@ -1240,7 +1248,7 @@ All parameter, weight, gradient are variables in Paddle. py::return_value_policy::reference) .def("get_bytes", [](Variable &self) { - if (self.IsType()) { + if (self.IsType()) { // NOLINT return py::bytes(*(self.GetMutable())); } else { return py::bytes( @@ -1801,7 +1809,7 @@ All parameter, weight, gradient are variables in Paddle. device_types = phi::DeviceManager::GetAllDeviceTypes(); #else VLOG(1) << string::Sprintf( - "Cannot use get_all_device_type because you have installed" + "Cannot use get_all_device_type because you have installed " "CPU/GPU version PaddlePaddle.\n" "If you want to use get_all_device_type, please try to install" "CustomDevice version " @@ -1815,8 +1823,8 @@ All parameter, weight, gradient are variables in Paddle. device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); #else VLOG(1) << string::Sprintf( - "Cannot use get_all_custom_device_type because you have installed" - "CPU/GPU version PaddlePaddle.\n" + "Cannot use get_all_custom_device_type because you have " + "installed CPU/GPU version PaddlePaddle.\n" "If you want to use get_all_custom_device_type, please try to " "install CustomDevice version " "PaddlePaddle by: pip install paddlepaddle\n"); @@ -1829,7 +1837,7 @@ All parameter, weight, gradient are variables in Paddle. devices = phi::DeviceManager::GetAllDeviceList(); #else VLOG(1) << string::Sprintf( - "Cannot use get_available_device because you have installed" + "Cannot use get_available_device because you have installed " "CPU/GPU version PaddlePaddle.\n" "If you want to use get_available_device, please try to install" "CustomDevice version " @@ -1844,8 +1852,7 @@ All parameter, weight, gradient are variables in Paddle. #else VLOG(1) << string::Sprintf( "Cannot use get_available_custom_device because you have " - "installed" - "CPU/GPU version PaddlePaddle.\n" + "installed CPU/GPU version PaddlePaddle.\n" "If you want to use get_available_custom_device, please try to " "install" "CustomDevice version " @@ -1863,8 +1870,7 @@ All parameter, weight, gradient are variables in Paddle. #else VLOG(1) << string::Sprintf( "Cannot use get_custom_device_count because you have " - "installed" - "CPU/GPU version PaddlePaddle.\n" + "installed CPU/GPU version PaddlePaddle.\n" "If you want to use get_custom_device_count, please try to " "install" "CustomDevice version " @@ -2154,6 +2160,12 @@ All parameter, weight, gradient are variables in Paddle. m.def("_cuda_synchronize", [](const platform::CUDAPlace &place) { platform::DeviceContextPool::Instance().Get(place)->Wait(); }); + m.def("_set_warmup", [](bool warmup) { +#if defined(PADDLE_WITH_CUDA) + paddle::memory::allocation::AutoGrowthBestFitAllocatorV2State::GetInstance() + .SetWarmup(warmup); +#endif + }); m.def("_test_enforce_gpu_success", []() { #if defined(PADDLE_WITH_CUDA) PADDLE_ENFORCE_GPU_SUCCESS(cudaErrorInsufficientDriver); @@ -2229,7 +2241,7 @@ All parameter, weight, gradient are variables in Paddle. const std::string &var_name, size_t index) -> py::object { auto &var = framework::GetFetchVariable(scope, var_name, index); - if (data_is_lod_tensor(var)) { + if (data_is_lod_tensor(var)) { // NOLINT return py::cast(PADDLE_GET(phi::DenseTensor, var)); } else { return py::cast(PADDLE_GET(LoDTensorArray, var)); @@ -3046,6 +3058,9 @@ All parameter, weight, gradient are variables in Paddle. BindPir(&m); BindVjp(&m); BindDecomp(&m); +#ifdef PADDLE_WITH_DISTRIBUTE + BindDistApi(&m); +#endif } } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/reader_py.cc b/paddle/fluid/pybind/reader_py.cc index 6489d815df18b..d3fb355fe4d88 100644 --- a/paddle/fluid/pybind/reader_py.cc +++ b/paddle/fluid/pybind/reader_py.cc @@ -258,8 +258,8 @@ class MultiDeviceFeedReader { kException = 2 // Exception raises when reading }; - Status WaitFutures(std::exception_ptr *excep) { - *excep = nullptr; + Status WaitFutures(std::exception_ptr *e) { + *e = nullptr; size_t success_num = 0; for (size_t i = 0; i < futures_.size(); ++i) { auto each_status = futures_[i].get(); @@ -270,7 +270,7 @@ class MultiDeviceFeedReader { platform::errors::NotFound("exceptions_[%d] is NULL, but the " "result status is Status::kException", i)); - *excep = exceptions_[i]; + *e = exceptions_[i]; exceptions_[i] = nullptr; } } else { @@ -278,7 +278,7 @@ class MultiDeviceFeedReader { } } - if (UNLIKELY(*excep)) { + if (UNLIKELY(*e)) { return Status::kException; } @@ -308,16 +308,16 @@ class MultiDeviceFeedReader { } void CheckNextStatus() { - std::exception_ptr excep; - Status status = WaitFutures(&excep); + std::exception_ptr e; + Status status = WaitFutures(&e); - if (UNLIKELY(excep)) { + if (UNLIKELY(e)) { PADDLE_ENFORCE_EQ(status, Status::kException, platform::errors::NotFound( "The exception raised is not NULL, but " "the result status is not Status::kException")); - std::rethrow_exception(excep); + std::rethrow_exception(e); } if (UNLIKELY(status == Status::kEOF)) { diff --git a/paddle/fluid/pybind/tensor.cc b/paddle/fluid/pybind/tensor.cc index ab81ddd6d3908..bf3d025b228cc 100644 --- a/paddle/fluid/pybind/tensor.cc +++ b/paddle/fluid/pybind/tensor.cc @@ -125,7 +125,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/reader_py.h" #include "paddle/fluid/pybind/tensor_py.h" -#include "paddle/fluid/string/to_string.h" +#include "paddle/utils/string/to_string.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" @@ -859,7 +859,7 @@ void BindTensor(pybind11::module &m) { // NOLINT )DOC") #endif .def("_share_filename", - [](phi::DenseTensor &self) { + [](phi::DenseTensor &self, bool use_file_descriptor) { if (!self.IsInitialized() || self.numel() == 0) throw std::runtime_error( "Tensor not initialized or numel is 0. could not pass to " @@ -886,6 +886,10 @@ void BindTensor(pybind11::module &m) { // NOLINT int flags = memory::allocation::MAPPED_SHAREDMEM | memory::allocation::MAPPED_EXCLUSIVE; + if (use_file_descriptor) { + flags = flags | memory::allocation::MAPPED_KEEPFD | + memory::allocation::MAPPED_UNLINK; + } std::string handle = memory::allocation::GetIPCName(); int find_id = -1; if (FLAGS_use_shm_cache) { @@ -894,9 +898,10 @@ void BindTensor(pybind11::module &m) { // NOLINT if (find_id != -1) { handle = memory::allocation::MemoryMapAllocationPool::Instance().GetById(find_id).file_name_; // NOLINT } + int shared_fd = -1; auto shared_holder = memory::allocation::AllocateRefcountedMemoryMapAllocation( - handle, flags, data_size, find_id); + handle, shared_fd, flags, data_size, find_id); // copy data & reset holder if (platform::is_cuda_pinned_place(holder->place())) { @@ -914,8 +919,10 @@ void BindTensor(pybind11::module &m) { // NOLINT int type_idx = static_cast(self.type()); return py::make_tuple(mmap_allocation->ipc_name(), + mmap_allocation->shared_fd(), mmap_allocation->size(), type_idx, - common::vectorize(self.dims()), self.lod()); + common::vectorize(self.dims()), self.lod(), + use_file_descriptor); }, R"DOC( Serialize CPU lod tensor in shared memory to tuple. @@ -935,30 +942,37 @@ void BindTensor(pybind11::module &m) { // NOLINT )DOC") .def("_new_shared_filename", [](py::tuple t) { // __setstate__ - if (t.size() != 5) + if (t.size() != 7) throw std::runtime_error("Invalid Tensor meta info state!"); phi::DenseTensor tensor; // 2. Rebuild Allocation const std::string &ipc_name = t[0].cast(); - size_t size = t[1].cast(); + const int shared_fd = t[1].cast(); + const bool use_file_descriptor = t[6].cast(); + + size_t size = t[2].cast(); int flags = memory::allocation::MAPPED_SHAREDMEM | memory::allocation::MAPPED_NOCREATE; + if (use_file_descriptor) { + flags = flags | memory::allocation::MAPPED_KEEPFD | + memory::allocation::MAPPED_UNLINK; + } int find_id = -1; if (FLAGS_use_shm_cache) { find_id = memory::allocation::MemoryMapAllocationPool::Instance().FindFromCache(flags, size, ipc_name, /*check_refcount*/ false); // NOLINT } auto shared_holder = memory::allocation::AllocateRefcountedMemoryMapAllocation( - ipc_name, flags, size, find_id); + ipc_name, shared_fd, flags, size, find_id); // 3. Rebuild Tensor tensor.ResetHolderWithType( shared_holder, - static_cast(t[2].cast())); - tensor.Resize(common::make_ddim(t[3].cast>())); - tensor.set_lod(t[4].cast()); + static_cast(t[3].cast())); + tensor.Resize(common::make_ddim(t[4].cast>())); + tensor.set_lod(t[5].cast()); return tensor; }, @@ -966,7 +980,7 @@ void BindTensor(pybind11::module &m) { // NOLINT Deserialize CPU lod tensor from shared memory. Params: - tuple: contrains ipc file name, data size, data type, + tuple: contains ipc file name, data size, data type, tensor dims and lod information. Examples: @@ -1073,12 +1087,19 @@ void BindTensor(pybind11::module &m) { // NOLINT self.unsafe_mutable_value()->ShareDataNoCheckWith(src.value()); return self; }) - .def("_share_data_with", [](DistTensor &self, const DistTensor &src) { - self.unsafe_set_dims(src.dims()); - self.unsafe_set_dist_attr(src.dist_attr()); - self.unsafe_mutable_value()->ShareDataWith(src.value()); - return self; - }); + .def("_share_data_with", + [](DistTensor &self, const DistTensor &src) { + self.unsafe_set_dims(src.dims()); + self.unsafe_set_dist_attr(src.dist_attr()); + if (!IsCurRankInMesh(self.process_mesh()) && + !IsCurRankInMesh(src.dist_attr().process_mesh())) { + self.unsafe_mutable_value()->ShareDataNoCheckWith(src.value()); + } else { + self.unsafe_mutable_value()->ShareDataWith(src.value()); + } + return self; + }) + .def("_clear", &DistTensor::clear); #endif py::class_(m, "SelectedRows") diff --git a/paddle/fluid/pybind/tensor_py.h b/paddle/fluid/pybind/tensor_py.h index 4e3cf9b35d78d..ba3a466fba219 100644 --- a/paddle/fluid/pybind/tensor_py.h +++ b/paddle/fluid/pybind/tensor_py.h @@ -970,14 +970,12 @@ inline py::array TensorToPyArray(const phi::DenseTensor &tensor, std::vector py_dims(rank); std::vector py_strides(rank); - size_t numel = 1; auto tensor_stride = tensor.strides(); for (int i = tensor_dims.size() - 1; i >= 0; --i) { py_dims[i] = static_cast(tensor_dims[i]); py_strides[i] = sizeof_dtype * tensor_stride[i]; - numel *= py_dims[i]; } const void *tensor_buf_ptr = tensor.data(); diff --git a/paddle/fluid/string/split.h b/paddle/fluid/string/split.h deleted file mode 100644 index d2a6f67ca75c1..0000000000000 --- a/paddle/fluid/string/split.h +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include - -#include "paddle/utils/string/split.h" diff --git a/paddle/fluid/string/to_string.h b/paddle/fluid/string/to_string.h deleted file mode 100644 index 72d9c0379fd3a..0000000000000 --- a/paddle/fluid/string/to_string.h +++ /dev/null @@ -1,16 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include "paddle/utils/string/to_string.h" diff --git a/paddle/fluid/sub_graph/sub_graph_checker.cc b/paddle/fluid/sub_graph/sub_graph_checker.cc index 0151684a8161d..42cd6bd001f0d 100644 --- a/paddle/fluid/sub_graph/sub_graph_checker.cc +++ b/paddle/fluid/sub_graph/sub_graph_checker.cc @@ -23,7 +23,7 @@ #include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h" #include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.h" -#include "paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.h" +#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h" #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" diff --git a/paddle/phi/README.md b/paddle/phi/README.md index 8151e2c078c09..07c8b0a925846 100644 --- a/paddle/phi/README.md +++ b/paddle/phi/README.md @@ -206,7 +206,7 @@ template void ScaleKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, DenseTensor* out); ``` @@ -354,7 +354,7 @@ Tensor mean(const Tensor& x); Tensor scale(const Tensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale); ``` diff --git a/paddle/phi/api/CMakeLists.txt b/paddle/phi/api/CMakeLists.txt index 1827dfbeb7f64..b06c40cf41a6e 100644 --- a/paddle/phi/api/CMakeLists.txt +++ b/paddle/phi/api/CMakeLists.txt @@ -1,2 +1,9 @@ add_subdirectory(profiler) add_subdirectory(lib) +if(WIN32) + file(GLOB YAML_FILE "${CMAKE_CURRENT_SOURCE_DIR}/yaml/*.yaml") + set_property( + DIRECTORY + APPEND + PROPERTY CMAKE_CONFIGURE_DEPENDS ${YAML_FILE}) +endif() diff --git a/paddle/phi/api/all.h b/paddle/phi/api/all.h index 93c97605f9f3f..aaafec306401a 100644 --- a/paddle/phi/api/all.h +++ b/paddle/phi/api/all.h @@ -38,8 +38,3 @@ limitations under the License. */ #include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/tensor_compat.h" - -// common headers -#include "paddle/common/ddim.h" -#include "paddle/common/exception.h" -#include "paddle/common/layout.h" diff --git a/paddle/phi/api/include/tensor.h b/paddle/phi/api/include/tensor.h index 636a4198640cd..a4ce550f9858c 100644 --- a/paddle/phi/api/include/tensor.h +++ b/paddle/phi/api/include/tensor.h @@ -142,14 +142,16 @@ class PADDLE_API Tensor final { explicit Tensor(const std::string& name) : name_(name) {} /** - * @brief Construct a new Tensor object by a TensorBase pointer and - * autograd_meta + * @brief Construct a new Tensor object by a TensorBase pointer, autograd meta + * and name * * @param tensor_impl * @param autograd_meta + * @param name */ Tensor(std::shared_ptr tensor_impl, - std::shared_ptr autograd_meta); + std::shared_ptr autograd_meta, + const std::string& name); /* Part 2: Dimension, DataType and DataLayout methods */ @@ -713,7 +715,7 @@ class PADDLE_API Tensor final { Tensor maximum(const Tensor& y) const; Tensor minimum(const Tensor& y) const; Tensor scale(const Scalar& scale = 1.0, - float bias = 0.0, + const Scalar& bias = 0.0, bool bias_after_scale = true) const; Tensor sum(const IntArray& axis = {}, DataType dtype = DataType::UNDEFINED, diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 87e6f9af43075..ef5cfc90727ff 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -24,6 +24,7 @@ PHI_DECLARE_bool(use_stride_kernel); #include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" #include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" #include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h" +#include "paddle/phi/core/kernel_factory.h" namespace paddle { namespace experimental { @@ -416,6 +417,32 @@ void TransStride(phi::DeviceContext* dev_ctx, delete from; return; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()), + phi::DataLayout::ALL_LAYOUT, + to->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + false, + *custom_ctx, + *from, + common::vectorize(to->dims()), + common::vectorize(to->strides()), + to->offset(), + to); + delete from; + return; + } #endif } } @@ -466,6 +493,31 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, })); return; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(to->place()), + phi::DataLayout::ALL_LAYOUT, + to->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + false, + *custom_ctx, + *from, + common::vectorize(to->dims()), + common::vectorize(to->strides()), + to->offset(), + to); + return; + } #endif } } @@ -520,6 +572,33 @@ void TransStride(phi::DeviceContext* dev_ctx, delete from[i]; continue; } +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto* custom_ctx = dynamic_cast(dev_ctx); + if (custom_ctx) { + const phi::KernelKey& kernel_key = { + phi::TransToPhiBackend(to[i]->place()), + phi::DataLayout::ALL_LAYOUT, + to[i]->dtype()}; + using kernel_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + kernel_key, + kernel_signature, + false, + *custom_ctx, + *from[i], + common::vectorize(to[i]->dims()), + common::vectorize(to[i]->strides()), + to[i]->offset(), + to[i]); + delete from[i]; + return; + } #endif } } diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 80bb9f4447573..01eb529a11b2c 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -255,6 +255,27 @@ phi::DenseTensor Trans2Contiguous(const phi::DenseTensor& tensor) { } else if (tensor.place().GetType() == phi::AllocationType::XPU) { auto* dev_ctx = static_cast(pool.Get(tensor.place())); return TensorContiguous(*dev_ctx, tensor); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (tensor.place().GetType() == phi::AllocationType::CUSTOM) { + auto* dev_ctx = static_cast(pool.Get(tensor.place())); + phi::DenseTensor dense_out; + phi::MetaTensor meta_input(tensor); + phi::MetaTensor meta_out(&dense_out); + UnchangedInferMeta(meta_input, &meta_out); + const phi::KernelKey& kernel_key = {phi::TransToPhiBackend(tensor.place()), + phi::DataLayout::ALL_LAYOUT, + tensor.dtype()}; + using kernel_signature = void (*)( + const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*); + PD_VISIT_KERNEL("contiguous", + kernel_key, + kernel_signature, + false, + *dev_ctx, + tensor, + &dense_out); + return dense_out; #endif } else { PADDLE_THROW(phi::errors::Unimplemented( @@ -283,7 +304,7 @@ std::vector CheckAndTrans2NewContiguousTensor( const std::vector& tensor) { std::vector out; for (auto& t : tensor) { - out.emplace_back(std::move(CheckAndTrans2NewContiguousTensor(t))); + out.emplace_back(CheckAndTrans2NewContiguousTensor(t)); } return out; } @@ -578,8 +599,7 @@ std::shared_ptr PrepareDataForDenseTensorInSparse( return std::static_pointer_cast(tensor_in); } - return std::make_shared( - std::move(Trans2Contiguous(dense_tensor))); + return std::make_shared(Trans2Contiguous(dense_tensor)); } PADDLE_THROW(phi::errors::InvalidArgument( "The impl() of input tensor is nullptr, it doesn't support for " diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index 0a37a1e763e9f..8924981d7060a 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -610,8 +610,8 @@ extern "C" { #ifndef _WIN32 // C-API to get global OpMetaInfoMap. -paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() { - return paddle::OpMetaInfoMap::Instance(); +paddle::OpMetaInfoMap* PD_GetOpMetaInfoMap() { + return &paddle::OpMetaInfoMap::Instance(); } #endif diff --git a/paddle/phi/api/lib/tensor.cc b/paddle/phi/api/lib/tensor.cc index 2ab68b2e846f2..54c949e688c79 100644 --- a/paddle/phi/api/lib/tensor.cc +++ b/paddle/phi/api/lib/tensor.cc @@ -53,8 +53,11 @@ Tensor::Tensor(std::shared_ptr tensor_impl) } Tensor::Tensor(std::shared_ptr tensor_impl, - std::shared_ptr autograd_meta) - : impl_(std::move(tensor_impl)), autograd_meta_(std::move(autograd_meta)) { + std::shared_ptr autograd_meta, + const std::string &name) + : impl_(std::move(tensor_impl)), + autograd_meta_(std::move(autograd_meta)), + name_(name) { PADDLE_ENFORCE_NOT_NULL( impl_, phi::errors::InvalidArgument("TensorImpl with nullptr is not supported")); diff --git a/paddle/phi/api/profiler/device_tracer.cc b/paddle/phi/api/profiler/device_tracer.cc index f15d6bbb88457..e1c009fa9cad0 100644 --- a/paddle/phi/api/profiler/device_tracer.cc +++ b/paddle/phi/api/profiler/device_tracer.cc @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/common/flags.h" #include "paddle/phi/core/enforce.h" -PD_DECLARE_bool(enable_host_event_recorder_hook); +PHI_DECLARE_bool(enable_host_event_recorder_hook); namespace phi { @@ -571,10 +571,10 @@ class DeviceTracerImpl : public DeviceTracer { Event *e = c->second; Event *parent = e->parent(); while (parent) { - parent->AddCudaElapsedTime(r.start_ns, r.end_ns); + parent->AddCudaElapsedTime(r.start_ns, r.end_ns); // NOLINT parent = parent->parent(); } - e->AddCudaElapsedTime(r.start_ns, r.end_ns); + e->AddCudaElapsedTime(r.start_ns, r.end_ns); // NOLINT } } for (const auto &r : mem_records_) { @@ -583,10 +583,10 @@ class DeviceTracerImpl : public DeviceTracer { Event *e = c->second; Event *parent = e->parent(); while (parent) { - parent->AddCudaElapsedTime(r.start_ns, r.end_ns); + parent->AddCudaElapsedTime(r.start_ns, r.end_ns); // NOLINT parent = parent->parent(); } - e->AddCudaElapsedTime(r.start_ns, r.end_ns); + e->AddCudaElapsedTime(r.start_ns, r.end_ns); // NOLINT } } #endif diff --git a/paddle/phi/api/profiler/profiler.cc b/paddle/phi/api/profiler/profiler.cc index 6dc419658d3c2..e9c49741a5e6b 100644 --- a/paddle/phi/api/profiler/profiler.cc +++ b/paddle/phi/api/profiler/profiler.cc @@ -77,7 +77,7 @@ double Event::CpuElapsedMs(const Event &e) const { double Event::CudaElapsedMs(const Event &e) const { #ifdef PADDLE_WITH_CUPTI - return gpu_ns_ / 1000000.0; + return static_cast(gpu_ns_) / 1000000.0; #else LOG_FIRST_N(WARNING, 1) << "CUDA CUPTI is not enabled"; return 0; diff --git a/paddle/phi/api/profiler/profiler.h b/paddle/phi/api/profiler/profiler.h index 8b789def59def..dfc304126f1c3 100644 --- a/paddle/phi/api/profiler/profiler.h +++ b/paddle/phi/api/profiler/profiler.h @@ -28,7 +28,7 @@ limitations under the License. */ #include "paddle/phi/api/profiler/event_tracing.h" #include "paddle/phi/api/profiler/supplement_tracing.h" -COMMON_DECLARE_bool(enable_host_event_recorder_hook); +PHI_DECLARE_bool(enable_host_event_recorder_hook); namespace phi { diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 772db08fd1a2e..603b65c8b4c53 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -433,9 +433,9 @@ param : [x, x] kernel : func : cos_double_grad - optional: grad_out backward : cos_triple_grad inplace : (grad_x_grad -> grad_out_grad) + composite : cos_double_grad(x, grad_out, grad_x_grad, x_grad, grad_out_grad) - backward_op : cos_grad forward : cos (Tensor x) -> Tensor(out) @@ -859,6 +859,17 @@ func : flash_attn_unpadded_grad data_type: q +- backward_op : flash_attn_with_sparse_mask_grad + forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0) + output : Tensor(q_grad), Tensor(k_grad), Tensor(v_grad) + infer_meta : + func : FlashAttnGradInferMeta + param : [q, k, v] + kernel : + func : flash_attn_with_sparse_mask_grad + data_type: q + - backward_op : flatten_grad forward : flatten(Tensor x, int start_axis = 1, int stop_axis = 1) -> Tensor(out), Tensor(xshape) args : (Tensor xshape, Tensor out_grad) @@ -1647,8 +1658,8 @@ func : mv_grad - backward_op : nanmedian_grad - forward : nanmedian (Tensor x, IntArray axis, bool keepdim) -> Tensor(out), Tensor(medians) - args : (Tensor x, Tensor medians, Tensor out_grad, IntArray axis, bool keepdim) + forward : nanmedian (Tensor x, IntArray axis, bool keepdim, str mode) -> Tensor(out), Tensor(medians) + args : (Tensor x, Tensor medians, Tensor out_grad, IntArray axis, bool keepdim, str mode) output : Tensor(x_grad) infer_meta : func : NanmedianGradInferMeta @@ -1772,6 +1783,7 @@ data_type : x backward : pow_triple_grad inplace : (grad_x_grad -> x_grad) + composite: pow_double_grad(x, grad_out, grad_x_grad, y, x_grad, grad_out_grad) - backward_op : pow_grad forward : pow(Tensor x, Scalar y=1.0f) -> Tensor(out) @@ -1786,6 +1798,7 @@ data_type : out_grad backward: pow_double_grad inplace : (out_grad -> x_grad) + composite: pow_grad(x, out_grad, y, x_grad) - backward_op : pow_triple_grad forward : pow_double_grad(Tensor x, Tensor grad_out, Tensor grad_grad_x, Scalar y) -> Tensor(grad_x), Tensor(grad_grad_out) @@ -2001,7 +2014,7 @@ inplace : (out_grad -> x_grad) - backward_op : scale_grad - forward : scale (Tensor x, Scalar scale, float bias, bool bias_after_scale) -> Tensor(out) + forward : scale (Tensor x, Scalar scale, Scalar bias, bool bias_after_scale) -> Tensor(out) args : (Tensor out_grad, Scalar scale=1.0) output : Tensor(x_grad) invoke : scale(out_grad, scale, 0.0f, true) @@ -2166,9 +2179,9 @@ param : [x, x] kernel : func : sin_double_grad - optional: grad_out backward : sin_triple_grad inplace : (grad_x_grad -> grad_out_grad) + composite : sin_double_grad(x, grad_out, grad_x_grad, x_grad, grad_out_grad) - backward_op : sin_grad forward : sin (Tensor x) -> Tensor(out) @@ -2362,6 +2375,12 @@ inplace : (out_grad -> x_grad) backward: squeeze_double_grad +- backward_op : stack_double_grad + forward : stack_grad (Tensor[] x, Tensor grad_out, int axis=0) -> Tensor[](grad_x) + args : (Tensor[] grad_x_grad, int axis = 0) + output : Tensor(grad_out_grad) + invoke : stack(grad_x_grad, axis) + - backward_op : stack_grad forward : stack (Tensor[] x, int axis) -> Tensor(out) args : (Tensor[] x, Tensor out_grad, int axis) @@ -2376,6 +2395,7 @@ data_type : out_grad no_need_buffer : x composite : stack_grad(x, out_grad, axis, x_grad) + backward: stack_double_grad - backward_op : stanh_grad forward : stanh(Tensor x, float scale_a, float scale_b) -> Tensor(out) @@ -2405,6 +2425,7 @@ infer_meta: func: SwiGLUGradInferMeta param: [x, y] + spmd_rule: SwiGLUGradInferSpmd kernel: func: swiglu_grad optional: y diff --git a/paddle/phi/api/yaml/fused_backward.yaml b/paddle/phi/api/yaml/fused_backward.yaml index 5c92b1a2a692f..36c3c0dde5191 100644 --- a/paddle/phi/api/yaml/fused_backward.yaml +++ b/paddle/phi/api/yaml/fused_backward.yaml @@ -6,7 +6,7 @@ - backward_op : fused_bias_dropout_residual_layer_norm_grad forward: fused_bias_dropout_residual_layer_norm (Tensor x, Tensor residual, Tensor bias, Tensor ln_scale, Tensor ln_bias, float dropout_rate, bool is_test, bool dropout_fix_seed, int dropout_seed, str dropout_implementation, float ln_epsilon) -> Tensor(y), Tensor(bias_dropout_residual_out), Tensor(dropout_mask_out), Tensor(ln_mean), Tensor(ln_variance) - args : (Tensor y_grad, Tensor x, Tensor residual, Tensor bias, Tensor ln_scale, Tensor ln_bias, Tensor ln_mean, Tensor ln_variance, Tensor bias_dropout_residual_out, Tensor dropout_mask_out, float dropout_rate = 0.5f, bool is_test = false, bool dropout_fix_seed = true, int dropout_seed = true, str dropout_implementation = "downgrade_in_infer", float ln_epsilon = 1e-5) + args : (Tensor x, Tensor residual, Tensor bias, Tensor ln_scale, Tensor ln_bias, Tensor ln_mean, Tensor ln_variance, Tensor bias_dropout_residual_out, Tensor dropout_mask_out, Tensor y_grad, float dropout_rate = 0.5f, bool is_test = false, bool dropout_fix_seed = true, int dropout_seed = true, str dropout_implementation = "downgrade_in_infer", float ln_epsilon = 1e-5) output : Tensor(x_grad), Tensor(residual_grad), Tensor(bias_grad), Tensor(ln_scale_grad), Tensor(ln_bias_grad) optional : bias, ln_scale, ln_bias, bias_grad, ln_scale_grad, ln_bias_grad infer_meta : @@ -14,6 +14,7 @@ kernel : func : fused_bias_dropout_residual_layer_norm_grad data_type : y_grad + support_dygraph_mode : true - backward_op : fused_dot_product_attention_grad forward : fused_dot_product_attention (Tensor q, Tensor k, Tensor v, Tensor mask, float scaling_factor, float dropout_probability, bool is_training, bool is_causal_masking) -> Tensor(out), Tensor(softmax_out), Tensor(rng_state) diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 2ca0a32be59f5..ff6969194f6d6 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -163,6 +163,7 @@ data_type : x backward : fused_bias_dropout_residual_layer_norm_grad intermediate : bias_dropout_residual_out, dropout_mask_out, ln_mean, ln_variance + support_dygraph_mode : true - op : fused_bias_residual_layernorm args : (Tensor x, Tensor bias, Tensor residual, Tensor norm_weight, Tensor norm_bias, float epsilon, float residual_alpha, int begin_norm_axis, float quant_scale, int quant_round_type, float quant_max_bound, float quant_min_bound) @@ -399,7 +400,7 @@ backward : max_pool2d_v2_grad - op : multi_encoder_xpu - args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, float[] softmax_max_value, str[] quant_types) + args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor[] roformer_embedding, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, int max_pos_len, float[] softmax_max_value, str[] quant_types) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) infer_meta : func : MultiEncoderXPUInferMeta @@ -437,6 +438,15 @@ func : quantize_xpu data_type : x +- op : roformer_relative_embedding_xpu + args : (Tensor x, Tensor sin_emb, Tensor cos_emb, int max_pos_len) + output : Tensor(out) + infer_meta : + func : RoformerRelativePosXPUInferMeta + kernel : + func : roformer_relative_embedding_xpu + data_type : x + - op : self_dp_attention args : (Tensor x, float alpha = 1.0f, int head_number = 1) output : Tensor(out) diff --git a/paddle/phi/api/yaml/generator/api_gen.py b/paddle/phi/api/yaml/generator/api_gen.py index 3e144fa27d986..59eedd4a83de4 100644 --- a/paddle/phi/api/yaml/generator/api_gen.py +++ b/paddle/phi/api/yaml/generator/api_gen.py @@ -340,9 +340,7 @@ def gene_output( ) else: raise ValueError( - "{} : Output error: only support Tensor type when use view in yaml. But get {}".format( - self.api, out_dtype_list[i] - ) + f"{self.api} : Output error: only support Tensor type when use view in yaml. But get {out_dtype_list[i]}" ) else: raise ValueError( diff --git a/paddle/phi/api/yaml/generator/dist_api_gen.py b/paddle/phi/api/yaml/generator/dist_api_gen.py index d0b82f3be9f70..ad153639c4d56 100644 --- a/paddle/phi/api/yaml/generator/dist_api_gen.py +++ b/paddle/phi/api/yaml/generator/dist_api_gen.py @@ -483,53 +483,56 @@ // API `{}` does not need to set DistAttr for output.""" # TODO(GhostScreaming): Support aliquant condition. -# Specialized Code, for example, reshape needs to calculate local_shape -RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE = """ +# Operators like `reshape`, `expand_as` need to calculate local_shape +# for their local `DenseTensor`, as the given shape in their attribute +# is global_shape for `DistTensor`. +CALCULATE_LOCAL_SHAPE_TEMPLATE = """ // The dist_input_x is a dist tensor, the dims() func return the global dims. auto x_shape = dist_input_x->dims(); auto x_numel = dist_input_x->numel(); bool visit_negative = false; - std::vector local_shape; - for (size_t i = 0; i < shape.GetData().size(); i++) { + auto global_shape = {shape}; + std::vector<{dtype}> local_shape; + for (size_t i = 0; i < global_shape.size(); i++) {{ auto& out_dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, spmd_info.second[0]); - if (out_dist_attr.dims_mapping()[i] >= 0) { - int64_t shape_i = shape.GetData()[i]; - if (shape_i == 0) { + if (out_dist_attr.dims_mapping()[i] >= 0) {{ + {dtype} shape_i = global_shape[i]; + if (shape_i == 0) {{ shape_i = x_shape[i]; - } else if (shape_i == -1) { + }} else if (shape_i == -1) {{ PADDLE_ENFORCE(not visit_negative, phi::errors::InvalidArgument( - "Reshape can only have one -1 in the shape.")); + "{op_name} can only have one -1 in the {shape_name}.")); visit_negative = true; int64_t non_negative_product = 1; - for (size_t j = 0; j < shape.GetData().size(); j++) { - if (i == j) { + for (size_t j = 0; j < global_shape.size(); j++) {{ + if (i == j) {{ continue; - } - int64_t tmp_j = shape.GetData()[j]; - if (tmp_j == 0) { + }} + int64_t tmp_j = global_shape[j]; + if (tmp_j == 0) {{ tmp_j = x_shape[j]; - } + }} non_negative_product *= tmp_j; - } + }} PADDLE_ENFORCE(x_numel % non_negative_product == 0, phi::errors::InvalidArgument("Cannot infer real shape for -1.")); shape_i = x_numel / non_negative_product; - } + }} int64_t dim = out_dist_attr.dims_mapping()[i]; int64_t mesh_dim = out_dist_attr.process_mesh().shape()[dim]; // TODO: Support aliquant condition. PADDLE_ENFORCE(shape_i % mesh_dim == 0, phi::errors::InvalidArgument( - "Reshape only support local shape dim is divisible " + "{op_name} only support local shape dim is divisible " "by the mesh dim, however local_shape[%lld] is %lld " "and shard mesh dims is %lld.", i, shape_i, mesh_dim)); local_shape.push_back(shape_i / mesh_dim); - } else { - local_shape.push_back(shape.GetData()[i]); - } - } + }} else {{ + local_shape.push_back({shape}[i]); + }} + }} """ # BaseAPI members: @@ -590,7 +593,11 @@ def parse_infer_meta(self, infer_meta_config): infer_meta['param'] = None if 'spmd_rule' not in infer_meta_config: infer_meta['spmd_rule'] = None - + # Operators like `reshape`, `expand_as` need to calculate local_shape + # for their local `DenseTensor`, as the given shape in their attribute + # is global_shape for `DistTensor`. + if 'local_shape' not in infer_meta_config: + infer_meta['local_shape'] = None return infer_meta def need_to_generate_code_for_inplace_impl(self, i): @@ -613,17 +620,6 @@ def need_to_generate_code_for_inplace_or_view_impl(self, i): i ) or self.need_to_generate_code_for_view_impl(i) - # # view output is also inlace, such case still needs - # # to create an empty DenseTensor for inplace output in pp - # def need_to_set_inplace_output_for_pp_impl(self, i): - # return (not self.need_to_generate_code_for_view_impl(i)) and self.is_inplace_output(i) - - def is_reshape_kernel(self): - return ( - "reshape" in self.kernel['func'][0] - and 'grad' not in self.kernel['func'][0] - ) - def is_inplace_output(self, i): return self.outputs['names'][i] in self.inplace_map @@ -1548,8 +1544,8 @@ def generate_infer_meta_code(self) -> str: f"{self.api} : Param of infer_meta error : {self.inputs['input_info'][param]} type is not supported." ) elif param in attr_names: - # TODO(GhostScreaming): reshape kernel need specialized process - if self.is_reshape_kernel() and param == "shape": + # TODO(GhostScreaming): kernel like reshape need calculate local_shape + if self.infer_meta['local_shape'] is not None: input_args_code = input_args_code + "local_shape" + ", " else: input_args_code = input_args_code + param + ", " @@ -1582,9 +1578,24 @@ def generate_infer_meta_code(self) -> str: output_args_code = output_args_code[:-2] infer_meta_code = "" - # TODO(GhostScreaming): reshape kernel need specialized process - if self.is_reshape_kernel(): - infer_meta_code = RESHAPE_CALCULATE_LOCAL_SHAPE_TEMPLATE + # TODO(GhostScreaming): kernel like reshape need calculate local_shape + if self.infer_meta['local_shape'] is not None: + shape_name = self.infer_meta['local_shape'] + assert ( + shape_name in self.attrs['names'] + ), f"Auto Parallel will calculate local_shape {shape_name} for" + "operator {self.kernel['func'][0]}, but {shape_name} is not" + "found in its attributes." + shape_type = self.attrs['attr_info'][shape_name][0] + + infer_meta_code = CALCULATE_LOCAL_SHAPE_TEMPLATE.format( + shape=f"{shape_name}.GetData()" + if shape_type == "IntArray" + else f"{shape_name}", + dtype="int64_t" if shape_type == "IntArray" else "int", + op_name=self.kernel['func'][0], + shape_name=shape_name, + ) infer_meta_code = infer_meta_code + INFER_META_TEMPLATE.format( infer_meta_func_code, input_args_code, output_args_code ) @@ -1637,8 +1648,8 @@ def generate_kernel_call_code(self) -> str: elif arg in attr_names: if 'IntArray' in self.attrs['attr_info'][arg][0]: kernel_args_type_list.append('const phi::IntArray&') - # TODO(GhostScreaming): reshape kernel need specialized process - if self.is_reshape_kernel() and arg == "shape": + # TODO(GhostScreaming): kernel like reshape need calculate local_shape + if self.infer_meta['local_shape'] is not None: arg = 'phi::IntArray(local_shape)' else: arg = 'phi::IntArray(' + arg + ')' diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index e5529aa6c5efa..8478e3caec98c 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -175,15 +175,15 @@ - backward_op : divide_double_grad forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y) - args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) + args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1) output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad) infer_meta : func : GeneralTernaryGradInferMeta - param : [y, grad_x, grad_x] + param : [y, out, out] kernel : func : divide_double_grad data_type : out - optional : grad_x_grad, grad_y_grad + optional : grad_x, grad_x_grad, grad_y_grad inplace : (grad_x_grad -> grad_out_grad) - backward_op : divide_grad @@ -381,6 +381,7 @@ kernel : func : maximum_grad composite : maximum_grad(x, y, out_grad, x_grad, y_grad) + backward : maximum_double_grad - backward_op : mean_double_grad forward: mean_grad (Tensor x, Tensor grad_out, IntArray axis={}, bool keepdim=false, bool reduce_all = false) -> Tensor(grad_x) @@ -421,6 +422,7 @@ kernel : func : minimum_grad composite : minimum_grad(x, y, out_grad, axis, x_grad, y_grad) + backward : minimum_double_grad - backward_op : mish_grad forward : mish (Tensor x, float lambda) -> Tensor(out) @@ -876,6 +878,19 @@ func : fused_gemm_epilogue_grad optional : reserve_space +- backward_op: maximum_double_grad + forward: maximum_grad(Tensor x, Tensor y, Tensor grad_out) -> Tensor(grad_x), Tensor(grad_y) + args: (Tensor x, Tensor y, Tensor grad_x_grad, Tensor grad_y_grad) + output: Tensor(grad_out_grad) + composite: maximum_double_grad(x, y, grad_x_grad, grad_y_grad, grad_out_grad) + +- backward_op: minimum_double_grad + forward: minimum_grad(Tensor x, Tensor y, Tensor grad_out) -> Tensor(grad_x), Tensor(grad_y) + args: (Tensor x, Tensor y, Tensor grad_x_grad, Tensor grad_y_grad) + output: Tensor(grad_out_grad) + composite: minimum_double_grad(x, y, grad_x_grad, grad_y_grad, grad_out_grad) + optional : grad_x_grad, grad_y_grad + - backward_op: unpool_grad forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out) args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 9b1d862180903..142814e1cc01e 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -277,6 +277,16 @@ data_type : x backward : conv2d_transpose_grad +- op : conv2d_transpose_bias + args : (Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") + output : Tensor(out) + infer_meta : + func : Conv2dTransposeInferMeta + param: [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format] + kernel : + func : conv2d_transpose_bias + data_type : x + - op : copy_to args : (Tensor x, Place place, bool blocking) output : Tensor(out) @@ -592,6 +602,16 @@ backward: fused_gemm_epilogue_grad optional: reserve_space +- op : fused_multi_transformer + args : (Tensor x, Tensor[] ln_scales, Tensor[] ln_biases, Tensor[] qkv_weights, Tensor[] qkv_biases, Tensor[] cache_kvs, Tensor[] pre_caches, Tensor rotary_tensor, Tensor time_step, Tensor seq_lengths, Tensor src_mask, Tensor[] out_linear_weights, Tensor[] out_linear_biases, Tensor[] ffn_ln_scales, Tensor[] ffn_ln_biases, Tensor[] ffn1_weights, Tensor[] ffn1_biases, Tensor[] ffn2_weights, Tensor[] ffn2_biases, bool pre_layer_norm = true, float epsilon = 1e-5, float dropout_rate = .5f, int rotary_emb_dims = 0, bool is_test = false, str dropout_implementation = "downgrade_in_infer", str act_method = "gelu", bool trans_qkvw =true, int ring_id = -1) + optional : qkv_biases, cache_kvs, pre_caches, rotary_tensor, time_step, seq_lengths, src_mask, out_linear_biases, ffn1_biases, ffn2_biases, cache_kv_outs + output : Tensor[](cache_kv_outs){out_linear_weights.size()}, Tensor(out) + infer_meta : + func : FusedMultiTransformerInferMeta + kernel : + func : fused_multi_transformer + data_type : x + - op : fused_softmax_mask args : (Tensor x, Tensor mask) output : Tensor(out) @@ -985,6 +1005,7 @@ infer_meta : func : ReshapeWithXShapeInferMeta spmd_rule : ReshapeInferSpmdDynamic + local_shape: shape kernel : func : reshape inplace : (x -> out) @@ -1078,6 +1099,7 @@ kernel : func : split backward : split_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : split_with_num args : (Tensor x, int num, Scalar(int) axis) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index 74263a1dd522d..0dbc54962da98 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -329,6 +329,19 @@ outputs : {auc : AUC, stat_pos_out : StatPosOut, stat_neg_out : StatNegOut} +- op : barrier + inputs : + {x : X} + outputs : + out : Out + +- op : batch_fc + backward : batch_fc_grad + inputs : + {input : Input, w : W, bias : Bias} + outputs : + out : Out + - op : batch_norm backward : batch_norm_grad, batch_norm_double_grad(batch_norm_grad_grad) inputs: @@ -471,6 +484,12 @@ outputs : {softmax : Softmax, loss : Loss} +- op : c_split + inputs : + x : X + outputs : + out : Out + - op : cast inputs : x : X @@ -617,6 +636,20 @@ str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f, int workspace_size_MB = phi::backends::gpu::GetDefaultConvWorkspaceSizeLimitMB()] +- op : conv2d_transpose_bias + inputs : + {x : Input, filter : Filter, bias : Bias} + outputs : + out : Output + int_array : + output_size : + data_type : int + support_tensor : true + extra : + attrs : [bool is_test = false, bool use_cudnn = false, bool use_mkldnn = true, bool force_fp32_output = false, + str mkldnn_data_type = "float32", bool fuse_relu = false, + str fuse_activation = "", float fuse_alpha = 0.0f, float fuse_beta = 0.0f] + - op : conv3d backward : conv3d_grad, conv3d_double_grad (conv3d_grad_grad) inputs : @@ -823,6 +856,10 @@ out : Out - op : distributed_push_sparse + inputs : + {ids : Ids, shows : Shows, clicks: Clicks} + outputs : + output : Outputs extra : attrs : ['int[] slots = {}'] @@ -1230,6 +1267,15 @@ data_type : float support_tensor : true +- op : fused_adam_(fused_adam) + inputs : + {params : Params, grads : Grads, learning_rate : LearningRate, moments1 : Moments1, + moments2 : Moments2, beta1_pows : Beta1Pows, beta2_pows : Beta2Pows, master_params : MasterParams, + skip_update : SkipUpdate} + outputs : + {params_out : ParamsOut, moments1_out : Moments1Out, moments2_out : Moments2Out, + beta1_pows_out : Beta1PowsOut, beta2_pows_out : Beta2PowsOut, master_params_out : MasterParamsOut} + - op : fused_attention backward: fused_attention_grad inputs: @@ -1445,6 +1491,10 @@ {x_grad : DX, y_grad : DY, bias_grad : DBias} - op : fused_transpose + inputs: + {x : X} + outputs : + {out : Out} extra : attrs : [str data_format = "AnyLayout"] @@ -1467,6 +1517,26 @@ attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", float Scale_data = 1.0f, float Shift_data = 0.0f, 'float[] Scale_weights = {1.0f}'] - op : fusion_lstm + inputs : + x : X + h0 : H0 + weight_x : WeightX + weight_h : WeightH + bias : Bias + c0 : C0 + outputs : + out : Out + hidden : Hidden + cell : Cell + xx : XX + batched_input : BatchedInput + batched_hidden : BatchedHidden + batched_cell : BatchedCell + reordered_h0 : ReorderedH0 + reordered_c0 : ReorderedC0 + checked_cell : CheckedCell + attrs : + {scale_data : Scale_data, shift_data : Shift_data, scale_weights : Scale_weights} extra : attrs : [bool use_mkldnn = true, str mkldnn_data_type = "float32"] @@ -1560,6 +1630,12 @@ attrs : {pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN} +- op : global_scatter + inputs : + {x : X} + outputs : + out : Out + - op : grad_add inputs : {x : X, y : Y} @@ -2421,8 +2497,31 @@ extra : attrs : [bool use_mkldnn = false] +- op : partial_allgather + inputs : + x : X + outputs : + out : Out + +- op : partial_concat + backward : partial_concat_grad + inputs : + x : X + outputs : + out : Out + extra : + attrs : [bool use_mkldnn = false] + +- op : partial_recv + outputs : + out : Out + - op : partial_sum backward : partial_sum_grad + inputs : + x : X + outputs : + out : Out extra : attrs : [bool use_mkldnn = false] @@ -2542,6 +2641,12 @@ outputs : out : Out +- op : push_dense + inputs : + ids : Ids + attrs : + {table_id : TableId, scale_data_norm : ScaleDataNorm, input_names: InputNames} + - op : push_sparse_v2 inputs : { x : Ids, W : w} @@ -2795,6 +2900,9 @@ scale : data_type : float tensor_name : ScaleTensor + bias : + data_type : float + support_tensor : false extra : attrs : [bool use_mkldnn = false] @@ -3117,7 +3225,7 @@ outputs : [xshape] - op : stack - backward : stack_grad + backward : stack_grad, stack_double_grad inputs : x : X outputs : @@ -3489,6 +3597,12 @@ outputs : out: Out +- op: c_allreduce_avg + inputs : + x : X + outputs : + out: Out + - op: c_allreduce_max inputs : x : X @@ -3525,12 +3639,30 @@ outputs : out: Out +- op: c_reduce_avg + inputs : + x : X + outputs : + out: Out + +- op: c_reduce_max + inputs : + x : X + outputs : + out: Out + - op: c_reduce_min inputs : x : X outputs : out: Out +- op: c_reduce_prod + inputs : + x : X + outputs : + out: Out + - op: c_reduce_sum inputs : x : X @@ -3543,6 +3675,12 @@ outputs : out: Out +- op: c_scatter + inputs : + x : X + outputs : + out: Out + - op: c_sync_calc_stream inputs : x : X @@ -3575,6 +3713,12 @@ multi_level_rois_num: MultiLevelRoIsNum restore_index: RestoreIndex +- op: distributed_fused_lamb_init + inputs: + {param: Param, grad: Grad} + outputs: + {fp32_fused_param: FP32FusedParam, fp32_fused_grad: FP32FusedGrad, fp16_fused_param: FP16FusedParam, fp16_fused_grad: FP16FusedGrad, moment1: Moment1, moment2: Moment2, beta1_pow: Beta1Pow, beta2_pow: Beta2Pow, fused_param_offsets: FusedParamOffsets, fp32_shard_fused_param_offsets: FP32ShardFusedParamOffsets, fp16_shard_fused_param_offsets: FP16ShardFusedParamOffsets, param_info: ParamInfo, param_order: ParamOrder, param_out: ParamOut, master_param_out: MasterParamOut, grad_out: GradOut, global_scale: GlobalScale, step: Step} + - op: distributed_lookup_table inputs: {ids: Ids, w: W} @@ -3610,6 +3754,33 @@ outputs : {out : Out, intermediate_out : IntermediateOut} +- op: fused_matmul + inputs : + {x: X, y: Y, residual_data: ResidualData} + outputs : + {out : Out} + attrs : + {scale_x : Scale_x, scale_y : Scale_y, scale_out : Scale_out, scale_in_eltwise : Scale_in_eltwise, fused_reshape_x : fused_reshape_X, fused_transpose_x : fused_transpose_X, fused_reshape_y : fused_reshape_Y, fused_transpose_y : fused_transpose_Y, fused_reshape_out : fused_reshape_Out, fused_transpose_out : fused_transpose_Out} + +- op: fused_softmax_mask + backward : fused_softmax_mask_grad + inputs : + {x: X, mask: Mask} + outputs : + {out : Out} + +- op: fused_softplus + inputs : + {x: X} + outputs : + {out : Out} + +- op: fused_token_prune + inputs : + {attn: Attn, x: X, mask: Mask, new_mask: NewMask} + outputs : + {slimmed_x : SlimmedX, cls_inds : CLSInds} + - op: fusion_squared_mat_sub inputs : x : X @@ -3638,6 +3809,10 @@ outputs : {param_out: ParamOut, velocity_out: VelocityOut, master_param_out: MasterParamOut} +- op: limit_by_capacity + outputs : + out : Out + - op: lod_array_length inputs : {x: X} @@ -3685,6 +3860,12 @@ outputs: {cost : Cost, sample_logits : SampleLogits, sample_labels : SampleLabels} +- op: nop + inputs : + x : X + outputs : + out : Out + - op: number_count inputs : {numbers: numbers} @@ -3695,6 +3876,27 @@ inputs : x : X +- op: prune_gate_by_capacity + inputs: + {gate_idx: GateIdx, expert_count: ExpertCount} + outputs: + new_gate_idx: NewGateIdx + +- op: random_routing + inputs: + {prob : Prob, topk_value : TopK_Value, topk_idx : TopK_Idx} + outputs: + out : Out + +- op: rank_attention + backward: rank_attention_grad + inputs: + {x : X, rank_offset : RankOffset, rank_param : RankParam} + outputs: + {input_help : InputHelp, out : Out, ins_rank: InsRank} + attrs: + {max_rank : MaxRank, max_size : MaxSize} + - op: read_from_array inputs: array : X diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index cf3986cae89e0..918cbb980d00f 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -152,6 +152,7 @@ output : Tensor(out) infer_meta : func : ArgMinMaxInferMeta + spmd_rule : ArgMaxInferSpmdDynamic kernel : func : argmax data_type : x @@ -207,7 +208,6 @@ func : as_strided backward : as_strided_grad no_need_buffer : input - interfaces : paddle::dialect::InferSymbolicShapeInterface - op : asgd_ args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor d, Tensor y, Tensor n, Tensor master_param, bool multi_precision=false) @@ -327,6 +327,7 @@ backward : bicubic_interp_grad data_transform : skip_transform : out_size, size_tensor, scale_tensor + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bilinear args : (Tensor x, Tensor y, Tensor weight, Tensor bias) @@ -350,6 +351,7 @@ backward : bilinear_interp_grad data_transform : skip_transform : out_size, size_tensor, scale_tensor + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bincount args: (Tensor x, Tensor weights, Scalar(int) minlength = 0) @@ -602,6 +604,7 @@ func : conv2d data_type : input backward : conv2d_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : conv3d args : (Tensor input, Tensor filter, int[] strides={1, 1, 1}, int[] paddings={0, 0, 0}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1, 1}, str data_format="NCDHW") @@ -612,6 +615,7 @@ func : conv3d data_type : input backward : conv3d_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : conv3d_transpose args : (Tensor x, Tensor filter, int[] strides={1, 1, 1}, int[] paddings={0, 0, 0}, int[] output_padding={}, int[] output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1, 1}, str data_format="NCHW") @@ -803,6 +807,7 @@ func : digamma inplace: (x -> out) backward : digamma_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : dirichlet args: (Tensor alpha) @@ -940,13 +945,13 @@ func : expand data_type : x backward : expand_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface - op : expand_as args : (Tensor x, Tensor y, int[] target_shape = {}) output : Tensor(out) infer_meta : func : ExpandAsInferMeta + local_shape: target_shape kernel : func : expand_as data_type : x @@ -1037,6 +1042,7 @@ func : flash_attn data_type : q backward : flash_attn_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : flash_attn_unpadded args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") @@ -1051,6 +1057,18 @@ intermediate : softmax_lse, seed_offset backward : flash_attn_unpadded_grad +- op : flash_attn_with_sparse_mask + args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") + output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) + optional : fixed_seed_offset + infer_meta : + func : FlashAttnInferMeta + param : [q, k, v] + kernel : + func : flash_attn_with_sparse_mask + data_type : q + backward : flash_attn_with_sparse_mask_grad + - op : flatten args : (Tensor x, int start_axis = 1, int stop_axis = 1) output : Tensor(out), Tensor(xshape) @@ -1642,6 +1660,7 @@ backward : linear_interp_grad data_transform : skip_transform : out_size, size_tensor, scale_tensor + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : llm_int8_linear args : (Tensor x, Tensor weight, Tensor bias, Tensor weight_scale, float threshold=6.0) @@ -2032,13 +2051,12 @@ backward : mv_grad - op : nanmedian - args : (Tensor x, IntArray axis = {}, bool keepdim = true) + args : (Tensor x, IntArray axis = {}, bool keepdim = true, str mode="avg") output : Tensor(out), Tensor(medians) infer_meta : func : NanmedianInferMeta kernel : func : nanmedian - intermediate : medians backward : nanmedian_grad - op : nearest_interp @@ -2053,6 +2071,7 @@ backward : nearest_interp_grad data_transform : skip_transform : out_size, size_tensor, scale_tensor + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : nextafter args : (Tensor x, Tensor y) @@ -2416,7 +2435,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : scale - args : (Tensor x, Scalar scale=1.0, float bias=0.0, bool bias_after_scale=true) + args : (Tensor x, Scalar scale=1.0, Scalar bias=0.0, bool bias_after_scale=true) output : Tensor(out) infer_meta : func : UnchangedInferMeta @@ -2764,6 +2783,7 @@ output : Tensor(out) infer_meta: func: SwiGLUInferMeta + spmd_rule: SwiGLUInferSpmd kernel: func : swiglu optional : y @@ -2897,6 +2917,7 @@ backward : trilinear_interp_grad data_transform : skip_transform : out_size, size_tensor, scale_tensor + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : trunc args : (Tensor input) @@ -2907,6 +2928,7 @@ func : trunc inplace: (input -> out) backward : trunc_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : unbind args : (Tensor input, int axis = 0) diff --git a/paddle/phi/api/yaml/sparse_ops.yaml b/paddle/phi/api/yaml/sparse_ops.yaml index fdebffcc4f06c..56e952623a150 100644 --- a/paddle/phi/api/yaml/sparse_ops.yaml +++ b/paddle/phi/api/yaml/sparse_ops.yaml @@ -102,8 +102,7 @@ args : (Tensor x, DataType index_dtype=DataType::UNDEFINED, DataType value_dtype=DataType::UNDEFINED) output : Tensor(out) infer_meta : - func : CastInferMeta - param: [x, value_dtype] + func : sparse::CastInferMeta kernel : func : cast_coo{sparse_coo -> sparse_coo}, cast_csr{sparse_csr -> sparse_csr} diff --git a/paddle/phi/api/yaml/static_ops.yaml b/paddle/phi/api/yaml/static_ops.yaml index 6ff2bfe427122..de355233456d7 100755 --- a/paddle/phi/api/yaml/static_ops.yaml +++ b/paddle/phi/api/yaml/static_ops.yaml @@ -123,6 +123,17 @@ optional : bias backward : conv2d_transpose_grad +- op : conv2d_transpose_bias + args : (Tensor x, Tensor filter, Tensor bias, int[] strides={1, 1}, int[] paddings={0, 0}, int[] output_padding={}, IntArray output_size={}, str padding_algorithm="EXPLICIT", int groups=1, int[] dilations={1, 1}, str data_format="NCHW") + output : Tensor(out) + infer_meta : + func : Conv2dTransposeInferMeta + param : [x, filter, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format] + kernel : + func : conv2d_transpose_bias + param : [x, filter, bias, strides, paddings, output_padding, output_size, padding_algorithm, groups, dilations, data_format] + data_type : x + - op : decode_jpeg args : (Tensor x, str mode = "unchanged") output : Tensor(out) diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index 50da99217b153..80d5f14e627a3 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -14,7 +14,7 @@ if(WITH_GPU OR WITH_ROCM) list(APPEND BACKENDS_SRCS gpu/cuda/cuda_info.cc gpu/cuda/cuda_graph.cc) endif() if(WITH_ROCM) - list(APPEND BACKENDS_SRCS gpu/rocm/rocm_info.cc) + list(APPEND BACKENDS_SRCS gpu/rocm/rocm_info.cc gpu/rocm/hip_graph.cc) endif() endif() diff --git a/paddle/phi/backends/c_comm_lib.h b/paddle/phi/backends/c_comm_lib.h index 3405b2f33bb58..b21ad1b7fedfe 100644 --- a/paddle/phi/backends/c_comm_lib.h +++ b/paddle/phi/backends/c_comm_lib.h @@ -29,17 +29,6 @@ typedef void* CCLComm; typedef std::vector CCLRootId; enum CCLReduceOp { SUM = 0, AVG, MAX, MIN, PRODUCT }; -enum CCLDataType { - CCL_DATA_TYPE_FP64 = 0, - CCL_DATA_TYPE_FP32, - CCL_DATA_TYPE_FP16, - CCL_DATA_TYPE_BF16, - CCL_DATA_TYPE_INT64, - CCL_DATA_TYPE_INT32, - CCL_DATA_TYPE_INT16, - CCL_DATA_TYPE_INT8, - CCL_DATA_TYPE_UINT8 -}; inline CCLReduceOp ToXCCLReduceOp(int reduce_type) { phi::ccl::CCLReduceOp red_type = phi::ccl::CCLReduceOp::SUM; @@ -67,51 +56,6 @@ inline CCLReduceOp ToXCCLReduceOp(int reduce_type) { return red_type; } -inline CCLDataType ToCCLDataType(phi::DataType type) { - if (type == phi::DataType::FLOAT64) { - return CCL_DATA_TYPE_FP64; - } else if (type == phi::DataType::FLOAT32) { - return CCL_DATA_TYPE_FP32; - } else if (type == phi::DataType::FLOAT16) { - return CCL_DATA_TYPE_FP16; - } else if (type == phi::DataType::BFLOAT16) { - return CCL_DATA_TYPE_BF16; - } else if (type == phi::DataType::INT64) { - return CCL_DATA_TYPE_INT64; - } else if (type == phi::DataType::INT32) { - return CCL_DATA_TYPE_INT32; - } else if (type == phi::DataType::INT8) { - return CCL_DATA_TYPE_INT8; - } else if (type == phi::DataType::UINT8) { - return CCL_DATA_TYPE_UINT8; - } else { - PADDLE_THROW( - phi::errors::Unimplemented("This datatype %s in CCL is not supported.", - phi::DataTypeToString(type))); - } -} - -inline phi::DataType ToPhiDataType(CCLDataType type) { - if (type == CCLDataType::CCL_DATA_TYPE_FP64) { - return phi::DataType::FLOAT64; - } else if (type == CCLDataType::CCL_DATA_TYPE_FP32) { - return phi::DataType::FLOAT32; - } else if (type == CCLDataType::CCL_DATA_TYPE_FP16) { - return phi::DataType::FLOAT16; - } else if (type == CCLDataType::CCL_DATA_TYPE_BF16) { - return phi::DataType::BFLOAT16; - } else if (type == CCLDataType::CCL_DATA_TYPE_INT64) { - return phi::DataType::INT64; - } else if (type == CCLDataType::CCL_DATA_TYPE_INT32) { - return phi::DataType::INT32; - } else if (type == CCLDataType::CCL_DATA_TYPE_INT8) { - return phi::DataType::INT8; - } else { - PADDLE_THROW( - phi::errors::Unimplemented("This datatype in CCL is not supported.")); - } -} - inline std::string SerializeXCCLUniqueId(const phi::ccl::CCLRootId& ccl_id) { const uint8_t* bytes = ccl_id.data(); std::ostringstream oss; diff --git a/paddle/phi/backends/custom/custom_device.cc b/paddle/phi/backends/custom/custom_device.cc index 4e2108cbbd9e4..624aabeffaba7 100644 --- a/paddle/phi/backends/custom/custom_device.cc +++ b/paddle/phi/backends/custom/custom_device.cc @@ -534,8 +534,8 @@ class CustomDevice : public DeviceInterface { if (pimpl_->device_extra_padding_size) { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->device_extra_padding_size(device, &padding_size)); - VLOG(10) << Type() << " extra padding size " << (padding_size >> 20) - << "M"; + VLOG(10) << Type() << " extra padding size:" << padding_size; + return padding_size; } else { return DeviceInterface::GetExtraPaddingSize(dev_id); } @@ -569,29 +569,6 @@ class CustomDevice : public DeviceInterface { return version; } - C_DataType ToXCCLDataType(ccl::CCLDataType data_type) { -#define return_result(in, ret) \ - case ccl::CCLDataType::in: \ - return C_DataType::ret - switch (data_type) { - return_result(CCL_DATA_TYPE_FP64, FLOAT64); - return_result(CCL_DATA_TYPE_FP32, FLOAT32); - return_result(CCL_DATA_TYPE_FP16, FLOAT16); - return_result(CCL_DATA_TYPE_BF16, BFLOAT16); - return_result(CCL_DATA_TYPE_INT64, INT64); - return_result(CCL_DATA_TYPE_INT32, INT32); - return_result(CCL_DATA_TYPE_INT16, INT16); - return_result(CCL_DATA_TYPE_INT8, INT8); - return_result(CCL_DATA_TYPE_UINT8, UINT8); - default: { - PADDLE_THROW(phi::errors::Unavailable( - "DataType is not supported on %s.", Type())); - return C_DataType::UNDEFINED; - } - } -#undef return_result - } - C_CCLReduceOp ToXCCLReduceOp(ccl::CCLReduceOp reduce_op) { #define return_result(in, ret) \ case ccl::CCLReduceOp::in: \ @@ -615,13 +592,21 @@ class CustomDevice : public DeviceInterface { case in: \ return C_DataType::ret switch (data_type) { - return_result(phi::DataType::FLOAT64, FLOAT64); - return_result(phi::DataType::FLOAT32, FLOAT32); - return_result(phi::DataType::FLOAT16, FLOAT16); - return_result(phi::DataType::INT64, INT64); - return_result(phi::DataType::INT32, INT32); - return_result(phi::DataType::INT16, INT16); + return_result(phi::DataType::BOOL, BOOL); + return_result(phi::DataType::UINT8, UINT8); + return_result(phi::DataType::UINT16, UINT16); + return_result(phi::DataType::UINT32, UINT32); + return_result(phi::DataType::UINT64, UINT64); return_result(phi::DataType::INT8, INT8); + return_result(phi::DataType::INT16, INT16); + return_result(phi::DataType::INT32, INT32); + return_result(phi::DataType::INT64, INT64); + return_result(phi::DataType::FLOAT16, FLOAT16); + return_result(phi::DataType::FLOAT32, FLOAT32); + return_result(phi::DataType::FLOAT64, FLOAT64); + return_result(phi::DataType::BFLOAT16, BFLOAT16); + return_result(phi::DataType::COMPLEX64, COMPLEX64); + return_result(phi::DataType::COMPLEX128, COMPLEX128); default: { PADDLE_THROW(phi::errors::Unavailable( "DataType is not supported on %s.", Type())); @@ -666,10 +651,16 @@ class CustomDevice : public DeviceInterface { pimpl_->xccl_destroy_comm(reinterpret_cast(comm))); } + void CCLCommName(ccl::CCLComm comm, char* comm_name) { + CHECK_PTR(pimpl_->xccl_get_comm_name); + PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_get_comm_name( + reinterpret_cast(comm), comm_name)); + } + void CCLAllReduce(void* send_buf, void* recv_buf, size_t count, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& comm, const stream::Stream& stream) override { @@ -678,7 +669,7 @@ class CustomDevice : public DeviceInterface { send_buf, recv_buf, count, - ToXCCLDataType(data_type), + ToCDatatType(data_type), ToXCCLReduceOp(op), reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -686,7 +677,7 @@ class CustomDevice : public DeviceInterface { void CCLBroadcast(void* buf, size_t count, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t root, const ccl::CCLComm& comm, const stream::Stream& stream) override { @@ -694,7 +685,7 @@ class CustomDevice : public DeviceInterface { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_broadcast( buf, count, - ToXCCLDataType(data_type), + ToCDatatType(data_type), root, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -703,7 +694,7 @@ class CustomDevice : public DeviceInterface { void CCLReduce(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& comm, @@ -713,7 +704,7 @@ class CustomDevice : public DeviceInterface { pimpl_->xccl_reduce(in_data, out_data, num, - ToXCCLDataType(data_type), + ToCDatatType(data_type), ToXCCLReduceOp(reduce_op), root_id, reinterpret_cast(comm), @@ -723,7 +714,7 @@ class CustomDevice : public DeviceInterface { void CCLAllGather(void* send_buf, void* recv_buf, size_t count, - ccl::CCLDataType data_type, + phi::DataType data_type, const ccl::CCLComm& comm, const stream::Stream& stream) override { CHECK_PTR(pimpl_->xccl_all_gather); @@ -731,7 +722,7 @@ class CustomDevice : public DeviceInterface { send_buf, recv_buf, count, - ToXCCLDataType(data_type), + ToCDatatType(data_type), reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); } @@ -739,7 +730,7 @@ class CustomDevice : public DeviceInterface { void CCLReduceScatter(void* send_buf, void* recv_buf, size_t count, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& comm, const stream::Stream& stream) override { @@ -748,7 +739,7 @@ class CustomDevice : public DeviceInterface { send_buf, recv_buf, count, - ToXCCLDataType(data_type), + ToCDatatType(data_type), ToXCCLReduceOp(reduce_op), reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -768,7 +759,7 @@ class CustomDevice : public DeviceInterface { void CCLSend(void* send_buf, size_t count, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t dest_rank, const ccl::CCLComm& comm, const stream::Stream& stream) override { @@ -776,7 +767,7 @@ class CustomDevice : public DeviceInterface { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_send(send_buf, count, - ToXCCLDataType(data_type), + ToCDatatType(data_type), dest_rank, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -784,7 +775,7 @@ class CustomDevice : public DeviceInterface { void CCLRecv(void* recv_buf, size_t count, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t src_rank, const ccl::CCLComm& comm, const stream::Stream& stream) override { @@ -792,7 +783,7 @@ class CustomDevice : public DeviceInterface { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_recv(recv_buf, count, - ToXCCLDataType(data_type), + ToCDatatType(data_type), src_rank, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -800,10 +791,10 @@ class CustomDevice : public DeviceInterface { void CCLAllToAll(const void** send_buf, const size_t* send_count, - const ccl::CCLDataType* send_dtype, + const phi::DataType* send_dtype, void** recv_buf, const size_t* recv_count, - const ccl::CCLDataType* recv_dtype, + const phi::DataType* recv_dtype, size_t rank, size_t nranks, const ccl::CCLComm& comm, @@ -811,8 +802,8 @@ class CustomDevice : public DeviceInterface { if (pimpl_->xccl_all_to_all) { std::vector c_send_dtype, c_recv_dtype; for (size_t i = 0; i < nranks; ++i) { - c_send_dtype.push_back(ToXCCLDataType(send_dtype[i])); - c_recv_dtype.push_back(ToXCCLDataType(recv_dtype[i])); + c_send_dtype.push_back(ToCDatatType(send_dtype[i])); + c_recv_dtype.push_back(ToCDatatType(recv_dtype[i])); } PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_all_to_all( send_buf, @@ -832,7 +823,7 @@ class CustomDevice : public DeviceInterface { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_recv(recv_buf[i], recv_count[i], - ToXCCLDataType(recv_dtype[i]), + ToCDatatType(recv_dtype[i]), i, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -842,7 +833,7 @@ class CustomDevice : public DeviceInterface { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS(pimpl_->xccl_send( const_cast(send_buf[i]), send_count[i], - ToXCCLDataType(send_dtype[i]), + ToCDatatType(send_dtype[i]), i, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -851,14 +842,13 @@ class CustomDevice : public DeviceInterface { MemoryCopyD2D(rank, recv_buf[rank], send_buf[rank], - send_count[rank] * - phi::SizeOf(phi::ccl::ToPhiDataType(send_dtype[rank])), + send_count[rank] * phi::SizeOf(send_dtype[rank]), &stream); for (size_t i = rank + 1; i < nranks; ++i) { PADDLE_ENFORCE_CUSTOM_DEVICE_SUCCESS( pimpl_->xccl_recv(recv_buf[i], recv_count[i], - ToXCCLDataType(recv_dtype[i]), + ToCDatatType(recv_dtype[i]), i, reinterpret_cast(comm), reinterpret_cast(stream.raw_stream()))); @@ -1106,7 +1096,7 @@ void LoadCustomRuntimeLib(const std::string& dso_lib_path, void* dso_handle) { } LoadCustomRuntimeLib( runtime_params, std::move(device_interface), dso_lib_path, dso_handle); - LOG(INFO) << "Successed in loading custom runtime in lib: " << dso_lib_path; + LOG(INFO) << "Succeed in loading custom runtime in lib: " << dso_lib_path; } #undef INTERFACE_UNIMPLEMENT diff --git a/paddle/phi/backends/device_base.cc b/paddle/phi/backends/device_base.cc index f27919bef05fe..e02fe9e340224 100644 --- a/paddle/phi/backends/device_base.cc +++ b/paddle/phi/backends/device_base.cc @@ -215,9 +215,9 @@ size_t DeviceInterface::AllocSize(size_t dev_id, bool realloc) { size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb : FLAGS_initial_gpu_memory_in_mb; size_t alloc_bytes = - (flag_mb > 0ul - ? flag_mb << 20 - : available_to_alloc * FLAGS_fraction_of_gpu_memory_to_use); + (flag_mb > 0ul ? flag_mb << 20 + : available_to_alloc * + FLAGS_fraction_of_gpu_memory_to_use); // NOLINT PADDLE_ENFORCE_GE(available_to_alloc, alloc_bytes, phi::errors::ResourceExhausted( @@ -267,6 +267,10 @@ size_t DeviceInterface::GetExtraPaddingSize(size_t dev_id) { return 0; } +void DeviceInterface::CCLCommName(ccl::CCLComm ccl_comm, char* comm_name) { + INTERFACE_UNIMPLEMENT; +} + void DeviceInterface::CCLDestroyComm(ccl::CCLComm ccl_comm) { INTERFACE_UNIMPLEMENT; } @@ -284,7 +288,7 @@ void DeviceInterface::CCLGetUniqueId(ccl::CCLRootId* root_id) { void DeviceInterface::CCLBroadcast(void* data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t root, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -294,7 +298,7 @@ void DeviceInterface::CCLBroadcast(void* data, void DeviceInterface::CCLAllReduce(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -304,7 +308,7 @@ void DeviceInterface::CCLAllReduce(void* in_data, void DeviceInterface::CCLReduce(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, @@ -315,7 +319,7 @@ void DeviceInterface::CCLReduce(void* in_data, void DeviceInterface::CCLAllGather(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { INTERFACE_UNIMPLEMENT; @@ -324,7 +328,7 @@ void DeviceInterface::CCLAllGather(void* in_data, void DeviceInterface::CCLReduceScatter(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -337,7 +341,7 @@ void DeviceInterface::CCLGroupEnd() { INTERFACE_UNIMPLEMENT; } void DeviceInterface::CCLSend(void* sendbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -346,7 +350,7 @@ void DeviceInterface::CCLSend(void* sendbuf, void DeviceInterface::CCLRecv(void* recvbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -355,10 +359,10 @@ void DeviceInterface::CCLRecv(void* recvbuf, void DeviceInterface::CCLAllToAll(const void** send_buf, const size_t* send_count, - const ccl::CCLDataType* send_dtype, + const phi::DataType* send_dtype, void** recv_buf, const size_t* recv_count, - const ccl::CCLDataType* recv_dtype, + const phi::DataType* recv_dtype, size_t rank, size_t nranks, const ccl::CCLComm& comm, diff --git a/paddle/phi/backends/device_base.h b/paddle/phi/backends/device_base.h index 855e77890348a..75e72c72887b9 100644 --- a/paddle/phi/backends/device_base.h +++ b/paddle/phi/backends/device_base.h @@ -169,6 +169,8 @@ class DeviceInterface { // Driver / Runtime virtual size_t GetExtraPaddingSize(size_t dev_id); // CCL + virtual void CCLCommName(ccl::CCLComm ccl_comm, char* comm_name); + virtual void CCLDestroyComm(ccl::CCLComm ccl_comm); virtual void CCLCommInitRank(size_t num_ranks, @@ -180,7 +182,7 @@ class DeviceInterface { // Driver / Runtime virtual void CCLBroadcast(void* data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t root, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); @@ -188,14 +190,14 @@ class DeviceInterface { // Driver / Runtime virtual void CCLAllReduce(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); virtual void CCLReduce(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, @@ -203,13 +205,13 @@ class DeviceInterface { // Driver / Runtime virtual void CCLAllGather(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); virtual void CCLReduceScatter(void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); @@ -217,23 +219,23 @@ class DeviceInterface { // Driver / Runtime virtual void CCLGroupEnd(); virtual void CCLSend(void* sendbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); virtual void CCLRecv(void* recvbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); virtual void CCLAllToAll(const void** send_buf, const size_t* send_count, - const ccl::CCLDataType* send_dtype, + const phi::DataType* send_dtype, void** recv_buf, const size_t* recv_count, - const ccl::CCLDataType* recv_dtype, + const phi::DataType* recv_dtype, size_t rank, size_t nranks, const ccl::CCLComm& comm, diff --git a/paddle/phi/backends/device_code.cc b/paddle/phi/backends/device_code.cc index 670e0e3781598..e2016ff78b7c3 100644 --- a/paddle/phi/backends/device_code.cc +++ b/paddle/phi/backends/device_code.cc @@ -186,7 +186,8 @@ static std::string FindCUDAIncludePath() { } for (std::string suffix : {"/lib", "/lib64"}) { if (EndWith(FLAGS_cuda_dir, suffix)) { - cuda_include_path.erase(cuda_include_path.end() - suffix.length()); + cuda_include_path.erase(cuda_include_path.end() - + suffix.length()); // NOLINT break; } } diff --git a/paddle/phi/backends/device_ext.h b/paddle/phi/backends/device_ext.h index bd3f5f687f29b..a2d68bee1ac27 100644 --- a/paddle/phi/backends/device_ext.h +++ b/paddle/phi/backends/device_ext.h @@ -50,6 +50,7 @@ typedef enum { NCHW, NCDHW, NDHWC, + STRIDED, NUM_DATA_LAYOUTS, ALL_LAYOUT = ANY, } C_DataLayout; @@ -547,6 +548,13 @@ struct C_DeviceInterface { // ccl api // ////////////// + /** + * @brief Get comm name. + * + * @param[char*] comm_name + */ + C_Status (*xccl_get_comm_name)(C_CCLComm comm, char* comm_name); + /** * @brief Get size of unique id * diff --git a/paddle/phi/backends/device_manager.cc b/paddle/phi/backends/device_manager.cc index e3ec68e7f9182..ae21fbb3e9f06 100644 --- a/paddle/phi/backends/device_manager.cc +++ b/paddle/phi/backends/device_manager.cc @@ -509,6 +509,13 @@ std::vector DeviceManager::GetSelectedDeviceList( return device_list_map[device_type]; } +void DeviceManager::CCLCommName(const std::string& device_type, + const ccl::CCLComm& ccl_comm, + char* comm_name) { + auto dev_impl = GetDeviceInterfaceWithType(device_type); + return dev_impl->CCLCommName(ccl_comm, comm_name); +} + void DeviceManager::CCLDestroyComm(const std::string& device_type, ccl::CCLComm ccl_comm) { auto dev_impl = GetDeviceInterfaceWithType(device_type); @@ -533,7 +540,7 @@ void DeviceManager::CCLGetUniqueId(const std::string& device_type, void DeviceManager::CCLBroadcast(const std::string& device_type, void* data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t root_id, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -545,7 +552,7 @@ void DeviceManager::CCLAllReduce(const std::string& device_type, void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -558,7 +565,7 @@ void DeviceManager::CCLReduce(const std::string& device_type, void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, @@ -572,7 +579,7 @@ void DeviceManager::CCLAllGather(const std::string& device_type, void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { auto dev_impl = GetDeviceInterfaceWithType(device_type); @@ -583,7 +590,7 @@ void DeviceManager::CCLReduceScatter(const std::string& device_type, void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -605,7 +612,7 @@ void DeviceManager::CCLGroupEnd(const std::string& device_type) { void DeviceManager::CCLSend(const std::string& device_type, void* sendbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -616,7 +623,7 @@ void DeviceManager::CCLSend(const std::string& device_type, void DeviceManager::CCLRecv(const std::string& device_type, void* recvbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream) { @@ -627,10 +634,10 @@ void DeviceManager::CCLRecv(const std::string& device_type, void DeviceManager::CCLAllToAll(const std::string& device_type, const void** send_buf, const size_t* send_count, - const ccl::CCLDataType* send_dtype, + const phi::DataType* send_dtype, void** recv_buf, const size_t* recv_count, - const ccl::CCLDataType* recv_dtype, + const phi::DataType* recv_dtype, size_t rank, size_t nranks, const ccl::CCLComm& comm, diff --git a/paddle/phi/backends/device_manager.h b/paddle/phi/backends/device_manager.h index 58a9e6ebe7ab8..5a42d2450ba97 100644 --- a/paddle/phi/backends/device_manager.h +++ b/paddle/phi/backends/device_manager.h @@ -23,9 +23,9 @@ #include "paddle/phi/backends/c_comm_lib.h" #include "paddle/phi/backends/device_base.h" #include "paddle/phi/backends/device_ext.h" -#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/event.h" #include "paddle/phi/backends/stream.h" +#include "paddle/phi/common/port.h" namespace phi { class Device final { @@ -178,6 +178,9 @@ class DeviceManager { const std::string& device_type); // CCL + static void CCLCommName(const std::string& device_type, + const ccl::CCLComm& ccl_comm, + char* comm_name); static void CCLDestroyComm(const std::string& device_type, ccl::CCLComm ccl_comm); static void CCLCommInitRank(const std::string& device_type, @@ -190,7 +193,7 @@ class DeviceManager { static void CCLBroadcast(const std::string& device_type, void* data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t root, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); @@ -198,7 +201,7 @@ class DeviceManager { void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); @@ -206,7 +209,7 @@ class DeviceManager { void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp reduce_op, size_t root_id, const ccl::CCLComm& ccl_comm, @@ -215,14 +218,14 @@ class DeviceManager { void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); static void CCLReduceScatter(const std::string& device_type, void* in_data, void* out_data, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, ccl::CCLReduceOp op, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); @@ -231,14 +234,14 @@ class DeviceManager { static void CCLSend(const std::string& device_type, void* sendbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t dst_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); static void CCLRecv(const std::string& device_type, void* recvbuf, size_t num, - ccl::CCLDataType data_type, + phi::DataType data_type, size_t src_rank, const ccl::CCLComm& ccl_comm, const stream::Stream& stream); @@ -246,10 +249,10 @@ class DeviceManager { static void CCLAllToAll(const std::string& device_type, const void** send_buf, const size_t* send_count, - const ccl::CCLDataType* send_dtype, + const phi::DataType* send_dtype, void** recv_buf, const size_t* recv_count, - const ccl::CCLDataType* recv_dtype, + const phi::DataType* recv_dtype, size_t rank, size_t nranks, const ccl::CCLComm& comm, diff --git a/paddle/phi/backends/dynload/CMakeLists.txt b/paddle/phi/backends/dynload/CMakeLists.txt index 9fd293574e247..1c444ebc1fa1e 100644 --- a/paddle/phi/backends/dynload/CMakeLists.txt +++ b/paddle/phi/backends/dynload/CMakeLists.txt @@ -1,5 +1,4 @@ -set(DYNLOAD_COMMON_SRCS dynamic_loader.cc port.cc warpctc.cc warprnnt.cc - lapack.cc) +set(DYNLOAD_COMMON_SRCS dynamic_loader.cc warpctc.cc warprnnt.cc lapack.cc) if(WITH_ASCEND_CL) list(REMOVE_ITEM DYNLOAD_COMMON_SRCS warprnnt.cc) endif() diff --git a/paddle/phi/backends/dynload/cublas.h b/paddle/phi/backends/dynload/cublas.h index 308ae2accef14..8053bbb6bd2ce 100644 --- a/paddle/phi/backends/dynload/cublas.h +++ b/paddle/phi/backends/dynload/cublas.h @@ -22,7 +22,7 @@ limitations under the License. */ #include #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index 90492ff4ba69d..5b05ee644f6c5 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -22,7 +22,7 @@ limitations under the License. */ #include #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cuda_driver.h b/paddle/phi/backends/dynload/cuda_driver.h index ba771afe09023..657b577d0a82e 100644 --- a/paddle/phi/backends/dynload/cuda_driver.h +++ b/paddle/phi/backends/dynload/cuda_driver.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cudnn.cc b/paddle/phi/backends/dynload/cudnn.cc index 924dd60d2c5e1..fb1c9cfa0af97 100644 --- a/paddle/phi/backends/dynload/cudnn.cc +++ b/paddle/phi/backends/dynload/cudnn.cc @@ -50,6 +50,18 @@ CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP); CUDNN_DNN_ROUTINE_EACH_FRONTEND(DEFINE_WRAP); #endif +#ifdef CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9 +CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DEFINE_WRAP); +#endif + +#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9 +CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(DEFINE_WRAP); +#endif + +#ifdef CUDNN_DNN_ROUTINE_EACH_R9 +CUDNN_DNN_ROUTINE_EACH_R9(DEFINE_WRAP); +#endif + bool HasCUDNN() { std::call_once(cudnn_dso_flag, []() { cudnn_dso_handle = GetCUDNNDsoHandle(); }); diff --git a/paddle/phi/backends/dynload/cudnn.h b/paddle/phi/backends/dynload/cudnn.h index 3292beb037110..7a7dce241ff0a 100644 --- a/paddle/phi/backends/dynload/cudnn.h +++ b/paddle/phi/backends/dynload/cudnn.h @@ -19,16 +19,16 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { -extern std::once_flag cudnn_dso_flag; -extern void* cudnn_dso_handle; +TEST_API extern std::once_flag cudnn_dso_flag; +TEST_API extern void* cudnn_dso_handle; extern bool HasCUDNN(); -extern void EnforceCUDNNLoaded(const char* fn_name); +TEST_API extern void EnforceCUDNNLoaded(const char* fn_name); #define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \ struct DynLoad__##__name { \ template \ @@ -103,13 +103,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnSetDropoutDescriptor); \ __macro(cudnnRestoreDropoutDescriptor); \ __macro(cudnnCreateRNNDescriptor); \ - __macro(cudnnGetRNNParamsSize); \ - __macro(cudnnGetRNNWorkspaceSize); \ - __macro(cudnnGetRNNTrainingReserveSize); \ - __macro(cudnnRNNForwardTraining); \ - __macro(cudnnRNNBackwardData); \ - __macro(cudnnRNNBackwardWeights); \ - __macro(cudnnRNNForwardInference); \ __macro(cudnnDestroyDropoutDescriptor); \ __macro(cudnnDestroyRNNDescriptor); \ __macro(cudnnSetTensorNdDescriptorEx); \ @@ -124,8 +117,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); __macro(cudnnCreateActivationDescriptor); \ __macro(cudnnSetActivationDescriptor); \ __macro(cudnnGetActivationDescriptor); \ - __macro(cudnnDestroyActivationDescriptor); \ - __macro(cudnnSetRNNDescriptor_v6); + __macro(cudnnDestroyActivationDescriptor); CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000 @@ -159,12 +151,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \ __macro(cudnnCreateRNNDataDescriptor); \ __macro(cudnnDestroyRNNDataDescriptor); \ - __macro(cudnnSetRNNDataDescriptor); \ - __macro(cudnnSetRNNPaddingMode); \ - __macro(cudnnRNNForwardTrainingEx); \ - __macro(cudnnRNNBackwardDataEx); \ - __macro(cudnnRNNBackwardWeightsEx); \ - __macro(cudnnRNNForwardInferenceEx); + __macro(cudnnSetRNNDataDescriptor); CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif @@ -207,6 +194,39 @@ CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_FRONTEND(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif +#if CUDNN_VERSION < 90000 +#define CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(__macro) \ + __macro(cudnnGetRNNParamsSize); \ + __macro(cudnnGetRNNWorkspaceSize); \ + __macro(cudnnGetRNNTrainingReserveSize); \ + __macro(cudnnSetRNNDescriptor_v6); \ + __macro(cudnnRNNForwardInference); \ + __macro(cudnnRNNForwardTraining); \ + __macro(cudnnRNNBackwardData); \ + __macro(cudnnRNNBackwardWeights); +CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + +#if CUDNN_VERSION < 90000 && CUDNN_VERSION >= 7201 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(__macro) \ + __macro(cudnnSetRNNPaddingMode); \ + __macro(cudnnRNNForwardInferenceEx); \ + __macro(cudnnRNNForwardTrainingEx); \ + __macro(cudnnRNNBackwardDataEx); \ + __macro(cudnnRNNBackwardWeightsEx); +CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9( + DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif + +#if CUDNN_VERSION >= 90000 +#define CUDNN_DNN_ROUTINE_EACH_R9(__macro) \ + __macro(cudnnGetRNNWeightSpaceSize); \ + __macro(cudnnGetRNNTempSpaceSizes); \ + __macro(cudnnRNNForward); \ + __macro(cudnnRNNBackwardData_v8); \ + __macro(cudnnRNNBackwardWeights_v8); +CUDNN_DNN_ROUTINE_EACH_R9(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif } // namespace dynload } // namespace phi diff --git a/paddle/phi/backends/dynload/cufft.h b/paddle/phi/backends/dynload/cufft.h index a27d7c3ab1eee..1547909d92e24 100644 --- a/paddle/phi/backends/dynload/cufft.h +++ b/paddle/phi/backends/dynload/cufft.h @@ -21,7 +21,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cupti.h b/paddle/phi/backends/dynload/cupti.h index 22e21b78f4f2e..59e92955c930e 100644 --- a/paddle/phi/backends/dynload/cupti.h +++ b/paddle/phi/backends/dynload/cupti.h @@ -22,7 +22,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/curand.h b/paddle/phi/backends/dynload/curand.h index f3c4496dc4d39..6b6abf7825d2e 100644 --- a/paddle/phi/backends/dynload/curand.h +++ b/paddle/phi/backends/dynload/curand.h @@ -18,7 +18,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cusolver.h b/paddle/phi/backends/dynload/cusolver.h index a86e85144fd7f..74c64085ea721 100644 --- a/paddle/phi/backends/dynload/cusolver.h +++ b/paddle/phi/backends/dynload/cusolver.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cusparse.h b/paddle/phi/backends/dynload/cusparse.h index d75b236c07ab1..8ec3cf2792444 100644 --- a/paddle/phi/backends/dynload/cusparse.h +++ b/paddle/phi/backends/dynload/cusparse.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/cusparseLt.h b/paddle/phi/backends/dynload/cusparseLt.h index 8eecefab5e469..a45b0637d8569 100644 --- a/paddle/phi/backends/dynload/cusparseLt.h +++ b/paddle/phi/backends/dynload/cusparseLt.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/dynamic_loader.cc b/paddle/phi/backends/dynload/dynamic_loader.cc index efdac108bcc8e..0b056d6df972f 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.cc +++ b/paddle/phi/backends/dynload/dynamic_loader.cc @@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/backends/dynload/dynamic_loader.h" +#include #include #include #include #include "paddle/phi/backends/dynload/cupti_lib_path.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "paddle/phi/core/enforce.h" #if defined(_WIN32) @@ -182,6 +183,34 @@ static inline void* GetDsoHandleFromSpecificPath(const std::string& spec_path, return dso_handle; } +static inline std::string FindLibAbsolutePath(const std::string& directory, + const std::string& filename) { + DIR* dir; + struct dirent* ent; + + if ((dir = opendir(directory.c_str())) != nullptr) { + while ((ent = readdir(dir)) != nullptr) { + if (ent->d_type == DT_REG || ent->d_type == DT_LNK) { + if (filename == std::string(ent->d_name)) { + closedir(dir); + return join(directory, ent->d_name); + } + } else if (ent->d_type == DT_DIR) { + if (strcmp(ent->d_name, ".") != 0 && strcmp(ent->d_name, "..") != 0) { + std::string res = + FindLibAbsolutePath(join(directory, ent->d_name) + "/", filename); + if (!res.empty()) { + closedir(dir); + return res; + } + } + } + } + closedir(dir); + } + return ""; +} + static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path, int dynload_flags) { // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH @@ -195,10 +224,19 @@ static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path, // bring System Integrity Projection (SIP), if dso_handle // is null, search from default package path in Mac OS. #if defined(__APPLE__) || defined(__OSX__) +#if defined(__arm__) || defined(__aarch64__) + if (nullptr == dso_handle) { + dso_handle = + dlopen(FindLibAbsolutePath("/opt/homebrew/Cellar/", dso_path).c_str(), + dynload_flags); + } +#else if (nullptr == dso_handle) { dso_handle = - dlopen(join("/usr/local/cuda/lib/", dso_path).c_str(), dynload_flags); + dlopen(FindLibAbsolutePath("/usr/local/cuda/lib/", dso_path).c_str(), + dynload_flags); } +#endif #endif return dso_handle; @@ -260,7 +298,7 @@ static inline void* GetDsoHandleFromSearchPath( " 2. Configure third-party dynamic library environment variables as " "follows:\n" " - Linux: set LD_LIBRARY_PATH by `export LD_LIBRARY_PATH=...`\n" - " - Windows: set PATH by `set PATH=XXX;%PATH%`\n" + " - Windows: set PATH by `set PATH=XXX;%%PATH%%`\n" " - Mac: set DYLD_LIBRARY_PATH by `export DYLD_LIBRARY_PATH=...` " "[Note: After Mac OS 10.11, using the DYLD_LIBRARY_PATH is " "impossible unless System Integrity Protection (SIP) is disabled.]"; @@ -289,9 +327,17 @@ void* GetCublasDsoHandle() { FLAGS_cuda_dir, win_cublas_lib, true, {cuda_lib_path}); #elif defined(__linux__) && defined(PADDLE_WITH_CUDA) if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so.11"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so"); +#endif } else if (CUDA_VERSION >= 12000 && CUDA_VERSION <= 12030) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so.12"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublas.so"); +#endif } else { std::string warning_msg( "Your CUDA_VERSION is less than 11 or greater than 12, paddle " @@ -309,9 +355,17 @@ void* GetCublasLtDsoHandle() { // APIs available after CUDA 10.1 #if defined(__linux__) && defined(PADDLE_WITH_CUDA) if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so.11"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so"); +#endif } else if (CUDA_VERSION >= 12000 && CUDA_VERSION <= 12030) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so.12"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cublas_dir, "libcublasLt.so"); +#endif } else { std::string warning_msg( "Your CUDA_VERSION is less than 11 or greater than 12, paddle " @@ -353,8 +407,13 @@ void* GetCUDNNDsoHandle() { #elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_miopen_dir, "libMIOpen.so", false); #else +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath( FLAGS_cudnn_dir, "libcudnn.so.8", false, {cuda_lib_path}); +#else + return GetDsoHandleFromSearchPath( + FLAGS_cudnn_dir, "libcudnn.so", false, {cuda_lib_path}); +#endif #endif } @@ -364,11 +423,22 @@ void* GetCUPTIDsoHandle() { FLAGS_cupti_dir, "libcupti.dylib", false, {cupti_lib_path}); #elif defined(__linux__) && defined(PADDLE_WITH_CUDA) if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.so.11.7", false, {cupti_lib_path}); + FLAGS_cupti_dir, "libcupti.so.11.8", false, {cupti_lib_path}); +#else + return GetDsoHandleFromSearchPath( + FLAGS_cupti_dir, "libcupti.so", false, {cupti_lib_path}); +#endif + } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 12030) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath( FLAGS_cupti_dir, "libcupti.so.12", false, {cupti_lib_path}); +#else + return GetDsoHandleFromSearchPath( + FLAGS_cupti_dir, "libcupti.so", false, {cupti_lib_path}); +#endif } else { std::string warning_msg( "Your CUDA_VERSION is less than 11 or greater than 12, paddle " @@ -377,7 +447,7 @@ void* GetCUPTIDsoHandle() { } #else return GetDsoHandleFromSearchPath( - FLAGS_cupti_dir, "libcupti.so.11.7", false, {cupti_lib_path}); + FLAGS_cupti_dir, "libcupti.so", false, {cupti_lib_path}); #endif } @@ -390,7 +460,12 @@ void* GetCurandDsoHandle() { #elif defined(PADDLE_WITH_HIP) return GetDsoHandleFromSearchPath(FLAGS_rocm_dir, "libhiprand.so"); #else +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_curand_dir, "libcurand.so.10"); +#else + return GetDsoHandleFromSearchPath(FLAGS_curand_dir, "libcurand.so"); +#endif + #endif } @@ -422,7 +497,11 @@ void* GetCusolverDsoHandle() { return GetDsoHandleFromSearchPath( FLAGS_cuda_dir, win_cusolver_lib, true, {cuda_lib_path}); #else +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.so.11"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.so"); +#endif #endif } @@ -434,9 +513,17 @@ void* GetCusparseDsoHandle() { FLAGS_cuda_dir, win_cusparse_lib, true, {cuda_lib_path}); #elif defined(__linux__) && defined(PADDLE_WITH_CUDA) if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so.11"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so"); +#endif } else if (CUDA_VERSION >= 12000 && CUDA_VERSION <= 12030) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so.12"); +#else + return GetDsoHandleFromSearchPath(FLAGS_cusparse_dir, "libcusparse.so"); +#endif } else { std::string warning_msg( "Your CUDA_VERSION is less than 11 or greater than 12, paddle " @@ -535,9 +622,15 @@ void* GetNCCLDsoHandle() { #elif defined(PADDLE_WITH_HIP) && defined(PADDLE_WITH_RCCL) return GetDsoHandleFromSearchPath( FLAGS_rccl_dir, "librccl.so", true, {}, warning_msg); +#else +#ifdef WITH_PIP_CUDA_LIBRARIES + return GetDsoHandleFromSearchPath( + FLAGS_nccl_dir, "libnccl.so;libnccl.so.2", true, {}, warning_msg); #else return GetDsoHandleFromSearchPath( - FLAGS_nccl_dir, "libnccl.so.2", true, {}, warning_msg); + FLAGS_nccl_dir, "libnccl.so", true, {}, warning_msg); +#endif + #endif } @@ -563,7 +656,11 @@ void* GetMKLMLDsoHandle() { void* GetLAPACKDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) +#if defined(__arm__) || defined(__aarch64__) + return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dylib"); +#else return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.3.dylib"); +#endif #elif defined(_WIN32) return GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dll"); #else @@ -592,8 +689,12 @@ void* GetCUFFTDsoHandle() { return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.dylib"); #elif defined(__linux__) && defined(PADDLE_WITH_CUDA) if (CUDA_VERSION >= 11000 && CUDA_VERSION < 12000) { +#ifdef WITH_PIP_CUDA_LIBRARIES return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so.10"); - } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 13000) { +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so"); +#endif + } else if (CUDA_VERSION >= 12000 && CUDA_VERSION < 12030) { return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcufft.so.11"); } else { std::string warning_msg( @@ -639,6 +740,5 @@ void* GetXPTIDsoHandle() { return nullptr; #endif } - } // namespace dynload } // namespace phi diff --git a/paddle/phi/backends/dynload/dynamic_loader.h b/paddle/phi/backends/dynload/dynamic_loader.h index 6ddeb1386410f..b71a8fe976cbb 100644 --- a/paddle/phi/backends/dynload/dynamic_loader.h +++ b/paddle/phi/backends/dynload/dynamic_loader.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once #include - +#include "paddle/utils/test_macros.h" namespace phi { namespace dynload { @@ -26,7 +26,7 @@ namespace dynload { void* GetCublasDsoHandle(); void* GetCublasLtDsoHandle(); -void* GetCUDNNDsoHandle(); +TEST_API void* GetCUDNNDsoHandle(); void* GetCUPTIDsoHandle(); void* GetCurandDsoHandle(); void* GetNvjpegDsoHandle(); diff --git a/paddle/phi/backends/dynload/flashattn.h b/paddle/phi/backends/dynload/flashattn.h index e4728cf43405e..2c03329944371 100644 --- a/paddle/phi/backends/dynload/flashattn.h +++ b/paddle/phi/backends/dynload/flashattn.h @@ -18,7 +18,7 @@ limitations under the License. */ #include "flashattn/include/flash_attn.h" #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/hipfft.h b/paddle/phi/backends/dynload/hipfft.h index 4d45a26b8b981..45e5a2a473d2a 100644 --- a/paddle/phi/backends/dynload/hipfft.h +++ b/paddle/phi/backends/dynload/hipfft.h @@ -18,7 +18,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/hiprand.h b/paddle/phi/backends/dynload/hiprand.h index 3e9502dd94d91..038b01eb7de5f 100644 --- a/paddle/phi/backends/dynload/hiprand.h +++ b/paddle/phi/backends/dynload/hiprand.h @@ -18,7 +18,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/hiprtc.h b/paddle/phi/backends/dynload/hiprtc.h index 75dd88f87bd3a..06c869b178481 100644 --- a/paddle/phi/backends/dynload/hiprtc.h +++ b/paddle/phi/backends/dynload/hiprtc.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/lapack.h b/paddle/phi/backends/dynload/lapack.h index 74051821eaebb..eaea6783824ab 100644 --- a/paddle/phi/backends/dynload/lapack.h +++ b/paddle/phi/backends/dynload/lapack.h @@ -18,7 +18,7 @@ limitations under the License. */ #include #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" // Because lapack doesn't provide appropriate header file, // we should expose API statement yourself. diff --git a/paddle/phi/backends/dynload/miopen.h b/paddle/phi/backends/dynload/miopen.h index eeaf8028ec312..6ef19f60f9f05 100644 --- a/paddle/phi/backends/dynload/miopen.h +++ b/paddle/phi/backends/dynload/miopen.h @@ -20,7 +20,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #define MIOPEN_VERSION \ (MIOPEN_VERSION_MAJOR * 1000 + MIOPEN_VERSION_MINOR * 10 + \ diff --git a/paddle/phi/backends/dynload/mklml.h b/paddle/phi/backends/dynload/mklml.h index 0f0c31f8064df..e5e8d104af044 100644 --- a/paddle/phi/backends/dynload/mklml.h +++ b/paddle/phi/backends/dynload/mklml.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/mklrt.h b/paddle/phi/backends/dynload/mklrt.h index 0267fb69a5932..fe12e2c2fb084 100644 --- a/paddle/phi/backends/dynload/mklrt.h +++ b/paddle/phi/backends/dynload/mklrt.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/nccl.cc b/paddle/phi/backends/dynload/nccl.cc index 147066b43b031..fe322c2ad7be5 100644 --- a/paddle/phi/backends/dynload/nccl.cc +++ b/paddle/phi/backends/dynload/nccl.cc @@ -14,11 +14,20 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/nccl.h" +ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm, + int nranks, + ncclUniqueId commId, + int myrank, + int param) { + // fake impl for compilation + return ncclInvalidUsage; +} + namespace phi { namespace dynload { std::once_flag nccl_dso_flag; -void *nccl_dso_handle; +void* nccl_dso_handle; #define DEFINE_WRAP(__name) DynLoad__##__name __name diff --git a/paddle/phi/backends/dynload/nccl.h b/paddle/phi/backends/dynload/nccl.h index 91b6f5dcd58dc..c52a8c1824514 100644 --- a/paddle/phi/backends/dynload/nccl.h +++ b/paddle/phi/backends/dynload/nccl.h @@ -18,7 +18,19 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" + +#ifdef __cplusplus +extern "C" { +#endif +ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm, + int nranks, + ncclUniqueId commId, + int myrank, + int param); +#ifdef __cplusplus +} +#endif namespace phi { namespace dynload { @@ -28,15 +40,21 @@ extern void* nccl_dso_handle; #define DECLARE_DYNAMIC_LOAD_NCCL_WRAP(__name) \ struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ + static auto GetNCCLFunc() { \ using nccl_func = decltype(&::__name); \ std::call_once(nccl_dso_flag, []() { \ nccl_dso_handle = phi::dynload::GetNCCLDsoHandle(); \ }); \ static void* p_##__name = dlsym(nccl_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ + return reinterpret_cast(p_##__name); \ + } \ + \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + return GetNCCLFunc()(args...); \ } \ + \ + static bool IsValid() { return GetNCCLFunc() != nullptr; } \ }; \ extern DynLoad__##__name __name @@ -44,6 +62,7 @@ extern void* nccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommInitRank2); \ __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ diff --git a/paddle/phi/backends/dynload/nvjpeg.h b/paddle/phi/backends/dynload/nvjpeg.h index 6e71e6b582c05..c5309e7e1167f 100644 --- a/paddle/phi/backends/dynload/nvjpeg.h +++ b/paddle/phi/backends/dynload/nvjpeg.h @@ -16,7 +16,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/nvrtc.h b/paddle/phi/backends/dynload/nvrtc.h index 9244e9487b250..ecd6da4573f7c 100644 --- a/paddle/phi/backends/dynload/nvrtc.h +++ b/paddle/phi/backends/dynload/nvrtc.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/nvtx.h b/paddle/phi/backends/dynload/nvtx.h index e51bbf2154a17..1ccedde4d558e 100644 --- a/paddle/phi/backends/dynload/nvtx.h +++ b/paddle/phi/backends/dynload/nvtx.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/rccl.cc b/paddle/phi/backends/dynload/rccl.cc index 95e171842527b..ee347af62fb79 100644 --- a/paddle/phi/backends/dynload/rccl.cc +++ b/paddle/phi/backends/dynload/rccl.cc @@ -14,11 +14,20 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/rccl.h" +ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm, + int nranks, + ncclUniqueId commId, + int myrank, + int param) { + // fake impl for compilation + return ncclInvalidUsage; +} + namespace phi { namespace dynload { std::once_flag rccl_dso_flag; -void *rccl_dso_handle; +void* rccl_dso_handle; #define DEFINE_WRAP(__name) DynLoad__##__name __name diff --git a/paddle/phi/backends/dynload/rccl.h b/paddle/phi/backends/dynload/rccl.h index e1018a3f253fa..9d3a49bce9624 100644 --- a/paddle/phi/backends/dynload/rccl.h +++ b/paddle/phi/backends/dynload/rccl.h @@ -18,7 +18,19 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" + +#ifdef __cplusplus +extern "C" { +#endif +ncclResult_t ncclCommInitRank2(ncclComm_t* newcomm, + int nranks, + ncclUniqueId commId, + int myrank, + int param); +#ifdef __cplusplus +} +#endif namespace phi { namespace dynload { @@ -28,15 +40,21 @@ extern void* rccl_dso_handle; #define DECLARE_DYNAMIC_LOAD_RCCL_WRAP(__name) \ struct DynLoad__##__name { \ - template \ - auto operator()(Args... args) -> decltype(__name(args...)) { \ - using nccl_func = decltype(&::__name); \ + static auto GetRCCLFunc() { \ + using rccl_func = decltype(&::__name); \ std::call_once(rccl_dso_flag, []() { \ rccl_dso_handle = phi::dynload::GetNCCLDsoHandle(); \ }); \ static void* p_##__name = dlsym(rccl_dso_handle, #__name); \ - return reinterpret_cast(p_##__name)(args...); \ + return reinterpret_cast(p_##__name); \ + } \ + \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + return GetRCCLFunc()(args...); \ } \ + \ + static bool IsValid() { return GetRCCLFunc() != nullptr; } \ }; \ extern DynLoad__##__name __name @@ -44,6 +62,7 @@ extern void* rccl_dso_handle; __macro(ncclCommInitAll); \ __macro(ncclGetUniqueId); \ __macro(ncclCommInitRank); \ + __macro(ncclCommInitRank2); \ __macro(ncclCommAbort); \ __macro(ncclCommDestroy); \ __macro(ncclCommCount); \ diff --git a/paddle/phi/backends/dynload/rocblas.h b/paddle/phi/backends/dynload/rocblas.h index a9804b3d82a7d..19df156b086a0 100644 --- a/paddle/phi/backends/dynload/rocblas.h +++ b/paddle/phi/backends/dynload/rocblas.h @@ -21,7 +21,7 @@ limitations under the License. */ #include #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/rocm_driver.h b/paddle/phi/backends/dynload/rocm_driver.h index 4e456db44c904..2613836bf13d4 100644 --- a/paddle/phi/backends/dynload/rocm_driver.h +++ b/paddle/phi/backends/dynload/rocm_driver.h @@ -19,7 +19,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { @@ -51,13 +51,33 @@ extern bool HasCUDADriver(); __macro(hipModuleLoadData); \ __macro(hipModuleGetFunction); \ __macro(hipModuleUnload); \ - /*rocm3.5 not support the function*/ \ + /* DTK not support the function*/ \ /* __macro(hipOccupancyMaxActiveBlocksPerMultiprocessor);*/ \ __macro(hipModuleLaunchKernel); \ __macro(hipLaunchKernel); \ __macro(hipGetDevice); \ __macro(hipGetDeviceCount); \ - __macro(hipDevicePrimaryCtxGetState) + __macro(hipDevicePrimaryCtxGetState); \ + __macro(hipDeviceGetAttribute); \ + __macro(hipDeviceGet) + +#define ROCM_ROUTINE_EACH_VVM(__macro) \ + __macro(hipMemGetAllocationGranularity); \ + __macro(hipMemAddressReserve); \ + __macro(hipMemCreate); \ + __macro(hipMemMap); \ + __macro(hipMemSetAccess); \ + __macro(hipMemUnmap); \ + __macro(hipMemRelease); \ + __macro(hipMemAddressFree) + +#define ROCM_ROUTINE_EACH_GPU_GRAPH(__macro) \ + __macro(hipGraphNodeGetType); \ + __macro(hipGraphKernelNodeGetParams); \ + __macro(hipGraphExecKernelNodeSetParams) + +ROCM_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_ROCM_WRAP); +ROCM_ROUTINE_EACH_GPU_GRAPH(DECLARE_DYNAMIC_LOAD_ROCM_WRAP); ROCM_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_ROCM_WRAP); diff --git a/paddle/phi/backends/dynload/rocsparse.h b/paddle/phi/backends/dynload/rocsparse.h index 423bb8e1c5a88..5245c27b7e448 100644 --- a/paddle/phi/backends/dynload/rocsparse.h +++ b/paddle/phi/backends/dynload/rocsparse.h @@ -21,7 +21,7 @@ #include #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/dynload/warpctc.h b/paddle/phi/backends/dynload/warpctc.h index 4cbbca53e235f..bea933a7e3bf9 100644 --- a/paddle/phi/backends/dynload/warpctc.h +++ b/paddle/phi/backends/dynload/warpctc.h @@ -17,7 +17,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "warpctc/include/ctc.h" namespace phi { diff --git a/paddle/phi/backends/dynload/warprnnt.h b/paddle/phi/backends/dynload/warprnnt.h index 3c02b20ff717c..5a84efc491ed4 100644 --- a/paddle/phi/backends/dynload/warprnnt.h +++ b/paddle/phi/backends/dynload/warprnnt.h @@ -17,7 +17,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" #include "warprnnt/include/rnnt.h" namespace phi { diff --git a/paddle/phi/backends/dynload/xpti.h b/paddle/phi/backends/dynload/xpti.h index 25ba7d9b3e0d6..bf9e2c210dac8 100644 --- a/paddle/phi/backends/dynload/xpti.h +++ b/paddle/phi/backends/dynload/xpti.h @@ -20,7 +20,7 @@ limitations under the License. */ #include // NOLINT #include "paddle/phi/backends/dynload/dynamic_loader.h" -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { namespace dynload { diff --git a/paddle/phi/backends/event.cc b/paddle/phi/backends/event.cc index c08b4b269b2d2..6d14a9460f155 100644 --- a/paddle/phi/backends/event.cc +++ b/paddle/phi/backends/event.cc @@ -84,7 +84,7 @@ void Event::Destroy() { void Event::Record(const stream::Stream* stream) { if (device_) { - is_recorded_ = true; // synchronize the event during detroy + is_recorded_ = true; // synchronize the event during destroy stream->RecordEvent(this); } } diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.cc b/paddle/phi/backends/gpu/cuda/cuda_graph.cc index 728451f9bde40..43ec0a0c89c08 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.cc @@ -301,8 +301,7 @@ void CUDAGraph::PrintToDotFiles(const std::string &dirname, #if CUDA_VERSION >= 11000 void CUDAGraphNodeLauncher::KernelNodeLaunch( - parameterSetter_t parameterSetter, - cudaKernelCallback_t cudakernelCallback) { + parameterSetter_t parameterSetter, gpuKernelCallback_t cudakernelCallback) { if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { unsigned int id = GenerateIdentifier(); auto cudaFunc = cudakernelCallback(id); @@ -333,7 +332,7 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { PADDLE_ENFORCE_GPU_SUCCESS( dynload::cuGraphKernelNodeGetParams(cuNode, &cuParams)); - CUDAKernelParams kernel_params(cuParams.kernelParams); + gpuKernelParams kernel_params(cuParams.kernelParams); auto kernel = parameterSetters.find(static_cast(cuParams.func)); VLOG(10) << "[GetParameterSettersForExecGraph] cuParams.func = " @@ -350,7 +349,7 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { auto setter = parameterSetter->second; hooks.emplace_back([setter, cuNode, cuParams]( cudaGraphExec_t exec_graph) { - CUDAKernelParams kernel_params(cuParams.kernelParams); + gpuKernelParams kernel_params(cuParams.kernelParams); setter(kernel_params); PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphExecKernelNodeSetParams( static_cast(exec_graph), cuNode, &cuParams)); @@ -369,7 +368,7 @@ CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) { void CUDAGraphNodeLauncher::KernelNodeLaunch( cudaFunction_t cudaFunc, parameterSetter_t parameterSetter, - cudaKernelCallback_t cudakernelCallback) { + gpuKernelCallback_t cudakernelCallback) { cudakernelCallback(0); } diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph.h b/paddle/phi/backends/gpu/cuda/cuda_graph.h index db5e4fcbe2da6..dfc981850ca13 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph.h @@ -95,9 +95,9 @@ class CUDAGraphContextManager { std::set capturing_ctxs_; }; -class CUDAKernelParams { +class gpuKernelParams { public: - explicit CUDAKernelParams(void **params) : kernelParams(params) {} + explicit gpuKernelParams(void **params) : kernelParams(params) {} template T &As(size_t idx) const { @@ -132,20 +132,20 @@ class CUDAGraphNodeLauncher { // Sets the kernel's parameters BEFORE activating the CUDA graph. It enables // dynamic determination and setup of kernel arguments. // - // parameterSetter_t parameterSetter = [saved_state](CUDAKernelParams + // parameterSetter_t parameterSetter = [saved_state](gpuKernelParams // ¶m){ // // Code to compute and the parameter values from the saved_state // // ... // param.As(idx) = calculated_value; // }; - using parameterSetter_t = std::function; + using parameterSetter_t = std::function; // [CUDA Kernel Callback] // Acts as the launcher for the kernel. It accepts an `unsigned int` // identifier and uses it for the kernel launch. // The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t` // reference of the kernel from the kernel pointer. - // cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) { + // gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { // // cudaFunction_t is REQUIRED to get here // cudaFunction_t cudaFunc; // PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel)); @@ -153,18 +153,18 @@ class CUDAGraphNodeLauncher { // kernel<<<>>>(id, ...); // Launching the kernel with id // return cudaFunc; // }; - using cudaKernelCallback_t = std::function; + using gpuKernelCallback_t = std::function; // [Kernel Launch] // With the callbacks defined and the CUDA function obtained, the kernel can // be launched using the `KernelNodeLaunch` method. void KernelNodeLaunch(parameterSetter_t parameterSetter, - cudaKernelCallback_t cudakernelCallback); + gpuKernelCallback_t cudakernelCallback); std::vector GetParameterSettersForExecGraph( cudaGraph_t graph); - parameterSetter_t GetParameterSetter(const CUDAKernelParams ¶ms); + parameterSetter_t GetParameterSetter(const gpuKernelParams ¶ms); static CUDAGraphNodeLauncher &Instance() { static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher; @@ -185,7 +185,7 @@ class CUDAGraphNodeLauncher { #if CUDA_VERSION >= 10010 static void ThrowErrorIfNotSupportCUDAGraph() {} #else -enum cudaStreamCaptureMode { +enum gpuStreamCaptureMode { cudaStreamCaptureModeGlobal = 0, cudaStreamCaptureModeThreadLocal = 1, cudaStreamCaptureModeRelaxed = 2 @@ -262,7 +262,7 @@ class CUDAGraph { static void BeginCapture(phi::GPUPlace place, cudaStream_t stream, - cudaStreamCaptureMode mode); + gpuStreamCaptureMode mode); static std::unique_ptr EndCapture(); static void BeginSegmentCapture(); @@ -309,7 +309,7 @@ class CUDAGraph { } } - using SetSeedFunc = std::function; + using SetSeedFunc = std::function; static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) { std::lock_guard guard(capturing_graph_->func_mtx_); capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func)); @@ -324,7 +324,7 @@ class CUDAGraph { #if CUDA_VERSION >= 10010 std::vector graphs_; std::vector exec_graphs_; - cudaStreamCaptureMode capture_mode_; + gpuStreamCaptureMode capture_mode_; #endif cudaStream_t stream_{nullptr}; phi::GPUPlace place_; @@ -368,7 +368,7 @@ class CUDAGraphCaptureModeGuard { public: explicit CUDAGraphCaptureModeGuard( - cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { + gpuStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) { if (UNLIKELY(CUDAGraph::IsCapturing())) { PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode)); // After cudaThreadExchangeStreamCaptureMode is called, @@ -385,7 +385,7 @@ class CUDAGraphCaptureModeGuard { } private: - cudaStreamCaptureMode old_mode_; + gpuStreamCaptureMode old_mode_; }; #else class CUDAGraphCaptureModeGuard { @@ -393,7 +393,7 @@ class CUDAGraphCaptureModeGuard { public: explicit CUDAGraphCaptureModeGuard( - cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} + gpuStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {} }; #endif diff --git a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h index 952dd355882e5..2d5810fbe1c9b 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h +++ b/paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h @@ -17,9 +17,13 @@ #include #include -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/phi/backends/context_pool.h" +#if defined(PADDLE_WITH_CUDA) #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#else +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" +#endif #include "paddle/phi/kernels/funcs/dropout_impl_util.h" #endif @@ -28,7 +32,7 @@ namespace backends { namespace gpu { inline bool IsCUDAGraphCapturing() { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return CUDAGraph::IsCapturing(); #else return false; @@ -39,7 +43,7 @@ inline bool IsCUDAGraphCapturing() { // Otherwise, invoke callback directly. template inline void AddPostResetCallbackIfCapturingCUDAGraph(Callback &&callback) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (UNLIKELY(IsCUDAGraphCapturing())) { return CUDAGraph::AddPostResetCallbackDuringCapturing( std::forward(callback)); @@ -52,7 +56,7 @@ template inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { static_assert(std::is_trivial::value, "T must be trivial type"); static_assert(!std::is_same::value, "T cannot be void"); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (UNLIKELY(IsCUDAGraphCapturing())) { size_t nbytes = size * sizeof(T); void *new_host_mem = new uint8_t[nbytes]; diff --git a/paddle/phi/backends/gpu/cuda/cuda_info.cc b/paddle/phi/backends/gpu/cuda/cuda_info.cc index 0af1beb782fcf..8ac492ea959f5 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_info.cc +++ b/paddle/phi/backends/gpu/cuda/cuda_info.cc @@ -28,7 +28,7 @@ namespace gpu { int DnnVersion() { if (!dynload::HasCUDNN()) return -1; - return dynload::cudnnGetVersion(); + return dynload::cudnnGetVersion(); // NOLINT } static int GetGPUDeviceCountImpl() { @@ -179,7 +179,7 @@ int GetCurrentDeviceId() { return device_id; } -std::array GetGpuMaxGridDimSize(int id) { +std::array GetGpuMaxGridDimSize(int id) { PADDLE_ENFORCE_LT( id, GetGPUDeviceCount(), @@ -187,7 +187,7 @@ std::array GetGpuMaxGridDimSize(int id) { "but received id is: %d. GPU count is: %d.", id, GetGPUDeviceCount())); - std::array ret; + std::array ret; int size; auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id); PADDLE_ENFORCE_GPU_SUCCESS(error_code_x); diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index 17e894529ca2b..fe952585f547d 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -753,7 +753,7 @@ struct GPUContext::Impl { int multi_process_; int max_threads_per_mp_; int max_threads_per_block_; - std::array max_grid_dim_size_; + std::array max_grid_dim_size_; CUDAStream* stream_{nullptr}; Eigen::GpuDevice* eigen_device_{nullptr}; @@ -873,7 +873,7 @@ int GPUContext::GetMaxThreadsPerBlock() const { return impl_->max_threads_per_block_; } -std::array GPUContext::GetCUDAMaxGridDimSize() const { +std::array GPUContext::GetCUDAMaxGridDimSize() const { return impl_->max_grid_dim_size_; } @@ -1024,7 +1024,7 @@ void GPUContext::SetMaxThreadsPerBlock(int val) { impl_->max_threads_per_block_ = val; } -void GPUContext::SetMaxGridDimSize(const std::array& val) { +void GPUContext::SetMaxGridDimSize(const std::array& val) { impl_->max_grid_dim_size_ = val; } diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index 8cd0d414bc105..7ccd365ee5f2c 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -69,7 +69,7 @@ class DnnWorkspaceHandle { void ResetWorkspace(); - void ReallocWorkspace(size_t required_workspace_bytes); + TEST_API void ReallocWorkspace(size_t required_workspace_bytes); DnnWorkspaceHandle(DnnWorkspaceHandle&&) = default; DnnWorkspaceHandle& operator=(DnnWorkspaceHandle&&) = delete; @@ -139,7 +139,7 @@ class PADDLE_API GPUContext : public DeviceContext, int GetMaxThreadsPerBlock() const; /*! \brief Return the max grid dim size in the device context */ - std::array GetCUDAMaxGridDimSize() const; + std::array GetCUDAMaxGridDimSize() const; /*! \brief Return eigen device in the device context. */ Eigen::GpuDevice* eigen_device() const; @@ -254,7 +254,7 @@ class PADDLE_API GPUContext : public DeviceContext, void SetMaxThreadsPerBlock(int val); - void SetMaxGridDimSize(const std::array& val); + void SetMaxGridDimSize(const std::array& val); void SetDriverVersion(int val); diff --git a/paddle/phi/backends/gpu/gpu_info.cc b/paddle/phi/backends/gpu/gpu_info.cc index 96048de5c047c..32546f762c39e 100644 --- a/paddle/phi/backends/gpu/gpu_info.cc +++ b/paddle/phi/backends/gpu/gpu_info.cc @@ -66,7 +66,7 @@ size_t GpuAvailableMemToAlloc() { size_t available = 0; memory_utils::GpuMemoryUsage(&available, &total); size_t reserving = - static_cast(fraction_reserve_gpu_memory * available); + static_cast(fraction_reserve_gpu_memory * available); // NOLINT // If available size is less than minimum chunk size, no usable memory exists size_t available_to_alloc = available - reserving; size_t min_chunk_size = GpuMinChunkSize(); diff --git a/paddle/phi/backends/gpu/gpu_info.h b/paddle/phi/backends/gpu/gpu_info.h index ebf57bd06eb19..c6ea44b20fe1b 100644 --- a/paddle/phi/backends/gpu/gpu_info.h +++ b/paddle/phi/backends/gpu/gpu_info.h @@ -57,7 +57,7 @@ int GetGPUMaxThreadsPerBlock(int id); int GetCurrentDeviceId(); //! Get the maximum GridDim size for GPU buddy allocator. -std::array GetGpuMaxGridDimSize(int); +std::array GetGpuMaxGridDimSize(int); std::pair GetGpuStreamPriorityRange(); diff --git a/paddle/phi/backends/gpu/gpu_launch_config.h b/paddle/phi/backends/gpu/gpu_launch_config.h index 27384587f7f8f..3196a6832cfaa 100644 --- a/paddle/phi/backends/gpu/gpu_launch_config.h +++ b/paddle/phi/backends/gpu/gpu_launch_config.h @@ -216,10 +216,13 @@ inline GpuLaunchConfig GetGpuLaunchConfig3D(const phi::GPUContext& context, int block_y = std::min(GetLastPow2(height), max_threads / block_x); int block_z = std::min(num_img, max_threads / block_x / block_y); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); - int grid_x = std::min(max_grid_dim[0], DivUp(width, block_x)); - int grid_y = std::min(max_grid_dim[1], DivUp(height, block_y)); - int grid_z = std::min(max_grid_dim[2], DivUp(num_img, block_z * 4)); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + unsigned int grid_x = + std::min(max_grid_dim[0], DivUp(width, block_x)); + unsigned int grid_y = + std::min(max_grid_dim[1], DivUp(height, block_y)); + unsigned int grid_z = + std::min(max_grid_dim[2], DivUp(num_img, block_z * 4)); const int capability = context.GetComputeCapability(); GpuLaunchConfig config; diff --git a/paddle/phi/backends/gpu/gpu_resources.cc b/paddle/phi/backends/gpu/gpu_resources.cc index a29b5e110922a..f017bbe2b107e 100644 --- a/paddle/phi/backends/gpu/gpu_resources.cc +++ b/paddle/phi/backends/gpu/gpu_resources.cc @@ -51,7 +51,7 @@ void InitGpuProperties(Place place, int* multi_process, int* max_threads_per_mp, int* max_threads_per_block, - std::array* max_grid_dim_size) { + std::array* max_grid_dim_size) { backends::gpu::GPUDeviceGuard guard(place.GetDeviceId()); *compute_capability = backends::gpu::GetGPUComputeCapability(place.GetDeviceId()); diff --git a/paddle/phi/backends/gpu/gpu_resources.h b/paddle/phi/backends/gpu/gpu_resources.h index 7bec5eebf5886..f7fdc35653c28 100644 --- a/paddle/phi/backends/gpu/gpu_resources.h +++ b/paddle/phi/backends/gpu/gpu_resources.h @@ -27,7 +27,7 @@ void InitGpuProperties(Place place, int* multi_process, int* max_threads_per_mp, int* max_threads_per_block, - std::array* max_grid_dim_size); + std::array* max_grid_dim_size); void InitStream(gpuStream_t* stream); void DestoryStream(gpuStream_t stream); diff --git a/paddle/phi/backends/gpu/gpu_types.h b/paddle/phi/backends/gpu/gpu_types.h index fe4d6a6623a96..97f34de9a55a6 100644 --- a/paddle/phi/backends/gpu/gpu_types.h +++ b/paddle/phi/backends/gpu/gpu_types.h @@ -29,6 +29,9 @@ namespace phi { +// Note(qili93): CUDA Runtime API supported by HIP +// https://github.com/ROCm/HIPIFY/blob/master/doc/markdown/CUDA_Runtime_API_functions_supported_by_HIP.md + #ifdef PADDLE_WITH_HIP #define DECLARE_TYPE_FOR_GPU(GPU_TYPE, CUDA_TYPE, ROCM_TYPE) \ using GPU_TYPE = ROCM_TYPE; @@ -50,6 +53,20 @@ DECLARE_TYPE_FOR_GPU(dnnTensorFormat_t, DECLARE_TYPE_FOR_GPU(dnnActivationMode_t, cudnnActivationMode_t, miopenActivationMode_t); +DECLARE_TYPE_FOR_GPU(gpuGraph_t, cudaGraph_t, hipGraph_t); +DECLARE_TYPE_FOR_GPU(gpuFunction_t, cudaFunction_t, hipFunction_t); +DECLARE_TYPE_FOR_GPU(gpuGraphExec_t, cudaGraphExec_t, hipGraphExec_t); +DECLARE_TYPE_FOR_GPU(gpuGraphNode_t, cudaGraphNode_t, hipGraphNode_t); +DECLARE_TYPE_FOR_GPU(gpuGraphNodeType, cudaGraphNodeType, hipGraphNodeType); +DECLARE_TYPE_FOR_GPU(gpuKernelNodeParams, + cudaKernelNodeParams, + hipKernelNodeParams); +DECLARE_TYPE_FOR_GPU(gpuStreamCaptureMode, + cudaStreamCaptureMode, + hipStreamCaptureMode); +DECLARE_TYPE_FOR_GPU(gpuStreamCaptureStatus, + cudaStreamCaptureStatus, + hipStreamCaptureStatus); #undef DECLARE_TYPE_FOR_GPU @@ -76,8 +93,75 @@ DECLARE_CONSTANT_FOR_GPU(gpuMemcpyDeviceToHost, DECLARE_CONSTANT_FOR_GPU(gpuMemcpyDeviceToDevice, cudaMemcpyKind::cudaMemcpyDeviceToDevice, hipMemcpyKind::hipMemcpyDeviceToDevice); +DECLARE_CONSTANT_FOR_GPU(gpuEventDisableTiming, + cudaEventDisableTiming, + hipEventDisableTiming); +DECLARE_CONSTANT_FOR_GPU(gpuStreamNonBlocking, + cudaStreamNonBlocking, + hipStreamNonBlocking); +DECLARE_CONSTANT_FOR_GPU(gpuStreamCaptureModeThreadLocal, + cudaStreamCaptureModeThreadLocal, + hipStreamCaptureModeThreadLocal); +DECLARE_CONSTANT_FOR_GPU(gpuStreamCaptureModeRelaxed, + cudaStreamCaptureModeRelaxed, + hipStreamCaptureModeRelaxed); +DECLARE_CONSTANT_FOR_GPU(gpuStreamCaptureStatusActive, + cudaStreamCaptureStatusActive, + hipStreamCaptureStatusActive); +DECLARE_CONSTANT_FOR_GPU(gpuGraphNodeTypeKernel, + cudaGraphNodeTypeKernel, + hipGraphNodeTypeKernel); #undef DECLARE_CONSTANT_FOR_GPU + +#ifdef PADDLE_WITH_HIP +#define DECLARE_FUNCTION_FOR_GPU(GPU_FUNC, CUDA_FUNC, ROCM_FUNC) \ + const auto GPU_FUNC = ROCM_FUNC; +#else // PADDLE_WITH_CUDA +#define DECLARE_FUNCTION_FOR_GPU(GPU_FUNC, CUDA_FUNC, ROCM_FUNC) \ + const auto GPU_FUNC = CUDA_FUNC; +#endif + +DECLARE_FUNCTION_FOR_GPU(gpuGraphGetNodes, cudaGraphGetNodes, hipGraphGetNodes); +DECLARE_FUNCTION_FOR_GPU(gpuGraphGetEdges, cudaGraphGetEdges, hipGraphGetEdges); +DECLARE_FUNCTION_FOR_GPU(gpuGraphLaunch, cudaGraphLaunch, hipGraphLaunch); +DECLARE_FUNCTION_FOR_GPU(gpuGraphDestroy, cudaGraphDestroy, hipGraphDestroy); +DECLARE_FUNCTION_FOR_GPU(gpuGraphExecDestroy, + cudaGraphExecDestroy, + hipGraphExecDestroy); +DECLARE_FUNCTION_FOR_GPU(gpuGraphNodeGetType, + cudaGraphNodeGetType, + hipGraphNodeGetType); +DECLARE_FUNCTION_FOR_GPU(gpuGraphExecKernelNodeSetParams, + cudaGraphExecKernelNodeSetParams, + hipGraphExecKernelNodeSetParams); +DECLARE_FUNCTION_FOR_GPU(gpuGraphKernelNodeGetParams, + cudaGraphKernelNodeGetParams, + hipGraphKernelNodeGetParams); +DECLARE_FUNCTION_FOR_GPU(gpuStreamCreateWithPriority, + cudaStreamCreateWithPriority, + hipStreamCreateWithPriority); +DECLARE_FUNCTION_FOR_GPU(gpuStreamBeginCapture, + cudaStreamBeginCapture, + hipStreamBeginCapture); +DECLARE_FUNCTION_FOR_GPU(gpuStreamEndCapture, + cudaStreamEndCapture, + hipStreamEndCapture); +DECLARE_FUNCTION_FOR_GPU(gpuStreamGetCaptureInfo, + cudaStreamGetCaptureInfo, + hipStreamGetCaptureInfo); +DECLARE_FUNCTION_FOR_GPU(gpuEventCreateWithFlags, + cudaEventCreateWithFlags, + hipEventCreateWithFlags); +DECLARE_FUNCTION_FOR_GPU(gpuEventRecord, cudaEventRecord, hipEventRecord); +DECLARE_FUNCTION_FOR_GPU(gpuEventDestroy, cudaEventDestroy, hipEventDestroy); +DECLARE_FUNCTION_FOR_GPU(gpuEventQuery, cudaEventQuery, hipEventQuery); +DECLARE_FUNCTION_FOR_GPU(gpuEventSynchronize, + cudaEventSynchronize, + hipEventSynchronize); + +#undef DECLARE_FUNCTION_FOR_GPU + } // namespace phi #endif // defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.cc b/paddle/phi/backends/gpu/rocm/hip_graph.cc new file mode 100644 index 0000000000000..781cb41ae6983 --- /dev/null +++ b/paddle/phi/backends/gpu/rocm/hip_graph.cc @@ -0,0 +1,365 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" +#include "glog/logging.h" +#include "paddle/common/flags.h" + +COMMON_DECLARE_bool(use_cuda_malloc_async_allocator); +COMMON_DECLARE_bool(auto_free_cudagraph_allocations_on_launch); + +namespace phi { +namespace backends { +namespace gpu { + +std::unique_ptr CUDAGraph::capturing_graph_{nullptr}; +paddle::optional CUDAGraph::capturing_thread_id_{paddle::none}; + +static std::vector ToposortCUDAGraph(hipGraph_t graph) { + size_t num_nodes; + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphGetNodes(graph, nullptr, &num_nodes)); + std::vector nodes(num_nodes); + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphGetNodes(graph, nodes.data(), &num_nodes)); + + size_t num_edges; + PADDLE_ENFORCE_GPU_SUCCESS( + hipGraphGetEdges(graph, nullptr, nullptr, &num_edges)); + std::vector from(num_edges), to(num_edges); + PADDLE_ENFORCE_GPU_SUCCESS( + hipGraphGetEdges(graph, from.data(), to.data(), &num_edges)); + + std::unordered_map> + in_edges, out_edges; + for (auto node : nodes) { + in_edges[node]; + out_edges[node]; + } + + for (size_t i = 0; i < num_edges; ++i) { + in_edges[to[i]].insert(from[i]); + out_edges[from[i]].insert(to[i]); + } + + std::queue q; + for (const auto &pair : in_edges) { + if (pair.second.empty()) { + q.push(pair.first); + } + } + + nodes.clear(); + while (!q.empty()) { + auto cur = q.front(); + q.pop(); + nodes.push_back(cur); + + for (auto out_node : out_edges.at(cur)) { + auto &in_nodes = in_edges.at(out_node); + in_nodes.erase(cur); + if (in_nodes.empty()) { + q.push(out_node); + } + } + } + PADDLE_ENFORCE_EQ( + nodes.size(), + num_nodes, + phi::errors::InvalidArgument("Toposort error, this may be a bug.")); + return nodes; +} + +CUDAGraphID CUDAGraph::UniqueID() { + static std::atomic id; + return id.fetch_add(1); +} + +int64_t CUDAGraph::UniqueMemoryPoolID() { + static std::atomic id(CUDAGraph::kDefaultPoolID + 1); + return id.fetch_add(1); +} + +void CUDAGraph::Reset() { + if (is_reset_) return; +#if defined(PADDLE_WITH_HIP) + for (auto graph : graphs_) { + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphDestroy(graph)); + } + graphs_.clear(); + for (auto exec_graph : exec_graphs_) { + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphExecDestroy(exec_graph)); + } + exec_graphs_.clear(); +#endif + // callback should be called in reverse order because the latter added + // callback may rely on the former added callback. + for (auto iter = cudagraph_post_reset_callbacks_.rbegin(); + iter != cudagraph_post_reset_callbacks_.rend(); + ++iter) { + (*iter)(); + } + cudagraph_post_reset_callbacks_.clear(); + is_reset_ = true; +} + +void CUDAGraph::Replay() { +#if defined(PADDLE_WITH_HIP) + PADDLE_ENFORCE_EQ(is_reset_, + false, + phi::errors::PermissionDenied( + "Cannot replay the CUDA Graph after reset is called.")); + size_t n = exec_graphs_.size(); + for (size_t i = 0; i < n; ++i) { + if (!is_first_run_) { + for (auto &hook : cudagraph_pre_replay_callbacks_[i]) { + hook(exec_graphs_[i]); + } + } + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphLaunch(exec_graphs_[i], stream_)); + } + is_first_run_ = false; +#endif +} + +void CUDAGraph::BeginSegmentCapture() { + ThrowErrorIfNotSupportCUDAGraph(); +#if defined(PADDLE_WITH_HIP) + PADDLE_ENFORCE_EQ(IsCapturing(), + true, + phi::errors::PermissionDenied( + "BeginSegmentCapture should be called when CUDA " + "Graph is capturing.")); + if (IsThreadLocalCapturing()) { + PADDLE_ENFORCE_EQ(IsThisThreadCapturing(), + true, + phi::errors::PermissionDenied( + "When capturing CUDA Graph in the thread local mode, " + "you cannot begin segmented capturing in the thread " + "which is not the one that starts the capturing.")); + } + PADDLE_ENFORCE_GPU_SUCCESS(hipStreamBeginCapture( + capturing_graph_->stream_, capturing_graph_->capture_mode_)); + PADDLE_ENFORCE_EQ( + IsValidCapturing(), + true, + phi::errors::PermissionDenied("CUDA Graph should not be invalidated.")); + VLOG(10) << "Begin to capture CUDA Graph with ID " << capturing_graph_->id_ + << ", segment id " << capturing_graph_->graphs_.size() + << ", memory pool id " << capturing_graph_->pool_id_; +#endif +} + +void CUDAGraph::BeginCapture(phi::GPUPlace place, + gpuStream_t stream, + hipStreamCaptureMode mode) { + ThrowErrorIfNotSupportCUDAGraph(); +#if defined(PADDLE_WITH_HIP) + PADDLE_ENFORCE_EQ(IsCapturing(), + false, + phi::errors::PermissionDenied( + "CUDA Graph can only captured one by one.")); + PADDLE_ENFORCE_NOT_NULL( + stream, + phi::errors::PermissionDenied( + "CUDA Graph cannot be captured in default CUDA stream 0.")); + capturing_graph_.reset(new CUDAGraph()); + capturing_graph_->place_ = place; + capturing_graph_->stream_ = stream; + capturing_graph_->capture_mode_ = mode; + if (mode == hipStreamCaptureModeThreadLocal) { + capturing_thread_id_ = std::this_thread::get_id(); + VLOG(10) << "Capturing CUDA Graph in thread local mode, thread id: " + << capturing_thread_id_; + } + BeginSegmentCapture(); +#endif +} + +void CUDAGraph::EndSegmentCapture() { + ThrowErrorIfNotSupportCUDAGraph(); +#if defined(PADDLE_WITH_HIP) + PADDLE_ENFORCE_EQ( + IsCapturing(), + true, + phi::errors::PermissionDenied("No CUDA Graph is capturing.")); + hipGraph_t graph; + PADDLE_ENFORCE_GPU_SUCCESS( + hipStreamEndCapture(capturing_graph_->stream_, &graph)); + auto num_nodes = static_cast(-1); + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphGetNodes(graph, nullptr, &num_nodes)); + if (num_nodes == 0) { + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphDestroy(graph)); + VLOG(10) << "Skip empty CUDA Graph with ID " << capturing_graph_->id_ + << ", segment id " << capturing_graph_->graphs_.size() + << ", memory pool id " << capturing_graph_->pool_id_; + return; + } + + for (auto &cudagraph_post_capture_callback : + capturing_graph_->cudagraph_post_capture_callbacks_) { + cudagraph_post_capture_callback(); + } + capturing_graph_->cudagraph_post_capture_callbacks_.clear(); + + capturing_graph_->cudagraph_pre_replay_callbacks_.emplace_back( + CUDAGraphNodeLauncher::Instance().GetParameterSettersForExecGraph(graph)); + + // if forward graph is registered, this graph is a backward graph + // we check whether there is remain blocks that is unreleased by this + hipGraphExec_t exec_graph; + if (FLAGS_use_cuda_malloc_async_allocator && + FLAGS_auto_free_cudagraph_allocations_on_launch) { +#if defined(PADDLE_WITH_HIP) + VLOG(1) << "hipGraphInstantiateFlagAutoFreeOnLaunch is enabled!"; + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphInstantiateWithFlags( + &exec_graph, graph, hipGraphInstantiateFlagAutoFreeOnLaunch)); +#else + PADDLE_THROW(phi::errors::Unimplemented( + "The cudaGraphInstantiateFlagAutoFreeOnLaunch is only supported when " + "CUDA version >= 11.4.0")); +#endif + } else { +#if defined(PADDLE_WITH_HIP) + PADDLE_ENFORCE_GPU_SUCCESS( + hipGraphInstantiate(&exec_graph, graph, nullptr, nullptr, 0)); +#endif + } + VLOG(10) << "End to capture CUDA Graph with ID " << capturing_graph_->id_ + << ", segment id " << capturing_graph_->graphs_.size() + << ", memory pool id " << capturing_graph_->pool_id_; + capturing_graph_->graphs_.emplace_back(graph); + capturing_graph_->exec_graphs_.emplace_back(exec_graph); +#endif +} + +std::unique_ptr CUDAGraph::EndCapture() { + EndSegmentCapture(); + capturing_thread_id_ = paddle::none; + return std::move(capturing_graph_); +} + +bool CUDAGraph::IsValidCapturing() { +#if defined(PADDLE_WITH_HIP) + if (!IsCapturing()) return false; + hipStreamCaptureStatus status; + CUDAGraphID id; + PADDLE_ENFORCE_GPU_SUCCESS( + hipStreamGetCaptureInfo(capturing_graph_->stream_, &status, &id)); + return status == hipStreamCaptureStatusActive; +#else + return false; +#endif +} + +static std::string ConcatPath(const std::string &dirname, + const std::string &filename) { +#ifdef _WIN32 + const std::array kFileSep = {"\\"}; +#else + const std::array kFileSep = {"/"}; +#endif + if (!dirname.empty() && dirname.back() == kFileSep[0]) { + return dirname + filename; + } else { + return dirname + kFileSep.data() + filename; + } +} + +void CUDAGraph::PrintToDotFiles(const std::string &dirname, + unsigned int flags) { + ThrowErrorIfNotSupportCUDAGraph(); + PADDLE_THROW(phi::errors::Unimplemented( + "The print_to_dot_files() method is not supported on ROCm/HIP")); +} + +#if defined(PADDLE_WITH_HIP) +void CUDAGraphNodeLauncher::KernelNodeLaunch( + parameterSetter_t parameterSetter, gpuKernelCallback_t cudakernelCallback) { + if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) { + unsigned int id = GenerateIdentifier(); + auto cudaFunc = cudakernelCallback(id); + + parameterSetters[cudaFunc][id] = parameterSetter; + VLOG(10) << "[KernelNodeLaunch] Launch kernel with cudaFunc = " << cudaFunc + << " id = " << id; + } else { + cudakernelCallback(0); + } +} + +std::vector +CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(hipGraph_t graph) { + size_t num_nodes; + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphGetNodes(graph, nullptr, &num_nodes)); + std::vector nodes(num_nodes); + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphGetNodes(graph, nodes.data(), &num_nodes)); + + std::vector> hooks; + for (auto node : nodes) { + hipGraphNode_t gpuNode = node; + hipGraphNodeType pType; + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphNodeGetType(gpuNode, &pType)); + if (pType == hipGraphNodeTypeKernel) { + hipKernelNodeParams gpuParams; + PADDLE_ENFORCE_GPU_SUCCESS( + gpuGraphKernelNodeGetParams(gpuNode, &gpuParams)); + gpuKernelParams kernel_params(gpuParams.kernelParams); + auto kernel = + parameterSetters.find(static_cast(gpuParams.func)); + VLOG(10) << "[GetParameterSettersForExecGraph] gpuParams.func = " + << gpuParams.func; + // There exists a parameter setter + if (kernel != parameterSetters.end()) { + auto launchSequence = kernel->second; + unsigned int id = kernel_params.As(0); + + VLOG(10) << "[GetParameterSettersForExecGraph] Find launch kernel id = " + << id; + auto parameterSetter = launchSequence.find(id); + if (parameterSetter != launchSequence.end()) { + auto setter = parameterSetter->second; + hooks.emplace_back( + [setter, gpuNode, gpuParams](hipGraphExec_t exec_graph) { + gpuKernelParams kernel_params(gpuParams.kernelParams); + setter(kernel_params); + PADDLE_ENFORCE_GPU_SUCCESS(hipGraphExecKernelNodeSetParams( + exec_graph, gpuNode, &gpuParams)); + }); + } else { + PADDLE_THROW( + phi::errors::InvalidArgument("Error: does not find launch id")); + } + } + } + } + + return hooks; +} +#else +void CUDAGraphNodeLauncher::KernelNodeLaunch( + hipFunction_t cudaFunc, + parameterSetter_t parameterSetter, + gpuKernelCallback_t cudakernelCallback) { + cudakernelCallback(0); +} + +std::vector +CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(hipGraph_t graph) { + PADDLE_THROW(phi::errors::Unimplemented( + "CUDAGraphNodeLauncher is only supported when CUDA version >= 11.0")); +} +#endif + +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/backends/gpu/rocm/hip_graph.h b/paddle/phi/backends/gpu/rocm/hip_graph.h new file mode 100644 index 0000000000000..cb92275227254 --- /dev/null +++ b/paddle/phi/backends/gpu/rocm/hip_graph.h @@ -0,0 +1,393 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/common/errors.h" +#include "paddle/common/macros.h" +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/backends/device_code.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/common/place.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/utils/optional.h" + +namespace phi { +namespace backends { +namespace gpu { + +class CUDAGraphContextManager { + public: + using DeviceContextMap = + std::map>>; + + static CUDAGraphContextManager &Instance() { + static CUDAGraphContextManager *cuda_graph_ctx_manager = + new CUDAGraphContextManager; + return *cuda_graph_ctx_manager; + } + + DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) { + std::lock_guard lk(ctx_mtx_); + VLOG(6) << "Get cuda graph device context for " << place; + + DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id]; + if (ctxs.find(place) == ctxs.end()) { + phi::memory_utils::EmplaceDeviceContexts( + &ctxs, + {place}, + /*disable_setting_default_stream_for_allocator=*/true, + stream_priority); + } + return ctxs[place].get().get(); + } + + void RecordCapturingDeviceContext(DeviceContext *dev_ctx) { + capturing_ctxs_.insert(dev_ctx); + } + + std::set GetAllCapturingDeviceContexts() const { + return capturing_ctxs_; + } + + void ClearDeviceContextsRecords() { capturing_ctxs_.clear(); } + + private: + CUDAGraphContextManager() {} + DISABLE_COPY_AND_ASSIGN(CUDAGraphContextManager); + + std::mutex ctx_mtx_; + std::unordered_map cuda_graph_ctx_pool_; + std::set capturing_ctxs_; +}; + +class gpuKernelParams { + public: + explicit gpuKernelParams(void **params) : kernelParams(params) {} + + template + T &As(size_t idx) const { + return *reinterpret_cast(kernelParams[idx]); + } + + void **getParams() const { return kernelParams; } + + private: + void **kernelParams; +}; + +using cudaGraphExecuterSetter_t = std::function; + +// ** class CUDAGraphNodeLauncher +// +// This class offers a interface for launching CUDA kernels in CUDA Graph, we +// utilize the `cudaGraphExecKernelNodeSetParams` function for parameter setup. +// Launching kernels via this class ensures proper management. +// +// NOTE: It's essential that the first parameter for any kernel launched +// through this class is an `unsigned int` identifier. This identifier plays a +// crucial role in linking the CUDA kernel to its corresponding CUDA graph +// node. We tag each kernel launch with a unique identifier to maintain +// structured linkage with its CUDA graph node. +// +// NOTE: This class use a singleton design pattern ensures there's only a +// single global instance accessible via the `Instance()` method. +class CUDAGraphNodeLauncher { + public: + // [Parameter Setter Callback] + // Sets the kernel's parameters BEFORE activating the CUDA graph. It enables + // dynamic determination and setup of kernel arguments. + // + // parameterSetter_t parameterSetter = [saved_state](gpuKernelParams + // ¶m){ + // // Code to compute and the parameter values from the saved_state + // // ... + // param.As(idx) = calculated_value; + // }; + using parameterSetter_t = std::function; + + // [CUDA Kernel Callback] + // Acts as the launcher for the kernel. It accepts an `unsigned int` + // identifier and uses it for the kernel launch. + // The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t` + // reference of the kernel from the kernel pointer. + // gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { + // // cudaFunction_t is REQUIRED to get here + // cudaFunction_t cudaFunc; + // PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel)); + // + // kernel<<<>>>(id, ...); // Launching the kernel with id + // return cudaFunc; + // }; + using gpuKernelCallback_t = std::function; + + // [Kernel Launch] + // With the callbacks defined and the CUDA function obtained, the kernel can + // be launched using the `KernelNodeLaunch` method. + void KernelNodeLaunch(parameterSetter_t parameterSetter, + gpuKernelCallback_t cudakernelCallback); + + std::vector GetParameterSettersForExecGraph( + hipGraph_t graph); + + parameterSetter_t GetParameterSetter(const gpuKernelParams ¶ms); + + static CUDAGraphNodeLauncher &Instance() { + static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher; + return *launcher; + } + + private: + CUDAGraphNodeLauncher() : id(0) {} + DISABLE_COPY_AND_ASSIGN(CUDAGraphNodeLauncher); + + unsigned int GenerateIdentifier() { return id++; } + + unsigned int id; + std::unordered_map> + parameterSetters; +}; + +#if defined(PADDLE_WITH_HIP) +static void ThrowErrorIfNotSupportCUDAGraph() {} +#else +enum gpuStreamCaptureMode { + hipStreamCaptureModeGlobal = 0, + hipStreamCaptureModeThreadLocal = 1, + hipStreamCaptureModeRelaxed = 2 +}; +static void ThrowErrorIfNotSupportCUDAGraph() { + PADDLE_THROW(phi::errors::Unimplemented( + "CUDA Graph is only supported when CUDA version >= 10.1")); +} +#endif + +using CUDAGraphID = unsigned long long; // NOLINT + +// NOTE: Currently, we do not support to capture CUDA graph in parallel +// NOTE: Do not use this class directly because it should be used with +// the memory pool. +class CUDAGraph { + DISABLE_COPY_AND_ASSIGN(CUDAGraph); + + // Since the constructor would throw error is CUDA_VERSION < 10010. + // The non-static method of CUDAGraph need not check CUDA_VERSION + // again. + CUDAGraph() { + ThrowErrorIfNotSupportCUDAGraph(); + id_ = UniqueID(); + } + + public: + static constexpr int64_t kDefaultPoolID = 0; + static constexpr int64_t kInvalidPoolID = -1; + + ~CUDAGraph() { Reset(); } + + CUDAGraphID ID() const { return id_; } + + static int64_t SetMemoryPoolID(int64_t pool_id) { + auto &pool_id_ = capturing_graph_->pool_id_; + PADDLE_ENFORCE_EQ( + pool_id_, + kInvalidPoolID, + phi::errors::InvalidArgument("Cannot reset memory pool id twice, the " + "former memory pool id is %d.", + pool_id_)); + if (pool_id <= kInvalidPoolID) { + pool_id_ = UniqueMemoryPoolID(); + } else { + PADDLE_ENFORCE_GE( + pool_id, + kDefaultPoolID, + phi::errors::InvalidArgument("Invalid memory pool id %d.", pool_id)); + pool_id_ = pool_id; + } + return pool_id_; + } + + int64_t PoolID() const { return pool_id_; } + + static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; } + + void Replay(); + + void Reset(); + + void AddPostResetCallback(std::function callback) { + std::lock_guard guard(mtx_); + cudagraph_post_reset_callbacks_.push_back(std::move(callback)); + } + + void AddPostCaptureCallback(std::function callback) { + std::lock_guard guard(mtx_); + cudagraph_post_capture_callbacks_.push_back(std::move(callback)); + } + + void PrintToDotFiles(const std::string &dirname, unsigned int flags); + + static void BeginCapture(phi::GPUPlace place, + gpuStream_t stream, + gpuStreamCaptureMode mode); + static std::unique_ptr EndCapture(); + + static void BeginSegmentCapture(); + static void EndSegmentCapture(); + + static void AddPostResetCallbackDuringCapturing( + std::function callback) { + capturing_graph_->AddPostResetCallback(std::move(callback)); + } + + static void AddPostCaptureCallbackDuringCapturing( + std::function callback) { + capturing_graph_->AddPostCaptureCallback(std::move(callback)); + } + + // No need to add CUDA_VERSION macro because capturing_graph_ would + // always be nullptr (constructor throws error) + static bool IsCapturing() { return capturing_graph_ != nullptr; } + + static CUDAGraphID CapturingID() { return capturing_graph_->id_; } + + static phi::GPUPlace CapturingPlace() { return capturing_graph_->place_; } + + // This API can be used to debug which GPU operation is not + // supported during capturing CUDA Graph. + static bool IsValidCapturing(); + + static bool IsThreadLocalCapturing() { +#if defined(PADDLE_WITH_HIP) + return IsCapturing() && + capturing_graph_->capture_mode_ == hipStreamCaptureModeThreadLocal; +#else + return false; +#endif + } + + static bool IsThisThreadCapturing() { + if (UNLIKELY(IsCapturing())) { + return IsThreadLocalCapturing() + ? capturing_thread_id_.get() == std::this_thread::get_id() + : true; + } else { + return false; + } + } + + using SetSeedFunc = std::function; + static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) { + std::lock_guard guard(capturing_graph_->func_mtx_); + capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func)); + } + + static int64_t UniqueMemoryPoolID(); + + private: + static CUDAGraphID UniqueID(); + + private: +#if defined(PADDLE_WITH_HIP) + std::vector graphs_; + std::vector exec_graphs_; + gpuStreamCaptureMode capture_mode_; +#endif + gpuStream_t stream_{nullptr}; + phi::GPUPlace place_; + CUDAGraphID id_; + int64_t pool_id_{kInvalidPoolID}; + bool is_reset_{false}; + std::mutex mtx_; + + std::vector set_seed_funcs_; + + // Holds callbacks that are triggered after the CUDA graph is reset. These + // callbacks are used for operations that need to be performed following the + // reset of a CUDA graph. + std::vector> cudagraph_post_reset_callbacks_; + + // Contains callbacks that are invoked after the CUDA graph has been captured. + // These callbacks are crucial for managing memory allocations related to the + // CUDA graph. They ensure that memory blocks not associated with a graph (as + // detailed in cuda_malloc_async_allocator) are not erroneously released + // during the graph's lifecycle. + std::vector> cudagraph_post_capture_callbacks_; + + // Maintains a collection of 'pre-hooks' - functions that are executed before + // the CUDA graph is replayed. These pre-hooks are essential for setting up + // the necessary conditions or states required for the correct execution of + // the CUDA graph. + std::vector> + cudagraph_pre_replay_callbacks_; + + std::mutex func_mtx_; + + bool is_first_run_{true}; + + static paddle::optional capturing_thread_id_; + static std::unique_ptr capturing_graph_; +}; + +#if defined(PADDLE_WITH_HIP) +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard( + gpuStreamCaptureMode mode = hipStreamCaptureModeRelaxed) { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_GPU_SUCCESS(hipThreadExchangeStreamCaptureMode(&mode)); + // After cudaThreadExchangeStreamCaptureMode is called, + // the variable "mode" would be set to the old capturing mode. + old_mode_ = mode; + } + } + + ~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW { + if (UNLIKELY(CUDAGraph::IsCapturing())) { + PADDLE_ENFORCE_GPU_SUCCESS( + hipThreadExchangeStreamCaptureMode(&old_mode_)); + } + } + + private: + gpuStreamCaptureMode old_mode_; +}; +#else +class CUDAGraphCaptureModeGuard { + DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard); + + public: + explicit CUDAGraphCaptureModeGuard( + gpuStreamCaptureMode mode = hipStreamCaptureModeRelaxed) {} +}; +#endif + +} // namespace gpu +} // namespace backends +} // namespace phi diff --git a/paddle/phi/backends/gpu/rocm/rocm_info.cc b/paddle/phi/backends/gpu/rocm/rocm_info.cc index edc23479c9238..b8ddea98b5c9e 100644 --- a/paddle/phi/backends/gpu/rocm/rocm_info.cc +++ b/paddle/phi/backends/gpu/rocm/rocm_info.cc @@ -173,7 +173,7 @@ int GetCurrentDeviceId() { return device_id; } -std::array GetGpuMaxGridDimSize(int id) { +std::array GetGpuMaxGridDimSize(int id) { PADDLE_ENFORCE_LT( id, GetGPUDeviceCount(), @@ -181,7 +181,7 @@ std::array GetGpuMaxGridDimSize(int id) { "but received id is: %d. GPU count is: %d.", id, GetGPUDeviceCount())); - std::array ret; + std::array ret; int size; auto error_code_x = hipDeviceGetAttribute(&size, hipDeviceAttributeMaxGridDimX, id); diff --git a/paddle/phi/backends/onednn/onednn_helper.h b/paddle/phi/backends/onednn/onednn_helper.h index 60c531c7b7443..82fd76e725a3b 100644 --- a/paddle/phi/backends/onednn/onednn_helper.h +++ b/paddle/phi/backends/onednn/onednn_helper.h @@ -220,8 +220,9 @@ inline std::string CreateKey(const OneDNNContext& dev_ctx UNUSED, ArgTypes&&... args) { std::string key; key.reserve(64); - using expand_type = int[]; - expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; + // using expand_type = int[]; + // expand_type{0, (AppendKey(&key, std::forward(args)), 0)...}; + ((void)AppendKey(&key, std::forward(args)), ...); key += OneDNNContext::tls().get_key_suffix(); return key; } diff --git a/paddle/phi/backends/xpu/enforce_xpu.h b/paddle/phi/backends/xpu/enforce_xpu.h index e4fc15f4cb747..e89857728da25 100644 --- a/paddle/phi/backends/xpu/enforce_xpu.h +++ b/paddle/phi/backends/xpu/enforce_xpu.h @@ -92,7 +92,7 @@ inline const char* xpuGetErrorString(int stat) { case XPUERR_INTERRUPTED: return "Execution interrupted by user"; default: - return "unknown error"; + return "Unknown error"; } } diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 55aae9f24c1a6..07972469a32b1 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -448,6 +448,29 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::FLOAT32})}, {"flip", XPUKernelSet({phi::DataType::FLOAT32})}, + {"full", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT64, + phi::DataType::FLOAT16})}, + {"full_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT16})}, + {"full_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::FLOAT16})}, + {"full_batch_size_like", + XPUKernelSet({phi::DataType::INT64, + phi::DataType::INT32, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::FLOAT16})}, {"full_batch_size_like", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, @@ -1174,10 +1197,14 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_gemm_epilogue_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_bias_residual_layernorm", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_attention", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_attention_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"fused_bias_act", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_feedforward", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"fused_feedforward_grad", @@ -1196,6 +1223,12 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32})}, {"sine_pos_xpu", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"roformer_relative_embedding_xpu", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"variable_length_memory_efficient_attention", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"flash_attn_unpadded", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, }; return s_xpu2_kernels; diff --git a/paddle/phi/backends/xpu/xpu3_op_list.cc b/paddle/phi/backends/xpu/xpu3_op_list.cc index 39e79ba0c4934..48dc5d8334193 100644 --- a/paddle/phi/backends/xpu/xpu3_op_list.cc +++ b/paddle/phi/backends/xpu/xpu3_op_list.cc @@ -1048,8 +1048,10 @@ XPUOpMap& get_kl3_ops() { phi::DataType::INT64, phi::DataType::BOOL, phi::DataType::FLOAT64, - phi::DataType::FLOAT32})}, - {"tile_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + phi::DataType::FLOAT32, + phi::DataType::BFLOAT16})}, + {"tile_grad", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::BFLOAT16})}, {"transpose2_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index 9de9744393d4a..050ed1693220b 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -31,31 +31,16 @@ namespace xpu = baidu::xpu::api; namespace phi { struct XPUContext::Impl { - void SetL3Cache(int l3_size = 14155776) { - const int MAX_XPU_NUM = 16; - static void* l3ptrs[MAX_XPU_NUM] = {nullptr}; - - if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) { - l3_size = atoi(std::getenv("XPU_PADDLE_L3_SIZE")); - } - - auto selected_xpus = backends::xpu::GetXPUSelectedDevices(); - for (unsigned int i = 0; i < selected_xpus.size(); i++) { - if (place_.GetDeviceId() == selected_xpus[i]) { - if (l3ptrs[place_.GetDeviceId()] != nullptr) { - xpu_free(l3ptrs[place_.GetDeviceId()]); - l3ptrs[place_.GetDeviceId()] = nullptr; - } - xpu_malloc(static_cast(&l3ptrs[place_.GetDeviceId()]), - l3_size, - XPU_MEM_L3); - if (l3ptrs[place_.GetDeviceId()] != nullptr) { - context_->_l3_mgr.set(l3ptrs[place_.GetDeviceId()], l3_size); - VLOG(3) << "xpu place " << static_cast(place_.GetDeviceId()) - << " set l3 size " << l3_size; - } - break; - } + void SetL3Cache(int64_t l3_size = 1024) { + PADDLE_ENFORCE_XPU_SUCCESS(xpu_wait(context_->xpu_stream)); + context_->_l3_mgr.set(nullptr, 0, true); // free origin l3 + void* l3_ptr = nullptr; + xpu_malloc(static_cast(&l3_ptr), l3_size, XPU_MEM_L3); + + if (l3_ptr != nullptr) { + VLOG(3) << "xpu place " << static_cast(place_.GetDeviceId()) + << "context " << context_ << " set l3 size " << l3_size; + context_->_l3_mgr.set(l3_ptr, l3_size, true); } } @@ -145,23 +130,26 @@ struct XPUContext::Impl { } } - void Init() { + void Init(int64_t gm_default_size = 1024, int64_t l3_default_size = 1024) { owned_ = true; backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << static_cast(place_.device); + context_ = xpu::create_context(); - // Setup XPU GM Buffer - if (std::getenv("XPUAPI_DEFAULT_SIZE") != nullptr) { - context_->set_option("XPUAPI_DEFAULT_SIZE", - std::getenv("XPUAPI_DEFAULT_SIZE")); - } else { - // Optimization described in - // https://github.com/PaddlePaddle/Paddle/pull/54674 - context_->set_option("XPUAPI_DEFAULT_SIZE", "1"); + context_->set_option("XPUAPI_DEFAULT_SIZE", + std::to_string(gm_default_size).c_str()); + VLOG(3) << "xpu place " << static_cast(place_.GetDeviceId()) + << "context " << context_ << " set xpuapi_default_size " + << gm_default_size; + + if (std::getenv("XPU_CDNN_CLUSTER_PARALLEL") != nullptr) { + XPUStream s; + xpu_stream_create(&s); + context_->set_stream(s); } xpu_version_ = backends::xpu::get_xpu_version(place_.device); - SetL3Cache(); + SetL3Cache(l3_default_size); } void SetXContext(xpu::Context* context) { @@ -234,58 +222,117 @@ struct XPUContext::Impl { xpu::BKCLContext_t bkcl_context_{nullptr}; }; -XPUContext::XPUContext() : DeviceContext(), impl_(std::make_unique()) { - impl_->Init(); +static int64_t get_gm_size(int i) { + int64_t default_size = 1024; + if (std::getenv("XPUAPI_DEFAULT_SIZE") != nullptr) { + default_size = std::atoll(std::getenv("XPUAPI_DEFAULT_SIZE")); + } + std::string cur_env = std::string("XPUAPI_DEFAULT_SIZE") + std::to_string(i); + if (std::getenv(cur_env.c_str()) != nullptr) { + default_size = std::atoll(std::getenv(cur_env.c_str())); + } + return default_size; } -XPUContext::XPUContext(const XPUPlace& place) - : DeviceContext(), impl_(std::make_unique(place)) { - impl_->Init(); +static int64_t get_l3_size(int i) { + int64_t default_size = 1024; + if (std::getenv("XPU_PADDLE_L3_SIZE") != nullptr) { + default_size = std::atoll(std::getenv("XPU_PADDLE_L3_SIZE")); + } + std::string cur_env = std::string("XPU_PADDLE_L3_SIZE") + std::to_string(i); + if (std::getenv(cur_env.c_str()) != nullptr) { + default_size = std::atoll(std::getenv(cur_env.c_str())); + } + return default_size; +} + +XPUContext::XPUContext() : DeviceContext() { + if (std::getenv("XPU_CDNN_CLUSTER_PARALLEL") != nullptr) { + int default_num_stream = 4; + if (std::getenv("XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER") != nullptr) { + default_num_stream = + atoi(std::getenv("XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER")); + } + for (int i = 0; i < default_num_stream; i++) { + impls_.push_back(std::make_unique()); + impls_[i]->Init(get_gm_size(i), get_l3_size(i)); + } + } else { + impls_.push_back(std::make_unique()); + impls_[0]->Init(get_gm_size(0), get_l3_size(0)); + } +} + +XPUContext::XPUContext(const XPUPlace& place) : DeviceContext() { + if (std::getenv("XPU_CDNN_CLUSTER_PARALLEL") != nullptr) { + int default_num_stream = 4; + if (std::getenv("XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER") != nullptr) { + default_num_stream = + atoi(std::getenv("XPU_CDNN_CLUSTER_PARALLEL_STREAM_NUMBER")); + } + for (int i = 0; i < default_num_stream; i++) { + impls_.push_back(std::make_unique(place)); + impls_[i]->Init(get_gm_size(i), get_l3_size(i)); + } + } else { + impls_.push_back(std::make_unique(place)); + impls_[0]->Init(get_gm_size(0), get_l3_size(0)); + } } XPUContext::~XPUContext() = default; -const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); } +const Place& XPUContext::GetPlace() const { return impls_[0]->GetPlace(); } -XPUStream XPUContext::stream() const { return impl_->stream(); } +XPUStream XPUContext::stream(int i) const { return impls_[i]->stream(); } -void XPUContext::SetStream(void* stream) { impl_->SetStream(stream); } +void XPUContext::SetStream(void* stream, int i) { + impls_[i]->SetStream(stream); +} void XPUContext::SetXpuVersion(int version) { - impl_->xpu_version_ = static_cast(version); + impls_[0]->xpu_version_ = static_cast(version); } void XPUContext::SetRuntimeVersion(int version) { - impl_->runtime_version_ = version; + impls_[0]->runtime_version_ = version; } void XPUContext::SetDriverVersion(int version) { - impl_->driver_version_ = version; + impls_[0]->driver_version_ = version; } backends::xpu::XPUVersion XPUContext::xpu_version() const { - return impl_->xpu_version_; + return impls_[0]->xpu_version_; } -xpu::Context* XPUContext::x_context() const { return impl_->GetXContext(); } +xpu::Context* XPUContext::x_context(int i) const { + return impls_[i]->GetXContext(); +} xpu::BKCLContext_t XPUContext::bkcl_context() const { - return impl_->GetBkclContext(); + return impls_[0]->GetBkclContext(); } -void XPUContext::Wait() const { impl_->Wait(); } +void XPUContext::Wait() const { + for (uint64_t i = 0; i < impls_.size(); i++) { + impls_[i]->Wait(); + } +} -void XPUContext::SetXContext(xpu::Context* context) { - impl_->SetXContext(context); +void XPUContext::SetXContext(xpu::Context* context, int i) { + impls_[i]->SetXContext(context); } -void XPUContext::SetL3Cache(int l3_size) { impl_->SetL3Cache(l3_size); } +void XPUContext::SetL3Cache(int64_t l3_size, int i) { + impls_[i]->SetL3Cache(l3_size); +} void XPUContext::SetBkclContext(xpu::BKCLContext_t context) { - impl_->SetBkclContext(context); + impls_[0]->SetBkclContext(context); } -void XPUContext::CreateStream() { impl_->CreateStream(); } +void XPUContext::CreateStream(int i) { impls_[i]->CreateStream(); } -void XPUContext::Init() { impl_->Init(); } +void XPUContext::Init() { impls_[0]->Init(); } } // namespace phi diff --git a/paddle/phi/backends/xpu/xpu_context.h b/paddle/phi/backends/xpu/xpu_context.h index 3e734a064b916..59dfb0c137832 100644 --- a/paddle/phi/backends/xpu/xpu_context.h +++ b/paddle/phi/backends/xpu/xpu_context.h @@ -17,6 +17,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include +#include #include "paddle/phi/backends/xpu/forwards.h" #include "paddle/phi/backends/xpu/xpu_header.h" @@ -45,15 +46,15 @@ class XPUContext : public DeviceContext, backends::xpu::XPUVersion xpu_version() const; - xpu::Context* x_context() const; + xpu::Context* x_context(int i = 0) const; // Return bkcl context. xpu::BKCLContext_t bkcl_context() const; void SetBkclContext(xpu::BKCLContext_t context); - void CreateStream(); + void CreateStream(int i = 0); // For share external stream. - void SetStream(void* stream); + void SetStream(void* stream, int i = 0); // Wait for all operations completion in the stream. void Wait() const override; @@ -68,9 +69,9 @@ class XPUContext : public DeviceContext, // NOTE: External users manage resources. Used in inference scenarios. // The Set interface is for inference only, DeviceContext will mark the // resource as external, and will not delete any resource when destructing. - void SetXContext(xpu::Context*); + void SetXContext(xpu::Context*, int i = 0); - void SetL3Cache(int l3_size = 14155776); + void SetL3Cache(int64_t l3_size = 1024, int i = 0); void SetXpuVersion(int version); @@ -80,13 +81,13 @@ class XPUContext : public DeviceContext, Eigen::DefaultDevice* eigen_device() const { return nullptr; } - XPUStream stream() const; + XPUStream stream(int i = 0) const; static const char* name() { return "XPUContext"; } private: struct Impl; - std::unique_ptr impl_; + std::vector> impls_; }; // KPS (Kernel PrimitiveS API) needs to exist as a kind of backend, diff --git a/paddle/phi/backends/xpu/xpu_l3_strategy.cc b/paddle/phi/backends/xpu/xpu_l3_strategy.cc index eab256a3edaa1..a117a9b88beaf 100644 --- a/paddle/phi/backends/xpu/xpu_l3_strategy.cc +++ b/paddle/phi/backends/xpu/xpu_l3_strategy.cc @@ -14,12 +14,14 @@ limitations under the License. */ #include "paddle/phi/backends/xpu/xpu_l3_strategy.h" #include "glog/logging.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" namespace phi { void XPUL3CacheBlock::Set(void* addr, size_t size) { if (addr == nullptr || size == 0) { - LOG(FATAL) << "Set XPUL3CacheBlock Size as Zero"; + PADDLE_THROW( + phi::errors::InvalidArgument("Set XPUL3CacheBlock Size as Zero")); } addr_ = addr; size_ = size; diff --git a/paddle/phi/capi/include/c_meta_tensor.h b/paddle/phi/capi/include/c_meta_tensor.h index 08f01084c6abf..f4c9a541e526a 100644 --- a/paddle/phi/capi/include/c_meta_tensor.h +++ b/paddle/phi/capi/include/c_meta_tensor.h @@ -39,6 +39,13 @@ int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor, size_t index, PD_Status *status); +int64_t PD_MetaTensorGetNumStrides(const PD_MetaTensor *tensor, + PD_Status *status); + +int64_t PD_MetaTensorGetStride(const PD_MetaTensor *tensor, + size_t index, + PD_Status *status); + bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status); void PD_MetaTensorSetDims(PD_MetaTensor *tensor, @@ -46,6 +53,11 @@ void PD_MetaTensorSetDims(PD_MetaTensor *tensor, const int64_t *dims, PD_Status *status); +void PD_MetaTensorSetStrides(PD_MetaTensor *tensor, + int64_t nstrides, + const int64_t *strides, + PD_Status *status); + void PD_MetaTensorSetDataType(PD_MetaTensor *tensor, PD_DataType dtype, PD_Status *status); diff --git a/paddle/phi/capi/include/c_tensor.h b/paddle/phi/capi/include/c_tensor.h index c4f706c70ccfb..2df292c6b946b 100644 --- a/paddle/phi/capi/include/c_tensor.h +++ b/paddle/phi/capi/include/c_tensor.h @@ -41,6 +41,12 @@ int64_t PD_TensorGetDim(const PD_Tensor *tensor, size_t index, PD_Status *status); +int64_t PD_TensorGetNumStrides(const PD_Tensor *tensor, PD_Status *status); + +int64_t PD_TensorGetStride(const PD_Tensor *tensor, + size_t index, + PD_Status *status); + void PD_TensorGetLoD(const PD_Tensor *tensor, PD_List *data, PD_List *offset, @@ -52,11 +58,22 @@ bool PD_TensorIsValid(const PD_Tensor *tensor, PD_Status *status); void *PD_TensorGetHolder(const PD_Tensor *tensor, PD_Status *status); +size_t PD_TensorGetOffset(const PD_Tensor *tensor, PD_Status *status); + void PD_TensorSetDims(PD_Tensor *tensor, int64_t ndims, const int64_t *dims, PD_Status *status); +void PD_TensorSetOffset(PD_Tensor *tensor, + const int64_t offset, + PD_Status *status); + +void PD_TensorSetStrides(PD_Tensor *tensor, + int64_t nstrides, + const int64_t *strides, + PD_Status *status); + void PD_TensorSetDataType(PD_Tensor *tensor, PD_DataType dtype, PD_Status *status); diff --git a/paddle/phi/capi/include/wrapper_base.h b/paddle/phi/capi/include/wrapper_base.h index 061561008a95e..75f3e2d9e350e 100644 --- a/paddle/phi/capi/include/wrapper_base.h +++ b/paddle/phi/capi/include/wrapper_base.h @@ -72,6 +72,19 @@ inline std::vector PD_TensorGetDims(PD_Tensor* tensor, return std::vector(); } +inline std::vector PD_TensorGetStrides(PD_Tensor* tensor, + PD_Status* status) { + int64_t nstrides = PD_TensorGetNumStrides(tensor, status); + if (nstrides > 0) { + std::vector shape(nstrides); + for (int64_t i = 0; i < nstrides; ++i) { + shape[i] = PD_TensorGetStride(tensor, i, status); + } + return shape; + } + return std::vector(); +} + inline std::vector PD_MetaTensorGetDims(PD_MetaTensor* tensor, PD_Status* status) { int64_t ndims = PD_MetaTensorGetNumDims(tensor, status); @@ -85,6 +98,19 @@ inline std::vector PD_MetaTensorGetDims(PD_MetaTensor* tensor, return std::vector(); } +inline std::vector PD_MetaTensorGetStrides(PD_MetaTensor* tensor, + PD_Status* status) { + int64_t nstrides = PD_MetaTensorGetNumStrides(tensor, status); + if (nstrides > 0) { + std::vector shape(nstrides); + for (int64_t i = 0; i < nstrides; ++i) { + shape[i] = PD_MetaTensorGetStride(tensor, i, status); + } + return shape; + } + return std::vector(); +} + template class WrapperBase { public: @@ -134,6 +160,13 @@ class DenseTensor : public WrapperBase { return holder; } + size_t offset() const { + C_Status status; + auto offset = PD_TensorGetOffset(raw_data(), &status); + PD_CHECK_STATUS(status); + return offset; + } + std::vector dims() const { C_Status status; auto dimension = PD_TensorGetDims(raw_data(), &status); @@ -141,6 +174,13 @@ class DenseTensor : public WrapperBase { return dimension; } + std::vector strides() const { + C_Status status; + auto strides = PD_TensorGetStrides(raw_data(), &status); + PD_CHECK_STATUS(status); + return strides; + } + PD_DataType dtype() const { C_Status status; auto data_type = PD_TensorGetPDDataType(raw_data(), &status); @@ -207,6 +247,18 @@ class DenseTensor : public WrapperBase { PD_CHECK_STATUS(status); } + void set_offset(const int64_t& offset) { + C_Status status; + PD_TensorSetOffset(raw_data(), offset, &status); + PD_CHECK_STATUS(status); + } + + void set_strides(const std::vector& strides) { + C_Status status; + PD_TensorSetStrides(raw_data(), strides.size(), strides.data(), &status); + PD_CHECK_STATUS(status); + } + void set_dtype(PD_DataType data_type) { C_Status status; PD_TensorSetDataType(raw_data(), data_type, &status); @@ -513,6 +565,13 @@ class MetaTensor : WrapperBase { return dimension; } + std::vector strides() const { + C_Status status; + auto strides = PD_MetaTensorGetStrides(raw_data(), &status); + PD_CHECK_STATUS(status); + return strides; + } + PD_DataType dtype() const { C_Status status; auto data_type = PD_MetaTensorGetPDDataType(raw_data(), &status); @@ -540,6 +599,13 @@ class MetaTensor : WrapperBase { PD_CHECK_STATUS(status); } + void set_strides(const std::vector& strides) { + C_Status status; + PD_MetaTensorSetStrides( + raw_data(), strides.size(), strides.data(), &status); + PD_CHECK_STATUS(status); + } + void set_dtype(PD_DataType data_type) { C_Status status; PD_MetaTensorSetDataType(raw_data(), data_type, &status); diff --git a/paddle/phi/capi/lib/c_meta_tensor.cc b/paddle/phi/capi/lib/c_meta_tensor.cc index 6ea6eda1a7f23..f436ba9d3cde0 100644 --- a/paddle/phi/capi/lib/c_meta_tensor.cc +++ b/paddle/phi/capi/lib/c_meta_tensor.cc @@ -88,6 +88,36 @@ int64_t PD_MetaTensorGetDim(const PD_MetaTensor *tensor, return cc_tensor->dims()[index]; } +int64_t PD_MetaTensorGetNumStrides(const PD_MetaTensor *tensor, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + return cc_tensor->strides().size(); +} + +int64_t PD_MetaTensorGetStride(const PD_MetaTensor *tensor, + size_t index, + PD_Status *status) { + auto cc_tensor = reinterpret_cast(tensor); + + if (status) { + if (!tensor || index >= static_cast(cc_tensor->strides().size())) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + return cc_tensor->strides()[index]; +} + bool PD_MetaTensorIsValid(const PD_MetaTensor *tensor, PD_Status *status) { if (status) { if (!tensor) { @@ -117,6 +147,22 @@ void PD_MetaTensorSetDims(PD_MetaTensor *tensor, cc_tensor->set_dims(common::make_ddim(shape)); } +void PD_MetaTensorSetStrides(PD_MetaTensor *tensor, + int64_t nstrides, + const int64_t *strides, + PD_Status *status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return; + } + *status = C_SUCCESS; + } + auto cc_tensor = reinterpret_cast(tensor); + std::vector shape(strides, strides + nstrides); + cc_tensor->set_strides(common::make_ddim(shape)); +} + void PD_MetaTensorSetDataType(PD_MetaTensor *tensor, PD_DataType dtype, PD_Status *status) { diff --git a/paddle/phi/capi/lib/c_tensor.cc b/paddle/phi/capi/lib/c_tensor.cc index 31a724447b7c7..eb8c8c6f4eb47 100644 --- a/paddle/phi/capi/lib/c_tensor.cc +++ b/paddle/phi/capi/lib/c_tensor.cc @@ -111,6 +111,35 @@ int64_t PD_TensorGetDim(const PD_Tensor* tensor, return cc_tensor->dims()[index]; } +int64_t PD_TensorGetNumStrides(const PD_Tensor* tensor, PD_Status* status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + return cc_tensor->strides().size(); +} + +int64_t PD_TensorGetStride(const PD_Tensor* tensor, + size_t index, + PD_Status* status) { + auto cc_tensor = reinterpret_cast(tensor); + + if (status) { + if (!tensor || index >= static_cast(cc_tensor->strides().size())) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + return cc_tensor->strides()[index]; +} + void PD_TensorGetLoD(const PD_Tensor* tensor, PD_List* data, PD_List* offset, @@ -185,6 +214,19 @@ void* PD_TensorGetHolder(const PD_Tensor* tensor, PD_Status* status) { return cc_tensor->Holder().get(); } +size_t PD_TensorGetOffset(const PD_Tensor* tensor, PD_Status* status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return 0; + } + *status = C_SUCCESS; + } + + auto cc_tensor = reinterpret_cast(tensor); + return cc_tensor->offset(); +} + void PD_TensorSetDims(PD_Tensor* tensor, int64_t ndims, const int64_t* dims, @@ -201,6 +243,36 @@ void PD_TensorSetDims(PD_Tensor* tensor, cc_tensor->Resize(common::make_ddim(shape)); } +void PD_TensorSetOffset(PD_Tensor* tensor, + const int64_t offset, + PD_Status* status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return; + } + *status = C_SUCCESS; + } + auto cc_tensor = reinterpret_cast(tensor); + cc_tensor->set_offset(offset); +} + +void PD_TensorSetStrides(PD_Tensor* tensor, + int64_t nstrides, + const int64_t* strides, + PD_Status* status) { + if (status) { + if (!tensor) { + *status = C_FAILED; + return; + } + *status = C_SUCCESS; + } + auto cc_tensor = reinterpret_cast(tensor); + std::vector shape(strides, strides + nstrides); + cc_tensor->set_strides(common::make_ddim(shape)); +} + void PD_TensorSetDataType(PD_Tensor* tensor, PD_DataType dtype, PD_Status* status) { diff --git a/paddle/phi/common/CMakeLists.txt b/paddle/phi/common/CMakeLists.txt index 5fe96a2a682fb..d4c02b69ce9f2 100644 --- a/paddle/phi/common/CMakeLists.txt +++ b/paddle/phi/common/CMakeLists.txt @@ -1 +1,8 @@ -collect_srcs(common_srcs SRCS place.cc scalar.cc int_array.cc memory_utils.cc) +collect_srcs( + common_srcs + SRCS + place.cc + scalar.cc + int_array.cc + memory_utils.cc + port.cc) diff --git a/paddle/phi/common/place.h b/paddle/phi/common/place.h index 2d32297e74903..9d68821af1d6b 100644 --- a/paddle/phi/common/place.h +++ b/paddle/phi/common/place.h @@ -136,7 +136,6 @@ class GPUPlace : public Place { GPUPlace() : Place(AllocationType::GPU, 0) {} explicit GPUPlace(int device_id) : Place(AllocationType::GPU, device_id) {} - GPUPlace(const GPUPlace&) = default; GPUPlace(const Place& place) // NOLINT : Place(AllocationType::GPU, place.GetDeviceId()) {} }; diff --git a/paddle/phi/backends/dynload/port.cc b/paddle/phi/common/port.cc similarity index 98% rename from paddle/phi/backends/dynload/port.cc rename to paddle/phi/common/port.cc index bcda44a745360..8c94232260aef 100644 --- a/paddle/phi/backends/dynload/port.cc +++ b/paddle/phi/common/port.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include #include #include diff --git a/paddle/phi/backends/dynload/port.h b/paddle/phi/common/port.h similarity index 94% rename from paddle/phi/backends/dynload/port.h rename to paddle/phi/common/port.h index 03a2863e4dc4e..a56479e7a471a 100644 --- a/paddle/phi/backends/dynload/port.h +++ b/paddle/phi/common/port.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/utils/test_macros.h" #define GLOG_NO_ABBREVIATED_SEVERITIES // msvc conflict logging with windows.h @@ -38,7 +39,7 @@ #define S_ISDIR(mode) (((mode)&S_IFMT) == S_IFDIR) #endif // S_ISDIR -void *dlsym(void *handle, const char *symbol_name); +TEST_API void *dlsym(void *handle, const char *symbol_name); void *dlopen(const char *filename, int flag); diff --git a/paddle/phi/common/scalar.h b/paddle/phi/common/scalar.h index 12de9149a96af..e97f918b0f6a5 100644 --- a/paddle/phi/common/scalar.h +++ b/paddle/phi/common/scalar.h @@ -226,6 +226,44 @@ class ScalarBase { return !operator==(other); } + ScalarBase operator-() const { + DataType data_type = this->dtype(); + switch (data_type) { + case DataType::BOOL: + return ScalarBase(-(this->data_.b)); + case DataType::INT8: + return ScalarBase(-(this->data_.i8)); + case DataType::UINT8: + return ScalarBase(-(this->data_.ui8)); + case DataType::INT16: + return ScalarBase(-(this->data_.i16)); + case DataType::UINT16: + return ScalarBase(-(this->data_.ui16)); + case DataType::INT32: + return ScalarBase(-(this->data_.i32)); + case DataType::UINT32: + return ScalarBase(-(this->data_.ui32)); + case DataType::INT64: + return ScalarBase(-(this->data_.i64)); + case DataType::UINT64: + return ScalarBase(-(this->data_.ui64)); + case DataType::FLOAT16: + return ScalarBase(-(this->data_.f16)); + case DataType::BFLOAT16: + return ScalarBase(-(this->data_.bf16)); + case DataType::FLOAT32: + return ScalarBase(-(this->data_.f32)); + case DataType::FLOAT64: + return ScalarBase(-(this->data_.f64)); + case DataType::COMPLEX64: + return ScalarBase(-(this->data_.c64)); + case DataType::COMPLEX128: + return ScalarBase(-(this->data_.c128)); + default: + PD_THROW("Invalid tensor data type `", dtype_, "`."); + } + } + std::string ToRawString() const { std::stringstream ss; switch (dtype_) { @@ -356,9 +394,9 @@ void CopyScalar(const ScalarBase& src, ScalarBase* dst) { } using Scalar = paddle::experimental::ScalarBase; -bool operator==(const Scalar& lhs, const Scalar& rhs); +TEST_API bool operator==(const Scalar& lhs, const Scalar& rhs); -std::ostream& operator<<(std::ostream& os, const Scalar& s); +TEST_API std::ostream& operator<<(std::ostream& os, const Scalar& s); template std::vector ExtractPlainVector( diff --git a/paddle/phi/core/compat/convert_utils.cc b/paddle/phi/core/compat/convert_utils.cc index d4c5de0dbe6dc..37053cc0c09ec 100644 --- a/paddle/phi/core/compat/convert_utils.cc +++ b/paddle/phi/core/compat/convert_utils.cc @@ -63,6 +63,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) { return phi::Place(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) case phi::Backend::GPU: + case phi::Backend::GPUDNN: return phi::GPUPlace( set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0); #endif @@ -70,11 +71,6 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) { case phi::Backend::ONEDNN: // NOLINT return phi::CPUPlace(); #endif -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - case phi::Backend::GPUDNN: - return phi::GPUPlace( - set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0); -#endif #if defined(PADDLE_WITH_XPU) case phi::Backend::XPU: return phi::XPUPlace( diff --git a/paddle/phi/core/compat/convert_utils.h b/paddle/phi/core/compat/convert_utils.h index 632b7a6d17ef2..320338fbc8edd 100644 --- a/paddle/phi/core/compat/convert_utils.h +++ b/paddle/phi/core/compat/convert_utils.h @@ -29,7 +29,7 @@ namespace phi { const std::string& TransToPhiKernelName(const std::string& fluid_op_name); const std::string& TransToFluidOpName(const std::string& phi_kernel_name); -Backend TransToPhiBackend(const phi::Place& place); +TEST_API Backend TransToPhiBackend(const phi::Place& place); phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id = true); #ifdef PADDLE_WITH_DNNL diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b2c334d89023d..12a419e5d6fcc 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -29,11 +29,6 @@ namespace phi { const static std::string deprecated_kernel_name = "deprecated"; // NOLINT -const std::unordered_set standard_kernel_suffixs({ - "sr", // SelectedRows kernel - "raw" // fallback kernel of original fluid op -}); - /** * Some fluid ops are no longer used under the corresponding official API * system of 2.0. These names need to correspond to the official API names diff --git a/paddle/phi/core/cuda_stream.h b/paddle/phi/core/cuda_stream.h index b27770b081433..b6900cdabf2b3 100644 --- a/paddle/phi/core/cuda_stream.h +++ b/paddle/phi/core/cuda_stream.h @@ -155,7 +155,7 @@ class CUDAStream { private: Place place_; Stream stream_; - bool owned_{false}; // whether the stream is created and onwed by self + bool owned_{false}; // whether the stream is created and owned by self }; } // namespace phi diff --git a/paddle/phi/core/custom_kernel.cc b/paddle/phi/core/custom_kernel.cc index bc737fa398baf..3f694518d2dcc 100644 --- a/paddle/phi/core/custom_kernel.cc +++ b/paddle/phi/core/custom_kernel.cc @@ -55,12 +55,12 @@ void CustomKernelMap::RegisterCustomKernels() { kernels[pair.first][info_pair.first] = info_pair.second; - VLOG(3) << "Successed in registering kernel [" << pair.first << ":" + VLOG(3) << "Succeed in registering kernel [" << pair.first << ":" << info_pair.first << "] to Paddle. It will be used like native ones."; } } - LOG(INFO) << "Successed in loading " << kernels_.size() + LOG(INFO) << "Succeed in loading " << kernels_.size() << " custom kernel(s) from loaded lib(s), will be " << "used like native ones."; kernels_.clear(); diff --git a/paddle/phi/core/dense_tensor.cc b/paddle/phi/core/dense_tensor.cc index d15cc4eeafda1..dbadf69cc8cdf 100644 --- a/paddle/phi/core/dense_tensor.cc +++ b/paddle/phi/core/dense_tensor.cc @@ -53,11 +53,10 @@ DenseTensor::DenseTensor(const std::shared_ptr& holder, const DenseTensorMeta& meta) : meta_(meta), holder_(holder) {} -DenseTensor::DenseTensor(const DenseTensor& other) { +DenseTensor::DenseTensor(const DenseTensor& other) { // NOLINT this->meta_ = other.meta(); holder_ = other.holder_; - storage_properties_ = - std::move(CopyStorageProperties(other.storage_properties_)); + storage_properties_ = CopyStorageProperties(other.storage_properties_); inplace_version_counter_ = other.inplace_version_counter_; } @@ -67,8 +66,7 @@ DenseTensor& DenseTensor::operator=(const DenseTensor& other) { } meta_ = other.meta(); holder_ = other.holder_; - storage_properties_ = - std::move(CopyStorageProperties(other.storage_properties_)); + storage_properties_ = CopyStorageProperties(other.storage_properties_); inplace_version_counter_ = other.inplace_version_counter_; return *this; } diff --git a/paddle/phi/core/dense_tensor_impl.cc b/paddle/phi/core/dense_tensor_impl.cc index 97d50dd8179a4..366949a5ec64b 100644 --- a/paddle/phi/core/dense_tensor_impl.cc +++ b/paddle/phi/core/dense_tensor_impl.cc @@ -415,16 +415,14 @@ DenseTensor& DenseTensor::ShareDataWith(const DenseTensor& src) { meta_.offset = src.meta_.offset; meta_.use_gpudnn = src.meta_.use_gpudnn; meta_.strides = src.meta_.strides; - storage_properties_ = - std::move(CopyStorageProperties(src.storage_properties_)); + storage_properties_ = CopyStorageProperties(src.storage_properties_); return *this; } DenseTensor& DenseTensor::ShareDataNoCheckWith(const DenseTensor& src) { holder_ = src.holder_; set_meta(src.meta()); - storage_properties_ = - std::move(CopyStorageProperties(src.storage_properties_)); + storage_properties_ = CopyStorageProperties(src.storage_properties_); return *this; } diff --git a/paddle/phi/core/device_context.cc b/paddle/phi/core/device_context.cc index 3804802e84260..6cf80c350cd04 100644 --- a/paddle/phi/core/device_context.cc +++ b/paddle/phi/core/device_context.cc @@ -14,8 +14,10 @@ #include "paddle/phi/core/device_context.h" -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) #include "paddle/phi/backends/gpu/cuda/cuda_graph.h" +#elif defined(PADDLE_WITH_HIP) +#include "paddle/phi/backends/gpu/rocm/hip_graph.h" #endif #include "paddle/phi/core/dense_tensor.h" @@ -70,7 +72,7 @@ struct DeviceContext::Impl { pinned_allocator_ = allocator; } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void SetCUDAGraphAllocator(const Allocator* allocator) { // NOTE (Yuang): cuda graph allocator can be set to nullptr, so don't check // validation of the allocator here @@ -163,7 +165,7 @@ struct DeviceContext::Impl { (fake_alloc || tensor->numel() == 0) && requested_size == 0 ? zero_allocator_ : (pinned ? pinned_allocator_ : device_allocator_); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) bool must_cuda_graph_allocator = (!fake_alloc && tensor->numel() != 0) && !pinned; if (must_cuda_graph_allocator && @@ -289,7 +291,7 @@ struct DeviceContext::Impl { const Allocator* zero_allocator_{nullptr}; const Allocator* host_zero_allocator_{nullptr}; const Allocator* pinned_allocator_{nullptr}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) const Allocator* cuda_graph_allocator_{nullptr}; #endif Generator* device_generator_{nullptr}; @@ -309,7 +311,7 @@ DeviceContext::DeviceContext(const DeviceContext& other) { impl_->SetPinnedAllocator(&other.GetPinnedAllocator()); impl_->SetHostGenerator(other.GetHostGenerator()); impl_->SetGenerator(other.GetGenerator()); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (other.IsCUDAGraphAllocatorValid()) { impl_->SetCUDAGraphAllocator(&other.GetCUDAGraphAllocator()); } @@ -340,7 +342,7 @@ const Allocator& DeviceContext::GetHostAllocator() const { return impl_->GetHostAllocator(); } -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) void DeviceContext::SetCUDAGraphAllocator(const Allocator* allocator) { impl_->SetCUDAGraphAllocator(allocator); } @@ -415,7 +417,7 @@ T* DeviceContext::HostAlloc(TensorBase* tensor, size_t requested_size) const { } #define DEVICE_CONTEXT_MEMBER_FUNC_INSTANTIATION(dtype) \ - template dtype* DeviceContext::Alloc( \ + template TEST_API dtype* DeviceContext::Alloc( \ TensorBase* tensor, size_t requested_size, bool pinned) const; \ template dtype* DeviceContext::HostAlloc(TensorBase* tensor, \ size_t requested_size) const; diff --git a/paddle/phi/core/device_context.h b/paddle/phi/core/device_context.h index b2b9e79725d85..9ead0e2c32b23 100644 --- a/paddle/phi/core/device_context.h +++ b/paddle/phi/core/device_context.h @@ -115,7 +115,7 @@ class PADDLE_API DeviceContext { const Allocator& GetPinnedAllocator() const; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) /** * @brief Set the CUDA graph Allocator object. * @@ -152,9 +152,9 @@ class PADDLE_API DeviceContext { bool fake_alloc = false) const; template - T* Alloc(TensorBase* tensor, - size_t requested_size = 0, - bool pinned = false) const; + TEST_API T* Alloc(TensorBase* tensor, + size_t requested_size = 0, + bool pinned = false) const; /** * @brief Allocate host memory for tensor. diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index 0e6ab882910a2..f45052ece6632 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -304,5 +304,11 @@ void* DistTensor::AllocateFrom(Allocator* allocator, return nullptr; } +void DistTensor::clear() { + if (value_) { + value_->clear(); + } +} + } // namespace distributed } // namespace phi diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index bf5b083aa6e6f..8ad8cfb437f39 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -79,7 +79,7 @@ class DistTensor final const Placements& placements); /// \brief Construct a empty dist tensor (for infer spmd) - /// \param dims The global dimension of the currnet Tensor. + /// \param dims The global dimension of the current Tensor. /// \param dist_attr The distributed attributes of the current tensor. DistTensor(const DDim& dims, const TensorDistAttr& dist_attr); @@ -178,6 +178,8 @@ class DistTensor final size_t requested_size = 0, bool fake_alloc = false) override; + void clear(); + private: friend class ReshardFunction; diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index 71395507a0951..d2c22bcd08db0 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -107,7 +107,7 @@ struct InferSpmdFnImpl { } }; - // for vecotr slot + // for vector slot template struct InferSpmdFnCallHelper&, Tail...> { diff --git a/paddle/phi/core/distributed/auto_parallel/proto_helper.cc b/paddle/phi/core/distributed/auto_parallel/proto_helper.cc index e8e4197a63c08..fad63c15d63bd 100644 --- a/paddle/phi/core/distributed/auto_parallel/proto_helper.cc +++ b/paddle/phi/core/distributed/auto_parallel/proto_helper.cc @@ -35,8 +35,8 @@ auto_parallel::ProcessMeshProto to_proto(const ProcessMesh& process_mesh) { } auto_parallel::DeviceCapabilityProto to_proto( - const auto_parallel::DeviceCapability& device_capibilty) { - TO_PROTO_HELPER(device_capibilty, auto_parallel::DeviceCapabilityProto); + const auto_parallel::DeviceCapability& device_capability) { + TO_PROTO_HELPER(device_capability, auto_parallel::DeviceCapabilityProto); } auto_parallel::DeviceProto to_proto(const auto_parallel::Device& device) { @@ -44,8 +44,8 @@ auto_parallel::DeviceProto to_proto(const auto_parallel::Device& device) { } auto_parallel::LinkCapabilityProto to_proto( - const auto_parallel::LinkCapability& link_capibilty) { - TO_PROTO_HELPER(link_capibilty, auto_parallel::LinkCapabilityProto); + const auto_parallel::LinkCapability& link_capability) { + TO_PROTO_HELPER(link_capability, auto_parallel::LinkCapabilityProto); } auto_parallel::LinkProto to_proto(const auto_parallel::Link& link) { diff --git a/paddle/phi/core/distributed/auto_parallel/proto_helper.h b/paddle/phi/core/distributed/auto_parallel/proto_helper.h index 66bdf2af74406..840c0eb95f89e 100644 --- a/paddle/phi/core/distributed/auto_parallel/proto_helper.h +++ b/paddle/phi/core/distributed/auto_parallel/proto_helper.h @@ -30,10 +30,10 @@ auto_parallel::TensorDistAttrProto to_proto(const TensorDistAttr& dist_attr); auto_parallel::ProcessMeshProto to_proto(const ProcessMesh& dist_attr); auto_parallel::DeviceCapabilityProto to_proto( - const auto_parallel::DeviceCapability& device_capibilty); + const auto_parallel::DeviceCapability& device_capability); auto_parallel::DeviceProto to_proto(const auto_parallel::Device& device); auto_parallel::LinkCapabilityProto to_proto( - const auto_parallel::LinkCapability& link_capibilty); + const auto_parallel::LinkCapability& link_capability); auto_parallel::LinkProto to_proto(const auto_parallel::Link& link); auto_parallel::DeviceMeshProto to_proto(const auto_parallel::DeviceMesh& link); auto_parallel::DistributedMapperProto to_proto( diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc index b7a6679590e63..222e918ae540b 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/nd_mesh_reshard_function.cc @@ -40,9 +40,11 @@ ProcessMesh GetSubProcessMesh(const ProcessMesh& mesh, int64_t axis) { std::vector process_ids; for (int64_t i = 0; i < shape_of_axis; ++i) { coord[axis] = i; - int64_t rank = coord.back(); - for (int64_t j = static_cast(coord.size() - 2); j >= 0; --j) { - rank += coord[j] * mesh.dim_size(j + 1); + int64_t rank = 0; + int64_t degree = 1; + for (int64_t j = static_cast(coord.size() - 1); j >= 0; --j) { + rank += coord[j] * degree; + degree *= mesh.dim_size(j); } process_ids.emplace_back(mesh.process_ids()[rank]); } @@ -228,7 +230,7 @@ void SameNdMeshReshardFunction::Eval(phi::DeviceContext* dev_ctx, bool is_partial = in_partial_status.count(out_mesh_axis) != 0; VLOG(3) << "Step4: out_mesh axis : " << out_mesh_axis - << "; paratial state :" << is_partial; + << "; partial state :" << is_partial; // 4.1 Calculate the dist_attr after this transform TensorDistAttr real_out_dist_attr(out->dist_attr()); std::vector real_dims_mapping = diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc index 0acf5abf3eec8..c55bf91083ef8 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/p_to_s_reshard_function.cc @@ -20,7 +20,10 @@ #include "paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h" #include "paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.h" #include "paddle/phi/core/distributed/store/store_utils.h" +#include "paddle/phi/kernels/concat_kernel.h" +#include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/reduce_scatter_kernel.h" +#include "paddle/phi/kernels/split_kernel.h" #include "paddle/phi/kernels/transpose_kernel.h" namespace phi { @@ -43,51 +46,132 @@ bool PToSReshardFunction::IsSuitable(const DistTensor& in, return true; } -void PToSReshardFunction::Eval(DeviceContext* dev_ctx, - const DistTensor& in, - const TensorDistAttr& out_dist_attr, - DistTensor* out) { - VLOG(3) << "Call " << Name(); - const auto& in_dist_attr = in.dist_attr(); - const auto& in_process_mesh = in_dist_attr.process_mesh(); - const auto& in_process_ids = in_process_mesh.process_ids(); - auto dtype = in.dtype(); - const auto& logical_ddim = in.dims(); - - int out_split_axis = - GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; - +void ReshardPToSWithPadding(DeviceContext* dev_ctx, + int64_t split_axis, + const std::vector& process_ids, + const DenseTensor& in, + int64_t padding_nums, + DenseTensor* out) { DenseTensor in_reduce_scatter; std::vector axis; - if (out_split_axis != 0) { + const auto& logical_ddim = in.dims(); + auto dtype = in.dtype(); + + if (split_axis != 0) { for (size_t i = 0; i < common::vectorize(logical_ddim).size(); ++i) { axis.emplace_back(i); } - std::swap(axis[0], axis[out_split_axis]); - RESHARD_FUNCTOR( - dev_ctx, Transpose, dtype, in.value(), axis, &in_reduce_scatter); + std::swap(axis[0], axis[split_axis]); + RESHARD_FUNCTOR(dev_ctx, Transpose, dtype, in, axis, &in_reduce_scatter); } else { - in_reduce_scatter.ShareDataWith(in.value()); + in_reduce_scatter.ShareDataWith(in); } DenseTensor out_reduce_scatter; RESHARD_FUNCTOR_WITH_COMM(dev_ctx, ReduceScatter, dtype, - in_process_ids, + process_ids, in_reduce_scatter, - static_cast(in_process_ids.size()), + static_cast(process_ids.size()), &out_reduce_scatter); - if (out_split_axis != 0) { + DenseTensor out_result; + if (split_axis != 0) { + RESHARD_FUNCTOR( + dev_ctx, Transpose, dtype, out_reduce_scatter, axis, &out_result); + } else { + out_result.ShareDataNoCheckWith(out_reduce_scatter); + } + + int64_t cur_global_rank = GetCurGlobalRank(); + if (cur_global_rank == process_ids.back() && padding_nums != 0) { + std::vector tmp_out_vec; + IntArray tmp_sections(std::vector{ + out_result.dims()[split_axis] - padding_nums, padding_nums}); RESHARD_FUNCTOR(dev_ctx, - Transpose, + Split, dtype, - out_reduce_scatter, - axis, - GetMutableTensor(out)); + out_result, + tmp_sections, + split_axis, + &tmp_out_vec); + // TODO(liyurui): Since we can not seperate local tensor with [0, 10] shape + // and uninitialized tensor, here we use a tricky solution. + // Give local tensor which has, for example [0, 10] shape, a little + // allocation, to make it difference from uninitialized tensor in pipelline + // strategy. + if (tmp_out_vec[0].dims()[split_axis] == 0) { + tmp_out_vec[0].mutable_data(tmp_out_vec[0].place(), 4); + } + out->ShareDataNoCheckWith(tmp_out_vec[0]); + } else { + out->ShareDataNoCheckWith(out_result); + } +} + +void PToSReshardFunction::Eval(DeviceContext* dev_ctx, + const DistTensor& in, + const TensorDistAttr& out_dist_attr, + DistTensor* out) { + VLOG(3) << "Call " << Name(); + const auto& in_dist_attr = in.dist_attr(); + const auto& in_process_mesh = in_dist_attr.process_mesh(); + const auto& in_process_ids = in_process_mesh.process_ids(); + + int out_split_axis = + GetSplitAxisWithDimsMapping(out_dist_attr.dims_mapping()).begin()->first; + int64_t num_of_process = in_process_mesh.size(); + int64_t num_of_padding = in.dims()[out_split_axis] % num_of_process; + bool is_balanced_split = (num_of_padding == 0); + + if (is_balanced_split) { + VLOG(3) << "Balanced reshard from partial to shard"; + ReshardPToSWithPadding(dev_ctx, + out_split_axis, + in_process_ids, + in.value(), + /*padding_nums*/ 0, + GetMutableTensor(out)); } else { - SetValue(out, out_reduce_scatter); + VLOG(3) << "Unbalanced reshard from partial to shard"; + int64_t avg_size_on_split_axis = + (in.dims()[out_split_axis] + num_of_process - 1) / num_of_process; + int64_t padding_nums = + avg_size_on_split_axis * num_of_process - in.dims()[out_split_axis]; + + DDim concat_local_shape = in.local_dims(); + concat_local_shape[out_split_axis] = padding_nums; + IntArray concat_local_shape_int_array(concat_local_shape.Get(), + concat_local_shape.size()); + auto dtype = in.dtype(); + + DenseTensor concat_local_tensor; + RESHARD_FUNCTOR(dev_ctx, + Full, + dtype, + concat_local_shape_int_array, + 0, + &concat_local_tensor); + + DenseTensor in_local_tensor = in.value(); + std::vector concat_input_vec = {&in_local_tensor, + &concat_local_tensor}; + + DenseTensor concat_result; + RESHARD_FUNCTOR(dev_ctx, + Concat, + dtype, + concat_input_vec, + out_split_axis, + &concat_result); + + ReshardPToSWithPadding(dev_ctx, + out_split_axis, + in_process_ids, + concat_result, + padding_nums, + GetMutableTensor(out)); } SetDistProps(out, in.dims(), out_dist_attr); diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc index a2a769ef3a2d4..73a367fac273d 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc @@ -147,10 +147,12 @@ std::map GetSplitAxisWithDimsMapping( } std::vector BalancedSplit(int64_t total_nums, int64_t num_of_pieces) { - std::vector result(num_of_pieces, total_nums / num_of_pieces); - int64_t remain_nums = total_nums % num_of_pieces; - for (int64_t i = 0; i < remain_nums; ++i) { - result[i] += 1; + bool has_remainder = (total_nums % num_of_pieces != 0); + std::vector result(num_of_pieces, + (total_nums + num_of_pieces - 1) / num_of_pieces); + if (has_remainder) { + int64_t& last_value = result.back(); + last_value = last_value - (last_value * num_of_pieces - total_nums); } return result; } diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc index fbbcd8eebb9e5..dbfbf1df8d284 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/s_to_r_reshard_function.cc @@ -35,7 +35,7 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx, int64_t split_axis, const std::vector& process_ids, const DenseTensor& in, - int64_t num_of_padding, + int64_t padding_nums, DenseTensor* out) { int64_t num_of_process = process_ids.size(); auto dtype = in.dtype(); @@ -46,7 +46,7 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx, RESHARD_FUNCTOR_WITH_COMM( dev_ctx, AllGather, dtype, process_ids, in, num_of_process, out); - if (split_axis != 0 || num_of_padding != 0) { + if (split_axis != 0 || padding_nums != 0) { IntArray sections(std::vector(num_of_process, in.dims()[0])); std::vector split_out_vec; @@ -58,20 +58,18 @@ void ReshardSToRWithPadding(DeviceContext* dev_ctx, /*split_axis*/ 0, &split_out_vec); - if (num_of_padding != 0) { - for (int64_t i = num_of_padding; i < num_of_process; ++i) { - std::vector tmp_out_vec; - IntArray tmp_sections( - std::vector{in.dims()[split_axis] - 1, 1}); - RESHARD_FUNCTOR(dev_ctx, - Split, - dtype, - split_out_vec[i], - tmp_sections, - split_axis, - &tmp_out_vec); - split_out_vec[i] = tmp_out_vec[0]; - } + if (padding_nums != 0) { + std::vector tmp_out_vec; + IntArray tmp_sections(std::vector{ + in.dims()[split_axis] - padding_nums, padding_nums}); + RESHARD_FUNCTOR(dev_ctx, + Split, + dtype, + split_out_vec[num_of_process - 1], + tmp_sections, + split_axis, + &tmp_out_vec); + split_out_vec[num_of_process - 1] = tmp_out_vec[0]; } // Concat the result after split on correct axis. @@ -124,15 +122,19 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, split_axis, in_process_ids, in.value(), - num_of_padding, + /*padding_nums*/ 0, GetMutableTensor(out)); } else { VLOG(3) << "Unbalanced reshard from shard to replicated"; - bool need_padding = - (in.dims()[split_axis] / num_of_process == in.local_dims()[split_axis]); + int64_t avg_size_on_split_axis = + (in.dims()[split_axis] + num_of_process - 1) / num_of_process; + int64_t padding_nums = + avg_size_on_split_axis * num_of_process - in.dims()[split_axis]; + bool need_padding = (in.local_dims()[split_axis] != avg_size_on_split_axis); + if (need_padding) { DDim concat_local_shape = in.local_dims(); - concat_local_shape[split_axis] = 1; + concat_local_shape[split_axis] = padding_nums; IntArray concat_local_shape_int_array(concat_local_shape.Get(), concat_local_shape.size()); auto dtype = in.dtype(); @@ -156,14 +158,14 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, split_axis, in_process_ids, concat_result, - num_of_padding, + padding_nums, GetMutableTensor(out)); } else { ReshardSToRWithPadding(dev_ctx, split_axis, in_process_ids, in.value(), - num_of_padding, + padding_nums, GetMutableTensor(out)); } } @@ -173,7 +175,6 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx, bool SToRReshardFunctionCrossMesh::IsSuitable( const DistTensor& in, const TensorDistAttr& out_dist_attr) { const auto& in_dist_attr = in.dist_attr(); - const auto& in_dims_mapping = in_dist_attr.dims_mapping(); RESHARD_SHORTCUT_IF_FALSE(in_dist_attr.is_shard()); RESHARD_SHORTCUT_IF_FALSE(out_dist_attr.is_replicated()); @@ -181,16 +182,6 @@ bool SToRReshardFunctionCrossMesh::IsSuitable( const auto& in_process_mesh = in_dist_attr.process_mesh(); const auto& out_process_mesh = out_dist_attr.process_mesh(); - int64_t cur_global_rank = GetCurGlobalRank(); - if (in_process_mesh.contains(cur_global_rank)) { - int split_axis = - GetSplitAxisWithDimsMapping(in_dims_mapping).begin()->first; - int64_t num_of_process = in_process_mesh.size(); - RESHARD_SHORTCUT_IF_FALSE(in.local_dims()[static_cast(split_axis)] * - num_of_process == - in.dims()[static_cast(split_axis)]); - } - RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.ndim() == 1); RESHARD_SHORTCUT_IF_FALSE(out_process_mesh.ndim() == 1); RESHARD_SHORTCUT_IF_FALSE(in_process_mesh.shape() == diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc index 2869951addffc..0a86275203b51 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/same_status_reshard_function.cc @@ -91,7 +91,7 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx, if (src == cur_global_rank) { VLOG(3) << "Send from src " << src << " to dst " << dst; int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst); - // Sice send kernel only has input, so we don't need to infermeta + // Since send kernel only has input, so we don't need to infermeta // actually. According to this reason, just use the kernel directly. RESHARD_FUNCTOR_WITH_COMM(dev_ctx, PSendKernel, diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 5fd7861cc52b2..9e3be85222c61 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -62,7 +62,8 @@ void CommContextManager::CreateNCCLCommContext( int rank, int size, const std::string& hash_key, - const P2POption* p2p_opt) { + const P2POption* p2p_opt, + int nccl_comm_init_option) { auto& comm_context_manager = CommContextManager::GetInstance(); if (comm_context_manager.Has(unique_comm_key)) { return; @@ -91,8 +92,8 @@ void CommContextManager::CreateNCCLCommContext( << ", unique_comm_key: " << unique_comm_key << ", unique_key: " << unique_key << ", nccl_id: " << SerializeNCCLUniqueId(nccl_id); - auto nccl_comm_context = - std::make_unique(rank, size, nccl_id); + auto nccl_comm_context = std::make_unique( + rank, size, nccl_id, nccl_comm_init_option); if (CommContextManager::device_id != -1) { std::unique_ptr dev_ctx( new phi::GPUContext(phi::GPUPlace(CommContextManager::device_id))); @@ -233,12 +234,10 @@ CommContext* CommContextManager::Get(const std::string& unique_comm_key) const { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) int CommContextManager::GetRingId(const ncclComm_t& comm) const { - for (auto iter = id_to_comm_context_.begin(); - iter != id_to_comm_context_.end(); - ++iter) { - if (static_cast(iter->second.get()) + for (const auto& iter : id_to_comm_context_) { + if (static_cast(iter.second.get()) ->GetNcclComm() == comm) { - return std::stoi(iter->first); + return std::stoi(iter.first); } } return -1; diff --git a/paddle/phi/core/distributed/comm_context_manager.h b/paddle/phi/core/distributed/comm_context_manager.h index 8c4d802294986..9e0cb8e5ec3d7 100644 --- a/paddle/phi/core/distributed/comm_context_manager.h +++ b/paddle/phi/core/distributed/comm_context_manager.h @@ -77,7 +77,8 @@ class CommContextManager { int rank, int size, const std::string& hash_key = "", - const P2POption* opt = nullptr); + const P2POption* opt = nullptr, + int nccl_comm_init_option = 0); #endif #if defined(PADDLE_WITH_GLOO) diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index 8da676e74d911..bfa9a494b327a 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -30,10 +30,22 @@ namespace distributed { // set this flag to `true` and recompile to enable dynamic checks constexpr bool FLAGS_enable_nccl_dynamic_check = false; -NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) +NCCLCommContext::NCCLCommContext(int rank, + int size, + ncclUniqueId nccl_id, + int nccl_comm_init_option) : CommContext(rank, size) { - NCCL_CHECK( - phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); + if (nccl_comm_init_option > 0 && phi::dynload::ncclCommInitRank2.IsValid()) { + LOG(WARNING) << "Creating modified qp with ncclCommInitRank2."; + NCCL_CHECK(phi::dynload::ncclCommInitRank2( + &nccl_comm_, size_, nccl_id, rank_, nccl_comm_init_option)); + } else { + if (nccl_comm_init_option > 0) { + LOG(WARNING) << "ncclCommInitRank2 is not supported."; + } + NCCL_CHECK( + phi::dynload::ncclCommInitRank(&nccl_comm_, size_, nccl_id, rank_)); + } NCCL_CHECK(phi::dynload::ncclGetVersion(&nccl_version_)); } diff --git a/paddle/phi/core/distributed/nccl_comm_context.h b/paddle/phi/core/distributed/nccl_comm_context.h index 609b5e0defe07..e11c9709976d3 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.h +++ b/paddle/phi/core/distributed/nccl_comm_context.h @@ -39,7 +39,10 @@ namespace distributed { class NCCLCommContext final : public CommContext { public: - NCCLCommContext(int rank, int size, ncclUniqueId nccl_id); + NCCLCommContext(int rank, + int size, + ncclUniqueId nccl_id, + int nccl_comm_init_option = 0); ~NCCLCommContext() override = default; int GetNcclVersion(); diff --git a/paddle/phi/core/distributed/nccl_comm_task.cc b/paddle/phi/core/distributed/nccl_comm_task.cc index 4e2efea0068eb..9ac1c75fc204a 100644 --- a/paddle/phi/core/distributed/nccl_comm_task.cc +++ b/paddle/phi/core/distributed/nccl_comm_task.cc @@ -249,9 +249,6 @@ void NCCLCommTask::AbortComm() { } std::string NCCLCommTask::GetTraceMsg() { - auto current_timepoint = std::chrono::steady_clock::now(); - auto time_elapsed = std::chrono::duration_cast( - current_timepoint - start_time_); auto global_ranks = phi::distributed::CommContextManager::GetInstance().GetGroupRanks( group_key_); diff --git a/paddle/phi/core/distributed/nccl_comm_task.h b/paddle/phi/core/distributed/nccl_comm_task.h index fca9004cf0b2d..706ce1cf112c2 100644 --- a/paddle/phi/core/distributed/nccl_comm_task.h +++ b/paddle/phi/core/distributed/nccl_comm_task.h @@ -46,7 +46,7 @@ class NCCLCommTask : public CommTask { gpuStream_t = nullptr, CommType comm_type = CommType::UNKNOWN, int64_t timeout = DefaultTimeout); - ~NCCLCommTask() = default; + ~NCCLCommTask() override = default; // check whether the nccl kernel started bool IsStarted() override; @@ -59,8 +59,8 @@ class NCCLCommTask : public CommTask { std::string GetCommErrors() override; void AbortComm() override; - void StartRecord(); - void EndRecord(); + void StartRecord() override; + void EndRecord() override; void ClearRecord() override; bool CudaEventQuery(gpuEvent_t event); diff --git a/paddle/phi/core/distributed/nccl_tools.cc b/paddle/phi/core/distributed/nccl_tools.cc index a5388796d1f45..d79466922976a 100644 --- a/paddle/phi/core/distributed/nccl_tools.cc +++ b/paddle/phi/core/distributed/nccl_tools.cc @@ -29,17 +29,20 @@ namespace distributed { ncclRedOp_t ToNCCLRedType(ReduceOp reduction) { static const std::unordered_map red_type = { - {ReduceOp::MIN, ncclMin}, - {ReduceOp::MAX, ncclMax}, - {ReduceOp::SUM, ncclSum}, - {ReduceOp::PRODUCT, ncclProd}, + {ReduceOp::MIN, ncclMin}, + {ReduceOp::MAX, ncclMax}, + {ReduceOp::SUM, ncclSum}, + {ReduceOp::PRODUCT, ncclProd}, +#if NCCL_VERSION_CODE >= 21000 + {ReduceOp::AVG, ncclAvg}, +#endif }; auto it = red_type.find(reduction); PADDLE_ENFORCE_EQ(it != red_type.end(), true, phi::errors::InvalidArgument( "Invalid nccl reduction. Must be ncclMin | ncclMax | " - "ncclProd | ncclSum")); + "ncclProd | ncclSum | ncclAvg.")); return it->second; } diff --git a/paddle/phi/core/distributed/store/tcp_store.cc b/paddle/phi/core/distributed/store/tcp_store.cc index 067450de210f9..9c4d5bc7eaa6e 100644 --- a/paddle/phi/core/distributed/store/tcp_store.cc +++ b/paddle/phi/core/distributed/store/tcp_store.cc @@ -241,8 +241,12 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { #else _sockets.erase(_sockets.begin() + i - 2); #endif - - VLOG(5) << "Meet some exceptions during run:" << ex.what(); + std::string s(ex.what()); + if (s.find("TCP connection reset by peer") != std::string::npos) { + VLOG(5) << "TCP connection reset by peer"; + } else { + VLOG(5) << "Meet some exceptions during run:" << ex.what(); + } } } } @@ -399,11 +403,11 @@ void TCPStore::waitWorkers() { std::this_thread::sleep_for(std::chrono::milliseconds(10)); if (_timeout != 0 && elapsed.count() > _timeout) { - LOG(FATAL) << paddle::string::Sprintf( + PADDLE_THROW(phi::errors::Fatal(paddle::string::Sprintf( "_timeout:%d elapsed:%d (elapsed > _timeout)=%d", _timeout, elapsed.count(), - elapsed.count() > _timeout); + elapsed.count() > _timeout))); PADDLE_ENFORCE_EQ( completed, diff --git a/paddle/phi/core/distributed/store/tcp_utils.h b/paddle/phi/core/distributed/store/tcp_utils.h index af11ad27f0425..fdc6f8d06048f 100644 --- a/paddle/phi/core/distributed/store/tcp_utils.h +++ b/paddle/phi/core/distributed/store/tcp_utils.h @@ -100,12 +100,16 @@ void receive_bytes(SocketType socket, T* buffer, size_t len) { while (to_recv > 0) { auto byte_received = ::recv(socket, ptr, to_recv, 0); - PADDLE_ENFORCE_GT( + PADDLE_ENFORCE_GE( byte_received, 0, phi::errors::InvalidArgument("TCP receive error. Details: %s.", socket_error().message())); - + if (byte_received == 0) { + PADDLE_THROW(phi::errors::InvalidArgument( + "TCP connection reset by peer. Details: %s.", + socket_error().message())); + } to_recv -= byte_received; ptr += byte_received; } diff --git a/paddle/phi/core/distributed/xccl_comm_context.cc b/paddle/phi/core/distributed/xccl_comm_context.cc index 3e3608e4d88a5..4dd2bcc48857c 100644 --- a/paddle/phi/core/distributed/xccl_comm_context.cc +++ b/paddle/phi/core/distributed/xccl_comm_context.cc @@ -81,7 +81,7 @@ void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, phi::DeviceManager::CCLBroadcast(place_.GetDeviceType(), const_cast(in_tensor.data()), in_tensor.numel(), - phi::ccl::ToCCLDataType(in_tensor.dtype()), + in_tensor.dtype(), root, xccl_comm_, stream); @@ -89,7 +89,7 @@ void XCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, phi::DeviceManager::CCLBroadcast(place_.GetDeviceType(), out_tensor->data(), out_tensor->numel(), - phi::ccl::ToCCLDataType(in_tensor.dtype()), + in_tensor.dtype(), root, xccl_comm_, stream); @@ -110,7 +110,7 @@ void XCCLCommContext::AllGather(phi::DenseTensor* out_tensor, const_cast(in_tensor.data()), out_tensor->data(), in_tensor.numel(), - phi::ccl::ToCCLDataType(in_tensor.dtype()), + in_tensor.dtype(), xccl_comm_, stream); } @@ -125,15 +125,14 @@ void XCCLCommContext::ReduceScatter(phi::DenseTensor* out_tensor, /*cur_rank*/ rank_, size_, phi::AllocationType::CUSTOM); - phi::DeviceManager::CCLReduceScatter( - place_.GetDeviceType(), - const_cast(in_tensor.data()), - out_tensor->data(), - out_tensor->numel(), - phi::ccl::ToCCLDataType(in_tensor.type()), - reduce_type, - xccl_comm_, - stream); + phi::DeviceManager::CCLReduceScatter(place_.GetDeviceType(), + const_cast(in_tensor.data()), + out_tensor->data(), + out_tensor->numel(), + in_tensor.dtype(), + reduce_type, + xccl_comm_, + stream); } void XCCLCommContext::Send(const phi::DenseTensor& in_tensor, @@ -145,7 +144,7 @@ void XCCLCommContext::Send(const phi::DenseTensor& in_tensor, phi::DeviceManager::CCLSend(place_.GetDeviceType(), const_cast(in_tensor.data()), count, - phi::ccl::ToCCLDataType(in_tensor.type()), + in_tensor.dtype(), peer, xccl_comm_, stream); @@ -162,7 +161,7 @@ void XCCLCommContext::Recv(phi::DenseTensor* out_tensor, phi::DeviceManager::CCLRecv(place_.GetDeviceType(), out_tensor->data(), count, - phi::ccl::ToCCLDataType(out_tensor->type()), + out_tensor->dtype(), peer, xccl_comm_, stream); @@ -184,7 +183,7 @@ void XCCLCommContext::AllReduce(phi::DenseTensor* out_tensor, const_cast(in_tensor.data()), out_tensor->data(), in_tensor.numel(), - phi::ccl::ToCCLDataType(in_tensor.type()), + in_tensor.dtype(), reduce_type, xccl_comm_, stream); @@ -205,7 +204,7 @@ void XCCLCommContext::Reduce(phi::DenseTensor* out_tensor, const_cast(in_tensor.data()), out_tensor->data(), in_tensor.numel(), - phi::ccl::ToCCLDataType(in_tensor.type()), + in_tensor.dtype(), reduce_type, root, xccl_comm_, diff --git a/paddle/phi/core/enforce.h b/paddle/phi/core/enforce.h index c74e0ea52cfd3..8ffeb74896ec6 100644 --- a/paddle/phi/core/enforce.h +++ b/paddle/phi/core/enforce.h @@ -79,41 +79,6 @@ limitations under the License. */ namespace phi { namespace enforce { -namespace details { -template -inline constexpr bool IsArithmetic() { - return std::is_arithmetic::value; -} - -template -struct TypeConverterImpl { - using Type1 = typename std::common_type::type; - using Type2 = Type1; -}; - -template -struct TypeConverterImpl { - using Type1 = T1; - using Type2 = T2; -}; - -template -struct TypeConverter { - static constexpr bool kIsArithmetic = - IsArithmetic() && IsArithmetic(); - using Type1 = typename TypeConverterImpl::Type1; - using Type2 = typename TypeConverterImpl::Type2; -}; - -template -using CommonType1 = typename std::add_lvalue_reference< - typename std::add_const::Type1>::type>::type; - -template -using CommonType2 = typename std::add_lvalue_reference< - typename std::add_const::Type2>::type>::type; -} // namespace details - template std::string GetCompleteTraceBackString(StrType&& what, const char* file, @@ -131,14 +96,6 @@ inline bool is_error(bool stat) { return !stat; } void ThrowWarnInternal(const std::string& message); -#define PADDLE_THROW(...) \ - do { \ - HANDLE_THE_ERROR \ - throw ::common::enforce::EnforceNotMet( \ - ::common::ErrorSummary(__VA_ARGS__), __FILE__, __LINE__); \ - END_HANDLE_THE_ERROR \ - } while (0) - #if defined(__CUDA_ARCH__) // For cuda, the assertions can affect performance and it is therefore // recommended to disable them in production code @@ -359,7 +316,7 @@ DEFINE_EXTERNAL_API_TYPE(ncclResult_t, ncclSuccess); } // namespace details template -std::string GetExternalErrorMsg(T status); +TEST_API std::string GetExternalErrorMsg(T status); /*************** CUDA ERROR ***************/ inline bool is_error(cudaError_t e) { return e != cudaSuccess; } diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index b40978edf1225..947af3af1d089 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -114,6 +114,10 @@ class KernelContext { return paddle::none; } + const TensorBase* MutableIutputAt(size_t idx) const { + return inputs_.at(idx); + } + template TensorType* MutableOutputAt(size_t idx) { return static_cast(outputs_.at(idx)); diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 35ac9e1e0db95..32644cfe8bf63 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -30,7 +30,7 @@ PHI_DEFINE_EXPORTED_bool(use_stride_kernel, true, - "Whether to use strdie kernel if op support stride."); + "Whether to use stride kernel if op support stride."); COMMON_DECLARE_int32(low_precision_op_list); COMMON_DECLARE_bool(enable_api_kernel_fallback); @@ -177,6 +177,22 @@ bool KernelFactory::HasKernel(const std::string& kernel_name, phi::errors::NotFound("The kernel `%s` is not registered.", kernel_name)); auto kernel_iter = iter->second.find(kernel_key); + if (kernel_iter == iter->second.end() && + kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) { + phi::KernelKey any_layout_kernel_key( + kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()); + kernel_iter = iter->second.find(any_layout_kernel_key); + } + +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + if (kernel_iter == iter->second.end() && + kernel_key.backend() > phi::Backend::NUM_BACKENDS) { + kernel_iter = iter->second.find({phi::Backend::CUSTOM, + phi::DataLayout::ALL_LAYOUT, + kernel_key.dtype()}); + } +#endif + if (kernel_iter == iter->second.end()) { return false; } @@ -233,6 +249,17 @@ KernelResult KernelFactory::SelectKernelOrThrowError( if (stride_kernel_iter != iter->second.end()) { return {stride_kernel_iter->second, false, true}; } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (stride_kernel_iter == iter->second.end() && + const_kernel_key.backend() > phi::Backend::NUM_BACKENDS) { + stride_kernel_iter = iter->second.find({phi::Backend::CUSTOM, + phi::DataLayout::STRIDED, + const_kernel_key.dtype()}); + if (stride_kernel_iter != iter->second.end()) { + return {stride_kernel_iter->second, false, true}; + } + } +#endif } KernelKey kernel_key = KernelKey(const_kernel_key.backend(), diff --git a/paddle/phi/core/kernel_registry.cc b/paddle/phi/core/kernel_registry.cc index fa9d531b6534d..6ce1af187e9a3 100644 --- a/paddle/phi/core/kernel_registry.cc +++ b/paddle/phi/core/kernel_registry.cc @@ -47,139 +47,159 @@ void SetKernelArgsDef(const std::vector& args_type, ) { #endif // do nothing, skip context arg now - } else if (arg_type == std::type_index(typeid(const DenseTensor&))) { + } else if (arg_type == + std::type_index(typeid(const DenseTensor&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); } else if (arg_type == - std::type_index(typeid(const paddle::optional&))) { + std::type_index( + typeid(const paddle::optional&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); } else if (arg_type == - std::type_index(typeid( - const paddle::optional>&))) { + std::type_index( + typeid(const paddle::optional< + std::vector>&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); } else if (arg_type == - std::type_index(typeid(const paddle::optional&))) { + std::type_index( + typeid(const paddle::optional&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid( - const std::vector&))) { + } else if (arg_type == + std::type_index( + typeid(const std::vector&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); } else if (arg_type == - std::type_index(typeid(const phi::ExtendedTensor&))) { + std::type_index(typeid(const phi::ExtendedTensor&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid( - const std::vector&))) { + } else if (arg_type == + std::type_index(typeid( + const std::vector&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid( - const std::vector&))) { + } else if (arg_type == + std::type_index(typeid( + const std::vector&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); } else if (arg_type == - std::type_index(typeid(const std::vector&))) { + std::type_index( + typeid(const std::vector&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid( - const std::vector&))) { + } else if (arg_type == + std::type_index( + typeid(const std::vector&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(const SelectedRows&))) { + } else if (arg_type == + std::type_index(typeid(const SelectedRows&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(const StringTensor&))) { + } else if (arg_type == + std::type_index(typeid(const StringTensor&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(const SparseCooTensor&))) { + } else if (arg_type == + std::type_index(typeid(const SparseCooTensor&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid( - paddle::optional))) { + } else if (arg_type == + std::type_index(typeid( + paddle::optional))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(const SparseCsrTensor&))) { + } else if (arg_type == + std::type_index(typeid(const SparseCsrTensor&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid( - paddle::optional))) { + } else if (arg_type == + std::type_index(typeid( + paddle::optional))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(const TensorArray&))) { + } else if (arg_type == + std::type_index(typeid(const TensorArray&))) { // NOLINT args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(DenseTensor*))) { + } else if (arg_type == std::type_index(typeid(DenseTensor*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(std::vector))) { + } else if (arg_type == + std::type_index(typeid(std::vector))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(SelectedRows*))) { + } else if (arg_type == std::type_index(typeid(SelectedRows*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(TensorArray*))) { + } else if (arg_type == std::type_index(typeid(TensorArray*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(SparseCooTensor*))) { + } else if (arg_type == + std::type_index(typeid(SparseCooTensor*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(SparseCsrTensor*))) { + } else if (arg_type == + std::type_index(typeid(SparseCsrTensor*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(StringTensor*))) { + } else if (arg_type == std::type_index(typeid(StringTensor*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); - } else if (arg_type == std::type_index(typeid(ExtendedTensor*))) { + } else if (arg_type == + std::type_index(typeid(ExtendedTensor*))) { // NOLINT args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 715b4f76392d8..801a69498b4c9 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -117,8 +117,8 @@ namespace phi { static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair& range = ctx->InputRangeAt(in_idx); \ - std::vector arg = std::move( \ - ctx->InputsBetween(range.first, range.second)); \ + std::vector arg = \ + ctx->InputsBetween(range.first, range.second); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ @@ -202,22 +202,22 @@ namespace phi { } \ } -#define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ - template \ - struct KernelCallHelper, Tail...> { \ - template \ - static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ - const std::pair& range = ctx->OutputRangeAt(out_idx); \ - std::vector arg = std::move( \ - ctx->MutableOutputBetween(range.first, range.second)); \ - KernelCallHelper:: \ - template Compute( \ - ctx, pargs..., arg); \ - } \ +#define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ + template \ + struct KernelCallHelper, Tail...> { \ + template \ + static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ + const std::pair& range = ctx->OutputRangeAt(out_idx); \ + std::vector arg = \ + ctx->MutableOutputBetween(range.first, range.second); \ + KernelCallHelper:: \ + template Compute( \ + ctx, pargs..., arg); \ + } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_TENSOR_SCALAR(attr_type) \ diff --git a/paddle/phi/core/lod_utils.h b/paddle/phi/core/lod_utils.h index a366f82c0ddf3..fdfe65f223827 100644 --- a/paddle/phi/core/lod_utils.h +++ b/paddle/phi/core/lod_utils.h @@ -16,6 +16,8 @@ #include #include +#include "paddle/utils/test_macros.h" + namespace phi { using LoD = std::vector>; @@ -24,7 +26,7 @@ using LoD = std::vector>; */ LoD ToAbsOffset(const LoD& in); -void AppendLoD(LoD* lod, const LoD& lod_length); +TEST_API void AppendLoD(LoD* lod, const LoD& lod_length); /* * Convert between length-based LoD and offset-based LoD. @@ -36,6 +38,6 @@ void AppendLoD(LoD* lod, const LoD& lod_length); * If offset_lod = [[0, 2, 3],[0, 3, 5, 9]] * then length_lod = [[2, 1], [3, 2, 4]] */ -LoD ConvertToLengthBasedLoD(const LoD& offset_lod); +TEST_API LoD ConvertToLengthBasedLoD(const LoD& offset_lod); } // namespace phi diff --git a/paddle/phi/core/os_info.h b/paddle/phi/core/os_info.h index eb93590669da3..1d44ecb46a29d 100644 --- a/paddle/phi/core/os_info.h +++ b/paddle/phi/core/os_info.h @@ -20,7 +20,7 @@ limitations under the License. */ #ifdef _POSIX_C_SOURCE #include #endif -#include "paddle/phi/backends/dynload/port.h" +#include "paddle/phi/common/port.h" namespace phi { @@ -54,7 +54,7 @@ ThreadId GetCurrentThreadId(); // Return the map from StdTid to ThreadId // Returns current snapshot of all threads. Make sure there is no thread -// create/destory when using it. +// create/destroy when using it. std::unordered_map GetAllThreadIds(); static constexpr const char* kDefaultThreadName = "unnamed"; @@ -63,7 +63,7 @@ std::string GetCurrentThreadName(); // Return the map from StdTid to ThreadName // Returns current snapshot of all threads. Make sure there is no thread -// create/destory when using it. +// create/destroy when using it. std::unordered_map GetAllThreadNames(); // Thread name is immutable, only the first call will succeed. diff --git a/paddle/phi/core/selected_rows.h b/paddle/phi/core/selected_rows.h index 7674a8e8722bc..145f7e7d3b2e4 100644 --- a/paddle/phi/core/selected_rows.h +++ b/paddle/phi/core/selected_rows.h @@ -42,7 +42,8 @@ class SelectedRows : public TensorBase, * */ public: - SelectedRows(const std::vector& rows, const int64_t& height); + TEST_API SelectedRows(const std::vector& rows, + const int64_t& height); TEST_API SelectedRows(); diff --git a/paddle/phi/core/selected_rows_impl.cc b/paddle/phi/core/selected_rows_impl.cc index ff96342940d92..afa20cc1a46c2 100644 --- a/paddle/phi/core/selected_rows_impl.cc +++ b/paddle/phi/core/selected_rows_impl.cc @@ -188,7 +188,7 @@ void SelectedRowsImpl::Get(const phi::DenseTensor& ids, value->numel() / value->dims()[0], phi::errors::InvalidArgument( "Output tensor should have the same shape with table " - "except the first dimmension, excepted value width not counting " + "except the first dimension, excepted value width not counting " "the first dimension is %d, actual value width is %d.", value_width, value->numel() / value->dims()[0])); diff --git a/paddle/phi/core/sparse_coo_tensor.cc b/paddle/phi/core/sparse_coo_tensor.cc index dfd519250aa37..d6f41168981aa 100644 --- a/paddle/phi/core/sparse_coo_tensor.cc +++ b/paddle/phi/core/sparse_coo_tensor.cc @@ -51,7 +51,7 @@ SparseCooTensor::SparseCooTensor(DenseTensor&& non_zero_indices, meta_.dtype = non_zero_elements.dtype(); } -SparseCooTensor::SparseCooTensor(const SparseCooTensor& other) { +SparseCooTensor::SparseCooTensor(const SparseCooTensor& other) { // NOLINT this->non_zero_indices_ = other.non_zero_indices_; this->non_zero_elements_ = other.non_zero_elements_; this->coalesced_ = other.coalesced_; diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index d0759bedcf557..61c8b0c3d2a5b 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -127,7 +127,7 @@ class SparseCooTensor : public TensorBase, /// \brief Test whether the non_zero_elements_ storage is allocated. /// In special cases, when nnz=0, non_zero_elements_ will not need to be - /// initialized, but it is neccessary to return true here, otherwise the + /// initialized, but it is necessary to return true here, otherwise the /// gradient will be None. return Whether the non_zero_elements_ storage is /// allocated. bool initialized() const override { @@ -189,7 +189,7 @@ class SparseCooTensor : public TensorBase, /// \brief get the sparse dim int32_t sparse_dim() const; - /// \brief get the dnese dim + /// \brief get the dense dim int32_t dense_dim() const; /// \brief Returns the meta information of the tensor. diff --git a/paddle/phi/core/sparse_csr_tensor.cc b/paddle/phi/core/sparse_csr_tensor.cc index 525f38cd8263d..f4373f528d217 100644 --- a/paddle/phi/core/sparse_csr_tensor.cc +++ b/paddle/phi/core/sparse_csr_tensor.cc @@ -66,7 +66,7 @@ SparseCsrTensor::SparseCsrTensor(const DenseTensor& non_zero_crows, meta_.dtype = non_zero_elements.dtype(); } -SparseCsrTensor::SparseCsrTensor(const SparseCsrTensor& other) { +SparseCsrTensor::SparseCsrTensor(const SparseCsrTensor& other) { // NOLINT this->non_zero_crows_ = other.non_zero_crows_; this->non_zero_cols_ = other.non_zero_cols_; this->non_zero_elements_ = other.non_zero_elements_; diff --git a/paddle/phi/core/sparse_csr_tensor.h b/paddle/phi/core/sparse_csr_tensor.h index 1901b824f5686..b746694475ade 100644 --- a/paddle/phi/core/sparse_csr_tensor.h +++ b/paddle/phi/core/sparse_csr_tensor.h @@ -42,7 +42,7 @@ class SparseCsrTensor : public TensorBase, SparseCsrTensor(const SparseCsrTensor& other); /// \brief create the sparse csr tensor. - /// \param non_zero_crows The compresessed row index of non zero elements in + /// \param non_zero_crows The compressed row index of non zero elements in /// original dense tensor. /// \param non_zero_cols The column index of non zero elements in original /// dense tensor. @@ -132,7 +132,7 @@ class SparseCsrTensor : public TensorBase, /// \brief Test whether the non_zero_elements_ storage is allocated. /// In special cases, when nnz=0, non_zero_elements_ will not need to be - /// initialized, but it is neccessary to return true here, otherwise the + /// initialized, but it is necessary to return true here, otherwise the /// gradient will be None. return Whether the non_zero_elements_ storage is /// allocated. bool initialized() const override { @@ -145,7 +145,7 @@ class SparseCsrTensor : public TensorBase, void Resize(const DDim& dense_dims, const int64_t non_zero_num); /// \brief set the member of sparse csr tensor. - /// \param non_zero_crows The compresessed row index of non zero elements in + /// \param non_zero_crows The compressed row index of non zero elements in /// original dense tensor. /// \param non_zero_cols The column index of non zero elements in original /// dense tensor. @@ -157,7 +157,7 @@ class SparseCsrTensor : public TensorBase, const DDim& dims); /// \brief set the member of sparse csr tensor. - /// \param non_zero_crows The compresessed row index of non zero elements in + /// \param non_zero_crows The compressed row index of non zero elements in /// original dense tensor. /// \param non_zero_cols The column index of non zero elements in original /// dense tensor. diff --git a/paddle/phi/core/storage_properties.h b/paddle/phi/core/storage_properties.h index ac64875452bf8..550a9ef152db0 100644 --- a/paddle/phi/core/storage_properties.h +++ b/paddle/phi/core/storage_properties.h @@ -63,7 +63,7 @@ struct XPUStorageProperties }; #endif -// Add OneDNNStorageProperties firstly for unittest covergae +// Add OneDNNStorageProperties firstly for unittest coverage #ifdef PADDLE_WITH_DNNL struct OneDNNStorageProperties : public StorageProperties, diff --git a/paddle/phi/core/stream.h b/paddle/phi/core/stream.h index 593bee67ef876..f8f9f8f2d4b3d 100644 --- a/paddle/phi/core/stream.h +++ b/paddle/phi/core/stream.h @@ -26,7 +26,7 @@ class Stream final { StreamId id() const { return id_; } private: - StreamId id_{0}; // not onwed the stream + StreamId id_{0}; // not owned the stream }; } // namespace phi diff --git a/paddle/phi/core/string_tensor.cc b/paddle/phi/core/string_tensor.cc index d370be21f4cac..bb7d06825fdbb 100644 --- a/paddle/phi/core/string_tensor.cc +++ b/paddle/phi/core/string_tensor.cc @@ -37,7 +37,7 @@ StringTensor::StringTensor(const std::shared_ptr& holder, const StringTensorMeta& meta) : meta_(meta), holder_(holder) {} -StringTensor::StringTensor(const StringTensor& other) { +StringTensor::StringTensor(const StringTensor& other) { // NOLINT this->meta_ = other.meta(); holder_ = other.holder_; } diff --git a/paddle/phi/core/tensor_array.h b/paddle/phi/core/tensor_array.h index 69995c016ac33..3c17217bf0d6d 100644 --- a/paddle/phi/core/tensor_array.h +++ b/paddle/phi/core/tensor_array.h @@ -54,13 +54,13 @@ class TensorArray : public TensorBase, /// \return The name of the class. static const char* name() { return "TensorArray"; } - /// \brief This overrided function is not used in TensorArray. + /// \brief This overridden function is not used in TensorArray. TEST_API int64_t numel() const override; - /// \brief This overrided function is not used in TensorArray. + /// \brief This overridden function is not used in TensorArray. TEST_API const DDim& dims() const override; - /// \brief This overrided function is not used in TensorArray. + /// \brief This overridden function is not used in TensorArray. TEST_API const Place& place() const override; TEST_API DataType dtype() const override; @@ -75,7 +75,7 @@ class TensorArray : public TensorBase, void set_layout(const DataLayout layout); #endif - /// \brief This overrided function is not used in TensorArray. + /// \brief This overridden function is not used in TensorArray. TEST_API bool valid() const override; /// \brief Test whether the tensor's storage in TensorArray is allocated. diff --git a/paddle/phi/core/tensor_meta.h b/paddle/phi/core/tensor_meta.h index 4c7c9ace49d32..f493e0249d7bf 100644 --- a/paddle/phi/core/tensor_meta.h +++ b/paddle/phi/core/tensor_meta.h @@ -121,7 +121,7 @@ struct SparseTensorMeta { bool valid() const noexcept; DDim dims; - DataType dtype; + DataType dtype{DataType::UNDEFINED}; DataLayout layout{DataLayout::NCHW}; }; diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 4d9b50d34f8f5..5d82fdfce976c 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -134,7 +134,8 @@ void TensorToVector(const phi::DenseTensor& src, const phi::DeviceContext& ctx, std::vector* dst); -phi::DenseTensor ReshapeToMatrix(const phi::DenseTensor& src, int num_col_dims); +TEST_API phi::DenseTensor ReshapeToMatrix(const phi::DenseTensor& src, + int num_col_dims); template T GetValue(const phi::DenseTensor* x); diff --git a/paddle/phi/core/threadpool.cc b/paddle/phi/core/threadpool.cc index 713ac4c0751f6..8ae9c5b4bf363 100644 --- a/paddle/phi/core/threadpool.cc +++ b/paddle/phi/core/threadpool.cc @@ -54,7 +54,7 @@ void ThreadPool::Init() { ThreadPool::ThreadPool(int num_threads) : running_(true) { threads_.resize(num_threads); for (auto& thread : threads_) { - // TODO(Yancey1989): binding the thread on the specify CPU numberw + // TODO(Yancey1989): binding the thread on the specify CPU number thread = std::make_unique([this] { ThreadPool::TaskLoop(); }); } } diff --git a/paddle/phi/core/threadpool.h b/paddle/phi/core/threadpool.h index 110a6a459186f..7dd9b79b07c06 100644 --- a/paddle/phi/core/threadpool.h +++ b/paddle/phi/core/threadpool.h @@ -56,7 +56,7 @@ class ThreadPool { std::packaged_task()>; // Returns the singleton of ThreadPool. - static ThreadPool* GetInstance(); + TEST_API static ThreadPool* GetInstance(); ~ThreadPool(); @@ -80,7 +80,7 @@ class ThreadPool { new common::enforce::EnforceNotMet(ex)); } catch (const std::exception& e) { PADDLE_THROW(phi::errors::Fatal( - "Unexpected exception is catched in thread pool. All " + "Unexpected exception is caught in thread pool. All " "throwable exception in Paddle should be an EnforceNotMet." "The exception is:\n %s.", e.what())); @@ -129,7 +129,7 @@ class ThreadPoolIO : ThreadPool { static void InitIO(); private: - // NOTE: threadpool in base will be inhereted here. + // NOTE: threadpool in base will be inherited here. static std::unique_ptr io_threadpool_; static std::once_flag io_init_flag_; }; diff --git a/paddle/phi/core/utils/intrusive_ref_counter.h b/paddle/phi/core/utils/intrusive_ref_counter.h index 1681f88af054f..6b2a3e989a840 100644 --- a/paddle/phi/core/utils/intrusive_ref_counter.h +++ b/paddle/phi/core/utils/intrusive_ref_counter.h @@ -57,7 +57,7 @@ inline void intrusive_ptr_release( const intrusive_ref_counter* p) noexcept { if (p->ref_.load(std::memory_order_acquire) == 0 || p->ref_.fetch_sub(1) == 0) { - delete static_cast(p); + delete static_cast(p); // NOLINT } } diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h index 7ee12e26d7d0e..ad30da4ddcd6f 100644 --- a/paddle/phi/core/visit_type.h +++ b/paddle/phi/core/visit_type.h @@ -471,4 +471,20 @@ namespace phi { } \ }() +#define PD_VISIT_KERNEL( \ + kernel_name, kernel_key, kernel_signature, use_strided_kernel, ...) \ + [&] { \ + auto kernel_result = \ + phi::KernelFactory::Instance().SelectKernelOrThrowError( \ + kernel_name, kernel_key, use_strided_kernel); \ + const auto& kernel = kernel_result.kernel; \ + auto* kernel_fn = kernel.GetVariadicKernelFn(); \ + if (kernel_result.has_fallback_cpu) { \ + PADDLE_THROW(phi::errors::NotFound( \ + "The kernel with key %s of kernel `%s` is not registered.", \ + kernel_key, \ + kernel_name)); \ + } \ + (*kernel_fn)(__VA_ARGS__); \ + }() } // namespace phi diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 845a8e6835729..9ba70ce824b39 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -39,6 +39,21 @@ void AngleGradInferMeta(const MetaTensor& x, UnchangedInferMeta(x, x_grad); } +void BatchFCGradInferMeta(const MetaTensor& input, + const MetaTensor& w, + const MetaTensor& bias, + const MetaTensor& out_grad, + MetaTensor* input_grad, + MetaTensor* w_grad, + MetaTensor* bias_grad) { + input_grad->set_dims(input.dims()); + input_grad->set_dtype(input.dtype()); + w_grad->set_dims(w.dims()); + w_grad->set_dtype(w.dtype()); + bias_grad->set_dims(bias.dims()); + bias_grad->set_dtype(bias.dtype()); +} + void BilinearGradInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, @@ -843,12 +858,23 @@ void NanmedianGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, const IntArray& axes, bool keep_dim, + const std::string& mode, MetaTensor* x_grad) { auto x_dims = x.dims(); x_grad->set_dims(x_dims); x_grad->set_dtype(x.dtype()); } +void PartialConcatGradInferMeta(const std::vector& xs, + std::vector x_grads) { + auto input_num = xs.size(); + for (size_t i = 0; i < input_num; i++) { + auto x_dims = xs[i]->dims(); + x_grads[i]->set_dims(x_dims); + x_grads[i]->set_dtype(xs[i]->dtype()); + } +} + void NceGradInferMeta(const MetaTensor& input, const MetaTensor& bias, const MetaTensor& weight, @@ -876,6 +902,16 @@ void NceGradInferMeta(const MetaTensor& input, } } +void PartialSumGradInferMeta(const std::vector& xs, + std::vector x_grads) { + auto input_num = xs.size(); + for (size_t i = 0; i < input_num; i++) { + auto x_dims = xs[i]->dims(); + x_grads[i]->set_dims(x_dims); + x_grads[i]->set_dtype(xs[i]->dtype()); + } +} + void NllLossGradInferMeta(const MetaTensor& x, const MetaTensor& label, const MetaTensor& weight, @@ -1008,6 +1044,19 @@ void PsroiPoolGradInferMeta(const MetaTensor& x, dx->share_meta(x); } +void RankAttentionGradInferMeta(const MetaTensor& x, + const MetaTensor& rank_offset, + const MetaTensor& rank_param, + const MetaTensor& input_help, + const MetaTensor& ins_rank, + const MetaTensor& out_grad, + int max_rank, + int max_size, + MetaTensor* rank_param_grad) { + rank_param_grad->set_dims(rank_param.dims()); + rank_param_grad->set_dtype(rank_param.dtype()); +} + void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx) { dx->set_dims(out_grad.dims()); dx->set_dtype(dtype::ToComplex(out_grad.dtype())); @@ -1180,16 +1229,16 @@ void TransposeGradInferMeta(const MetaTensor& x, const std::vector& axis, MetaTensor* out) { size_t x_rank = x.dims().size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { - formated_axis[i] = static_cast(axis[i] + x_rank); + formatted_axis[i] = static_cast(axis[i] + x_rank); } } std::vector reversed_axis(axis); - for (int i = 0; i < static_cast(formated_axis.size()); i++) { - reversed_axis[formated_axis[i]] = i; + for (int i = 0; i < static_cast(formatted_axis.size()); i++) { + reversed_axis[formatted_axis[i]] = i; } TransposeInferMeta(x, reversed_axis, out); diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index bde9c57ff245a..278b4ba970ff1 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -36,6 +36,14 @@ void AngleGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, MetaTensor* x_grad); +void BatchFCGradInferMeta(const MetaTensor& input, + const MetaTensor& w, + const MetaTensor& bias, + const MetaTensor& out_grad, + MetaTensor* intput_grad, + MetaTensor* w_grad, + MetaTensor* bias_grad); + void BilinearGradInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& weight, @@ -370,8 +378,15 @@ void NanmedianGradInferMeta(const MetaTensor& x, const MetaTensor& out_grad, const IntArray& axes, bool keep_dim, + const std::string& mode, MetaTensor* x_grad); +void PartialConcatGradInferMeta(const std::vector& xs, + std::vector x_grads); + +void PartialSumGradInferMeta(const std::vector& xs, + std::vector x_grads); + void NceGradInferMeta(const MetaTensor& input, const MetaTensor& bias, const MetaTensor& weight, @@ -415,6 +430,16 @@ void PsroiPoolGradInferMeta(const MetaTensor& x, float spatial_scale, MetaTensor* dx); +void RankAttentionGradInferMeta(const MetaTensor& x, + const MetaTensor& rank_offset, + const MetaTensor& rank_param, + const MetaTensor& input_help, + const MetaTensor& ins_rank, + const MetaTensor& out_grad, + int max_rank, + int max_size, + MetaTensor* rank_param_grad); + void RealAndImagGradInferMeta(const MetaTensor& out_grad, MetaTensor* dx); void ReshapeDoubleGradInferMeta(const MetaTensor& out_grad, diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index fdef52a5fb6e1..63d1d1c9b32d0 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -166,8 +166,8 @@ void ArrayReadInferMeta(const MetaTensor& array, out->set_dims({-1}); } else { double index = i.to(); - out->set_dims(array.dims(index)); - out->share_lod(array, index); + out->set_dims(array.dims(index)); // NOLINT + out->share_lod(array, index); // NOLINT } out->set_dtype(array.dtype()); out->set_layout(array.layout()); @@ -1201,6 +1201,60 @@ void DistributeFpnProposalsInferMeta( } } +void DistributedFusedLambInitInferMeta( + const std::vector& param, + const std::vector& grad, + float beta1, + float beta2, + const std::vector& apply_weight_decay, + int alignment, + int rank, + int nranks, + MetaTensor* fp32_fused_param, + MetaTensor* fp32_fused_grad, + MetaTensor* fp16_fused_param, + MetaTensor* fp16_fused_grad, + MetaTensor* moment1, + MetaTensor* moment2, + MetaTensor* beta1_pow, + MetaTensor* beta2_pow, + MetaTensor* fused_param_offsets, + MetaTensor* fp32_shard_fused_param_offsets, + MetaTensor* fp16_shard_fused_param_offsets, + MetaTensor* param_info, + MetaTensor* param_order, + std::vector param_out, + std::vector master_param_out, + std::vector grad_out, + MetaTensor* global_scale, + MetaTensor* step) { + fp32_fused_param->set_dtype(DataType::FLOAT32); + fp32_fused_grad->set_dtype(DataType::FLOAT32); + fp16_fused_param->set_dtype(DataType::FLOAT16); + fp16_fused_grad->set_dtype(DataType::FLOAT16); + moment1->set_dtype(DataType::FLOAT32); + moment2->set_dtype(DataType::FLOAT32); + beta1_pow->set_dtype(DataType::FLOAT32); + beta2_pow->set_dtype(DataType::FLOAT32); + fused_param_offsets->set_dtype(DataType::INT32); + fp32_shard_fused_param_offsets->set_dtype(DataType::INT32); + fp16_shard_fused_param_offsets->set_dtype(DataType::INT32); + param_info->set_dtype(DataType::INT32); + param_order->set_dtype(DataType::INT32); + + for (size_t i = 0; i < param.size(); ++i) { + param_out[i]->set_dtype(param[i]->dtype()); + master_param_out[i]->set_dtype(DataType::FLOAT32); + } + + for (size_t i = 0; i < grad.size(); ++i) { + grad_out[i]->set_dtype(grad[i]->dtype()); + } + + global_scale->set_dtype(DataType::FLOAT32); + step->set_dtype(DataType::INT64); +} + void DropoutInferMeta(const MetaTensor& x, const MetaTensor& seed_tensor, const Scalar& p, @@ -1478,7 +1532,7 @@ void ExpandAsInferMeta(const MetaTensor& x, const MetaTensor& y, const std::vector& target_shape, MetaTensor* out) { -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 auto x_dims = x.dims(); PADDLE_ENFORCE_GE( target_shape.size(), @@ -2113,6 +2167,15 @@ void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { out->set_dtype(x.dtype()); } +void LimitByCapacityInferMeta(const MetaTensor& expert_count, + const MetaTensor& capacity, + int n_worker, + MetaTensor* out) { + out->share_dims(expert_count); + out->share_lod(expert_count); + out->set_dtype(expert_count.dtype()); +} + void LogLossInferMeta(const MetaTensor& input, const MetaTensor& label, float epsilon, @@ -2801,6 +2864,35 @@ void PriorBoxInferMeta(const MetaTensor& input, var->set_dims(common::make_ddim(dim_vec)); } +void PruneGateByCapacityInferMeta(const MetaTensor& gate_idx, + const MetaTensor& expert_count, + int64_t n_expert, + int64_t n_worker, + MetaTensor* new_gate_idx) { + auto expert_count_dims = expert_count.dims(); + + int64_t expert_count_num_ele = 1; + for (int i = 0; i < static_cast(expert_count_dims.size()); i++) { + expert_count_num_ele *= expert_count_dims[i]; + } + + PADDLE_ENFORCE_EQ( + expert_count_num_ele, + n_expert * n_worker, + phi::errors::Unavailable( + "The number of elements for expert_count is ( %ld ) incorrect. " + "Because the number of expert_count must equal the " + "product of n_worker ( %ld ) and n_expert ( %ld ). " + "Please input appropriate expert_count again!", + expert_count_num_ele, + n_worker, + n_expert)); + + auto gate_idx_in_dims = gate_idx.dims(); + new_gate_idx->set_dims(gate_idx_in_dims); + new_gate_idx->set_dtype(gate_idx.dtype()); +} + void RepeatInterleaveWithTensorIndexInferMeta(const MetaTensor& x, const MetaTensor& repeats, int dim, @@ -3557,8 +3649,8 @@ void WeightDequantizeInferMeta(const MetaTensor& x, dim_scale[0], (x.dims()[1] + (group_size - 1)) / group_size)); } - int n = x.dims()[1]; - int k = x.dims()[0]; + int n = static_cast(x.dims()[1]); + int k = static_cast(x.dims()[0]); out->set_dims(common::make_ddim({n, k})); out->set_dtype(out_dtype); } diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 79b46c1d5ba80..77bc925197013 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -210,6 +210,34 @@ void DistributeFpnProposalsInferMeta( MetaTensor* restore_index, MetaConfig config = MetaConfig()); +void DistributedFusedLambInitInferMeta( + const std::vector& param, + const std::vector& grad, + float beta1, + float beta2, + const std::vector& apply_weight_decay, + int alignment, + int rank, + int nranks, + MetaTensor* fp32_fused_param, + MetaTensor* fp32_fused_grad, + MetaTensor* fp16_fused_param, + MetaTensor* fp16_fused_grad, + MetaTensor* moment1, + MetaTensor* moment2, + MetaTensor* beta1_pow, + MetaTensor* beta2_pow, + MetaTensor* fused_param_offsets, + MetaTensor* fp32_shard_fused_param_offsets, + MetaTensor* fp16_shard_fused_param_offsets, + MetaTensor* param_info, + MetaTensor* param_order, + std::vector param_out, + std::vector master_param_out, + std::vector grad_out, + MetaTensor* global_scale, + MetaTensor* step); + void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); void DropoutInferMeta(const MetaTensor& x, @@ -352,6 +380,11 @@ void IndexAddInferMeta(const MetaTensor& x, void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void LimitByCapacityInferMeta(const MetaTensor& expert_count, + const MetaTensor& capacity, + int n_worker, + MetaTensor* out); + void LogicalBinaryInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); @@ -464,6 +497,12 @@ void PriorBoxInferMeta(const MetaTensor& input, MetaTensor* out, MetaTensor* var); +void PruneGateByCapacityInferMeta(const MetaTensor& gate_idx, + const MetaTensor& expert_count, + int64_t n_expert, + int64_t n_worker, + MetaTensor* new_gate_idx); + void SearchsortedInferMeta(const MetaTensor& sorted_sequence, const MetaTensor& value, bool out_int32, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 6e85754335ce9..b56e7fab0bfe6 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -116,6 +116,108 @@ void AddLayernormXPUInferMeta(const MetaTensor& x, out->share_lod(x); } +void FusedMultiTransformerInferMeta( + const MetaTensor& x, + const std::vector& ln_scales, + const std::vector& ln_biases, + const std::vector& qkv_weights, + const paddle::optional>& qkv_biases, + const paddle::optional>& cache_kvs, + const paddle::optional>& pre_caches, + const MetaTensor& rotary_tensor, + const MetaTensor& time_step, + const MetaTensor& seq_lengths, + const MetaTensor& src_mask, + const std::vector& out_linear_weights, + const paddle::optional>& out_linear_biases, + const std::vector& ffn_ln_scales, + const std::vector& ffn_ln_biases, + const std::vector& ffn1_weights, + const paddle::optional>& ffn1_biases, + const std::vector& ffn2_weights, + const paddle::optional>& ffn2_biases, + bool pre_layer_norm, + float epsilon, + float dropout_rate, + int rotary_emb_dims, + bool is_test, + const std::string& dropout_implementation, + const std::string& act_method, + bool trans_qkvw, + int ring_id, + std::vector cache_kv_outs, + MetaTensor* out) { + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto x_dim = x.dims(); + auto y_dim = qkv_weights[0]->dims(); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + phi::errors::InvalidArgument("The dimensions of x must be 3" + "(batch_size, seq_len, dim_embed)," + "but received dimensions of" + "Input is [%d]", + x_dim.size())); + PADDLE_ENFORCE_EQ( + y_dim.size(), + 4, + phi::errors::InvalidArgument("The dimensions of qkv_weight must be 4" + "(3, num_head, dim_head, dim_embed)," + "but received dimensions of" + "Input is [%d]", + y_dim.size())); + PADDLE_ENFORCE_EQ( + x_dim[2], + trans_qkvw ? y_dim[3] : y_dim[0], + phi::errors::InvalidArgument( + "ShapeError: the dimension of x_dim[2] and y_dim[3](trans_qkvw is " + "true) or y_dim[0](trans_qkvw is false)" + "must be equal. But received: the shape " + "of input x = [%s], and the shape of " + "input qkv_weight = [%s]", + x_dim, + y_dim)); + + if (cache_kvs && cache_kvs->size() > 0) { + // [2, batch_size, num_head, max_seq_len, head_size] + const auto& c_dim = cache_kvs.get()[0]->dims(); + + PADDLE_ENFORCE_EQ( + c_dim.size(), + 5, + phi::errors::InvalidArgument("The CacheKV must be 5 dims, but got %d", + c_dim.size())); + PADDLE_ENFORCE_EQ(c_dim[0], + 2, + phi::errors::InvalidArgument( + "The first dim of CacheKV must be 2, but got %d", + c_dim[0])); // 2 + PADDLE_ENFORCE_EQ(c_dim[1], + x_dim[0], + phi::errors::InvalidArgument( + "The second dim of CacheKV must be equal with " + "batch size %d, but got %d", + x_dim[0], + c_dim[1])); // batch_size + PADDLE_ENFORCE_EQ(c_dim[2], + trans_qkvw ? y_dim[1] : y_dim[2], + phi::errors::InvalidArgument( + "The third dim of CacheKV must be equal with num " + "head %d, but got %d", + trans_qkvw ? y_dim[1] : y_dim[2], + c_dim[2])); // num_head + PADDLE_ENFORCE_EQ(c_dim[4], + trans_qkvw ? y_dim[2] : y_dim[3], + phi::errors::InvalidArgument( + "The fifth dim of CacheKV must be equal with head " + "size %d, but got %d", + trans_qkvw ? y_dim[2] : y_dim[3], + c_dim[4])); // head_size + } + out->set_dims(x.dims()); +} + void BlockMultiheadAttentionInferMeta(const MetaTensor& qkv, const MetaTensor& key_cache, const MetaTensor& value_cache, @@ -975,7 +1077,6 @@ void FusedBiasDropoutResidualLnInferMeta( } void FusedBiasDropoutResidualLnGradInferMeta( - const MetaTensor& y_grad, const MetaTensor& x, const MetaTensor& residual, const MetaTensor& bias, @@ -985,6 +1086,7 @@ void FusedBiasDropoutResidualLnGradInferMeta( const MetaTensor& ln_variance, const MetaTensor& bias_dropout_residual_out, const MetaTensor& dropout_mask_out, + const MetaTensor& y_grad, const float dropout_rate, const bool is_test, const bool dropout_fix_seed, @@ -1447,6 +1549,7 @@ void MultiEncoderXPUInferMeta( const std::vector& ln_scale, const std::vector& ln_bias, const std::vector& smooth_scale_weight, + const std::vector& roformer_embedding, const MetaTensor& mask, const MetaTensor& seq_lod, const MetaTensor& max_seq_len, @@ -1460,6 +1563,7 @@ void MultiEncoderXPUInferMeta( int relative_type, int slice_idx, bool is_per_channel, + int max_pos_len, const std::vector& softmax_max_value, const std::vector& quant_types, MetaTensor* out, @@ -3829,4 +3933,216 @@ void MultiGruInferMeta( hidden->set_dims(out_dims); hidden->share_lod(x); } + +void FusionLstmInferMeta(const MetaTensor& x, + const MetaTensor& weight_x, + const MetaTensor& weight_h, + const MetaTensor& bias, + const MetaTensor& h0, + const MetaTensor& c0, + const bool use_peepholes, + const bool is_reverse, + const bool use_seq, + const std::string& gate_activation, + const std::string& cell_activation, + const std::string& candidate_activation, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + MetaTensor* hidden, + MetaTensor* cell, + MetaTensor* xx, + MetaTensor* batched_input, + MetaTensor* batched_hidden, + MetaTensor* batched_cell, + MetaTensor* reordered_h0, + MetaTensor* reordered_c0, + MetaTensor* checked_cell) { + auto x_dims = x.dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), + 2, + phi::errors::InvalidArgument( + "Input(X)'s rank must be 2, but received x's rank " + "is:%d, x dim is:[%s]", + x_dims.size(), + x_dims)); + + if (h0.initialized()) { + PADDLE_ENFORCE_EQ( + c0.initialized(), + true, + phi::errors::InvalidArgument( + "fusion_lstm must has h0 and c0 input at the same time.")); + auto h_dims = h0.dims(); + auto c_dims = c0.dims(); + PADDLE_ENFORCE_EQ(h_dims, + c_dims, + phi::errors::InvalidArgument( + "The dimension of Input(H0) and Input(C0) should be " + "same, but received h0 dims is:[%s], c0 dims is:[%s]", + h_dims, + c_dims)); + } + + auto wx_dims = weight_x.dims(); + PADDLE_ENFORCE_EQ(wx_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(WeightX) should be 2, but received " + "WeightX's rank is:%d, WeightX dim is:[%s]", + wx_dims.size(), + wx_dims)); + PADDLE_ENFORCE_EQ(wx_dims[0], + x_dims[1], + phi::errors::InvalidArgument( + "The first dimension of Input(WeightX) " + "should equal to second dimension of Input(X), but " + "received WeightX first dim is:%d, X second dim is:%d", + wx_dims[0], + x_dims[1])); + + int frame_size = static_cast(wx_dims[1] / 4); + auto wh_dims = weight_h.dims(); + + PADDLE_ENFORCE_EQ(wh_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(WeightH) should be 2, but received " + "WeightH rank is:%d, WeightH dim is:[%s]", + wh_dims.size(), + wh_dims)); + PADDLE_ENFORCE_EQ(wh_dims[0], + frame_size, + phi::errors::InvalidArgument( + "The first dimension of Input(WeightH) " + "should equal to frame size, but received WeightH " + "first dim is:%d, frame size is:%d.", + wh_dims[0], + frame_size)); + + PADDLE_ENFORCE_EQ(wh_dims[1], + 4 * frame_size, + phi::errors::InvalidArgument( + "The second dimension of Input(WeightH) " + "should equal to 4 * frame_size, but received WeightH " + "second dimension is:%d, frame size is:%d.", + wh_dims[1], + frame_size)); + + auto b_dims = bias.dims(); + PADDLE_ENFORCE_EQ(b_dims.size(), + 2, + phi::errors::InvalidArgument( + "The rank of Input(Bias) should be 2, but received " + "Bias rank is:%d, Bias dim is:[%s]", + b_dims.size(), + b_dims)); + PADDLE_ENFORCE_EQ(b_dims[0], + 1, + phi::errors::InvalidArgument( + "The first dimension of Input(Bias) should be 1, but " + "received Bias's dimension is:[%s]", + b_dims)); + + if (use_peepholes) { + PADDLE_ENFORCE_EQ(b_dims[1], + 7 * frame_size, + phi::errors::InvalidArgument( + "The second dimension of Input(Bias) should be " + "7 * %d if enable peepholes connection, but received " + "Bias dim is:[%s]", + frame_size, + b_dims)); + checked_cell->set_dims(phi::make_ddim({2, frame_size})); + checked_cell->set_dtype(x.dtype()); + } else { + PADDLE_ENFORCE_EQ( + b_dims[1], + 4 * frame_size, + phi::errors::InvalidArgument( + "The second dimension of Input(Bias) should be " + "4 * %d if disable peepholes, but received Bias dim is:[%s]", + frame_size, + b_dims)); + } + + auto out_dims = phi::make_ddim({x_dims[0], frame_size}); + hidden->set_dims(out_dims); + cell->set_dims(out_dims); + hidden->share_lod(x); + cell->share_lod(x); + hidden->set_dtype(x.dtype()); + cell->set_dtype(x.dtype()); + + int xx_width = 0; + if (use_seq) { + xx_width = static_cast(wx_dims[1]); + } else { + xx_width = + static_cast(x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]); + + batched_input->set_dims(phi::make_ddim({x_dims[0], wx_dims[1]})); + batched_hidden->set_dims(out_dims); + batched_cell->set_dims(out_dims); + batched_input->set_dtype(x.dtype()); + batched_hidden->set_dtype(x.dtype()); + batched_cell->set_dtype(x.dtype()); + } + xx->set_dims(phi::make_ddim({x_dims[0], xx_width})); + xx->set_dtype(x.dtype()); + xx->share_lod(x); +} + +void RoformerRelativePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& sin_emb, + const MetaTensor& cos_emb, + int max_pos_len, + MetaTensor* out) { + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + auto sin_emb_dims = sin_emb.dims(); + auto sin_emb_dims_size = sin_emb_dims.size(); + auto cos_emb_dims = cos_emb.dims(); + auto cos_emb_dims_size = cos_emb_dims.size(); + PADDLE_ENFORCE_EQ( + x_dims_size, + 4, + phi::errors::InvalidArgument( + "x_dims_size should be 4, but received x_dims_size is %d", + x_dims_size)); + PADDLE_ENFORCE_EQ( + sin_emb_dims_size, + 4, + phi::errors::InvalidArgument( + "sin_emb_dims_size should be 4, but received sin_emb_dims_size is %d", + sin_emb_dims_size)); + PADDLE_ENFORCE_EQ( + cos_emb_dims_size, + 4, + phi::errors::InvalidArgument( + "cos_emb_dims_size should be 4, but received cos_emb_dims_size is %d", + cos_emb_dims_size)); + for (int i = 0; i < sin_emb_dims_size; i++) { + PADDLE_ENFORCE_EQ( + sin_emb_dims[i], + cos_emb_dims[i], + phi::errors::InvalidArgument( + "sin_emb_dims[i] should be equal to cos_emb_dims[i], index i is " + "%d, sin_emb_dims[i] is %d, cos_emb_dims[i] is %d", + i, + sin_emb_dims[i], + cos_emb_dims[i])); + } + PADDLE_ENFORCE_EQ( + x_dims[3], + cos_emb_dims[3], + phi::errors::InvalidArgument("x_dims[3] should be equal to cos_dims[3], " + "but sin_dims[3] is %d, cos_dims[3] is %d", + x_dims[3], + cos_emb_dims[3])); + out->set_dims(x_dims); + out->set_dtype(x.dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 767f22fd245f4..0a7224e39f73b 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -22,6 +22,38 @@ namespace phi { // Common InferMeta Functions for fusion operators. // NOTE: The InferMeta Functions in this file are arranged in alphabetic order. +void FusedMultiTransformerInferMeta( + const MetaTensor& x, + const std::vector& ln_scales, + const std::vector& ln_biases, + const std::vector& qkv_weights, + const paddle::optional>& qkv_biases, + const paddle::optional>& cache_kvs, + const paddle::optional>& pre_caches, + const MetaTensor& rotary_tensor, + const MetaTensor& time_step, + const MetaTensor& seq_lengths, + const MetaTensor& src_mask, + const std::vector& out_linear_weights, + const paddle::optional>& out_linear_biases, + const std::vector& ffn_ln_scales, + const std::vector& ffn_ln_biases, + const std::vector& ffn1_weights, + const paddle::optional>& ffn1_biases, + const std::vector& ffn2_weights, + const paddle::optional>& ffn2_biases, + bool pre_layer_norm, + float epsilon, + float dropout_rate, + int rotary_emb_dims, + bool is_test, + const std::string& dropout_implementation, + const std::string& act_method, + bool trans_qkvw, + int ring_id, + std::vector cache_kv_outs, + MetaTensor* out); + void AddActXPUInferMeta(const MetaTensor& x, const MetaTensor& x_max, const MetaTensor& y, @@ -151,6 +183,7 @@ void MultiEncoderXPUInferMeta( const std::vector& ln_scale, const std::vector& ln_bias, const std::vector& smooth_scale_weight, + const std::vector& roformer_embedding, const MetaTensor& mask, const MetaTensor& seq_lod, const MetaTensor& max_seq_len, @@ -164,6 +197,7 @@ void MultiEncoderXPUInferMeta( int relative_type, int slice_idx, bool is_per_channel, + int max_pos_len, const std::vector& softmax_max_value, const std::vector& quant_types, MetaTensor* out, @@ -753,7 +787,6 @@ void FusedBiasDropoutResidualLnInferMeta( MetaTensor* ln_variance); void FusedBiasDropoutResidualLnGradInferMeta( - const MetaTensor& y_grad, const MetaTensor& x, const MetaTensor& residual, const MetaTensor& bias, @@ -763,6 +796,7 @@ void FusedBiasDropoutResidualLnGradInferMeta( const MetaTensor& ln_variance, const MetaTensor& bias_dropout_residual_out, const MetaTensor& dropout_mask_out, + const MetaTensor& y_grad, const float dropout_rate, const bool is_test, const bool dropout_fix_seed, @@ -838,6 +872,11 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, void SinePosXPUInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void RoformerRelativePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& sin_emb, + const MetaTensor& cos_emb, + int max_pos_len, + MetaTensor* out); void MultiGruInferMeta( const MetaTensor& x, @@ -854,4 +893,31 @@ void MultiGruInferMeta( float shift_data, bool force_fp32_output, MetaTensor* hidden); + +void FusionLstmInferMeta(const MetaTensor& x, + const MetaTensor& weight_x, + const MetaTensor& weight_h, + const MetaTensor& bias, + const MetaTensor& h0, + const MetaTensor& c0, + const bool use_peepholes, + const bool is_reverse, + const bool use_seq, + const std::string& gate_activation, + const std::string& cell_activation, + const std::string& candidate_activation, + const float scale_data, + const float shift_data, + const std::vector& scale_weights, + const bool force_fp32_output, + MetaTensor* hidden, + MetaTensor* cell, + MetaTensor* xx, + MetaTensor* batched_input, + MetaTensor* batched_hidden, + MetaTensor* batched_cell, + MetaTensor* reordered_h0, + MetaTensor* reordered_c0, + MetaTensor* checked_cell); + } // namespace phi diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index bb57e5a813aa7..01b4f96580b4a 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4273,6 +4273,15 @@ void WeightOnlyLinearInferMeta(const MetaTensor& x, "But received Input(X) dim[-1](%s) != Input(Weight) dim[1](%s)", x_dims[x_dims.size() - 1], w_dims[1])); + if (bias.initialized()) { + auto bias_dims = bias.dims(); + PADDLE_ENFORCE_EQ( + bias_dims.size(), + 1UL, + errors::InvalidArgument( + "The size of Input(Bias)'s dimension should equal to 1UL.", + bias_dims.size())); + } // per-channel dequantization if (group_size == -1) { @@ -4584,6 +4593,86 @@ void FusedRopeInferMeta(const MetaTensor& q, } } +void FusedTokenPruneInferMeta(const MetaTensor& attn, + const MetaTensor& x, + const MetaTensor& mask, + const MetaTensor& new_mask, + bool keep_first_token, + bool keep_order, + MetaTensor* slimmed_x, + MetaTensor* cls_inds) { + auto mask_dim = mask.dims(); + auto attn_dim = attn.dims(); + auto x_dim = x.dims(); + auto new_mask_dim = new_mask.dims(); + + PADDLE_ENFORCE_EQ( + mask_dim.size(), + 4, + phi::errors::InvalidArgument("The input mask must be 4-dimension")); + PADDLE_ENFORCE_EQ( + attn_dim.size(), + 4, + phi::errors::InvalidArgument("The input attn must be 4-dimension")); + PADDLE_ENFORCE_EQ( + x_dim.size(), + 3, + phi::errors::InvalidArgument("The input x must be 4-dimension")); + PADDLE_ENFORCE_EQ( + new_mask_dim.size(), + 4, + phi::errors::InvalidArgument("The input attn must be 4-dimension")); + PADDLE_ENFORCE_EQ(mask_dim[0], + attn_dim[0], + phi::errors::InvalidArgument( + "The first dim of mask and attn should be the same" + "which is batch size")); + PADDLE_ENFORCE_EQ(mask_dim[1], + attn_dim[1], + phi::errors::InvalidArgument( + "The second dim of mask and attn should be the same" + "which is nb_head")); + PADDLE_ENFORCE_EQ(mask_dim[0], + x_dim[0], + phi::errors::InvalidArgument( + "The first dim of mask and x should be the same" + "which is batch size")); + PADDLE_ENFORCE_EQ( + mask_dim[2], + mask_dim[3], + phi::errors::InvalidArgument( + "The third dim and the fourth dim of mask should be the same" + "which is max seq len")); + PADDLE_ENFORCE_EQ( + attn_dim[2], + attn_dim[3], + phi::errors::InvalidArgument( + "The third dim and the fourth dim of mask should be the same" + "which is max seq len")); + PADDLE_ENFORCE_EQ(attn_dim[2], + mask_dim[2], + phi::errors::InvalidArgument( + "The third dim of mask and attn should be the same" + "which is max seq len")); + PADDLE_ENFORCE_EQ(attn_dim[2], + x_dim[1], + phi::errors::InvalidArgument( + "The third dim of mask and the second dim of attn" + "should be the same which is max seq len")); + + auto bsz = mask_dim[0]; + auto c = x_dim[2]; + auto slim_seq_len = new_mask_dim[2]; + + std::vector slimmed_x_dims({bsz, slim_seq_len, c}); + slimmed_x->set_dims(common::make_ddim(slimmed_x_dims)); + slimmed_x->set_dtype(x.dtype()); + + std::vector cls_inds_dims({bsz, slim_seq_len}); + cls_inds->set_dims(common::make_ddim(cls_inds_dims)); + cls_inds->set_dtype(phi::DataType::INT64); +} + void MoeInferMeta(const MetaTensor& x, const MetaTensor& gate, const MetaTensor& bmm0, @@ -4706,8 +4795,8 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, int v_num_head = k_num_head; int dim_head = static_cast(cache_kv.dims()[4]); // below's num_head is q's head actually. - int num_head = - x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head; + int num_head = x.dims()[x.dims().size() - 1] / dim_head - k_num_head - + v_num_head; // NOLINT PADDLE_ENFORCE_EQ( num_head % k_num_head, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index e83ef2ed1825d..3722a0d5844ba 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -908,6 +908,15 @@ void FusedRopeInferMeta(const MetaTensor& q, MetaTensor* out_k, MetaTensor* out_v); +void FusedTokenPruneInferMeta(const MetaTensor& attn, + const MetaTensor& x, + const MetaTensor& mask, + const MetaTensor& new_mask, + bool keep_first_token, + bool keep_order, + MetaTensor* slimmed_x, + MetaTensor* cls_inds); + void MultiheadMatmulInferMeta(const MetaTensor& input, const MetaTensor& w, const MetaTensor& bias, diff --git a/paddle/phi/infermeta/nullary.cc b/paddle/phi/infermeta/nullary.cc index d1bd204a682d9..5917a7a46b5ca 100644 --- a/paddle/phi/infermeta/nullary.cc +++ b/paddle/phi/infermeta/nullary.cc @@ -123,6 +123,18 @@ void GaussianInferMeta(const IntArray& shape, out->set_layout(DataLayout::NCHW); } +void PartialRecvInferMeta(int ring_id, + int peer, + DataType dtype, + const std::vector& out_shape, + bool use_calc_stream, + int num, + int id, + MetaTensor* out) { + out->set_dims(common::make_ddim(out_shape)); + out->set_dtype(dtype); +} + void RandpermInferMeta(int n, DataType dtype, MetaTensor* out) { out->set_dims(common::make_ddim({n})); out->set_dtype(dtype); diff --git a/paddle/phi/infermeta/nullary.h b/paddle/phi/infermeta/nullary.h index 5eda8fc1a8461..b35b37acc7244 100644 --- a/paddle/phi/infermeta/nullary.h +++ b/paddle/phi/infermeta/nullary.h @@ -80,6 +80,15 @@ void RandpermInferMeta(int n, DataType dtype, MetaTensor* out); void RandintInferMeta( int low, int high, const IntArray& shape, DataType dtype, MetaTensor* out); +void PartialRecvInferMeta(int ring_id, + int peer, + DataType dtype, + const std::vector& out_shape, + bool use_calc_stream, + int num, + int id, + MetaTensor* out); + void PRecvInferMeta(int peer, DataType dtype, MetaTensor* out); void PRecvArrayInferMeta(int peer, diff --git a/paddle/phi/infermeta/sparse/unary.cc b/paddle/phi/infermeta/sparse/unary.cc index f80f18bbba857..01da3ae08eb74 100644 --- a/paddle/phi/infermeta/sparse/unary.cc +++ b/paddle/phi/infermeta/sparse/unary.cc @@ -36,5 +36,21 @@ void ValuesInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_layout(x.layout()); } +void CastInferMeta(const MetaTensor& x, + DataType index_dtype, + DataType out_dtype, + MetaTensor* out) { + out->set_dims(x.dims()); + out->set_layout(x.layout()); + out->share_lod(x); + // In inplace case, setting the dtype of out will reset the dtype of x at the + // same time, which will cause bugs, so move the dtype setting of out to the + // kernel + + if (!(out->is_same_tensor(x))) { + out->set_dtype(out_dtype); + } +} + } // namespace sparse } // namespace phi diff --git a/paddle/phi/infermeta/sparse/unary.h b/paddle/phi/infermeta/sparse/unary.h index 880e90b7ae697..5ee7f054143c0 100644 --- a/paddle/phi/infermeta/sparse/unary.h +++ b/paddle/phi/infermeta/sparse/unary.h @@ -24,5 +24,10 @@ void IndicesInferMeta(const MetaTensor& x, MetaTensor* out); void ValuesInferMeta(const MetaTensor& x, MetaTensor* out); +void CastInferMeta(const MetaTensor& x, + DataType index_dtype, + DataType out_dtype, + MetaTensor* out); + } // namespace sparse } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/argmax.cc b/paddle/phi/infermeta/spmd_rules/argmax.cc new file mode 100644 index 0000000000000..baf8ec2276268 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/argmax.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/argmax.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +SpmdInfo ArgMaxInferSpmdBase(const DistMetaTensor& x, + int axis, + bool keepdims, + bool flatten) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + axis = axis < 0 ? x_ndim + axis : axis; + + std::vector x_dims_mapping_dst(x_dims_mapping_src); + std::vector out_dims_mapping; + if (flatten) { + x_dims_mapping_dst.assign(x_ndim, -1); + if (keepdims) { + out_dims_mapping.assign(x_ndim, -1); + } else { + out_dims_mapping.push_back(-1); + } + } else { + x_dims_mapping_dst[axis] = -1; + out_dims_mapping.assign(x_dims_mapping_dst.begin(), + x_dims_mapping_dst.end()); + if (!keepdims) { + out_dims_mapping.erase(out_dims_mapping.begin() + axis); + } + } + + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + VLOG(4) << "ArgMaxInferSpmd:"; + VLOG(4) << "x:"; + VLOG(4) << "src_dist_attr: [" << x_dist_attr_src.to_string() << "] " + << "dst_dist_attr: [" << x_dist_attr_dst.to_string() << "]"; + VLOG(4) << "out:"; + VLOG(4) << "dist_attr: [" << out_dist_attr.to_string() << "]" << std::endl; + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo ArgMaxInferSpmdReverseBase(const DistMetaTensor& x, + const DistMetaTensor& out, + int axis, + bool keepdims, + bool flatten) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(out); + axis = axis < 0 ? x_ndim + axis : axis; + std::vector x_dims_mapping_dst; + std::vector out_dims_mapping_dst(out_dims_mapping_src); + + if (flatten) { + if (keepdims) { + out_dims_mapping_dst.assign(x_ndim, -1); + } else { + out_dims_mapping_dst.push_back(-1); + } + x_dims_mapping_dst.assign(x_ndim, -1); + } else { + x_dims_mapping_dst.assign(out_dims_mapping_dst.begin(), + out_dims_mapping_dst.end()); + if (!keepdims) { + x_dims_mapping_dst.insert(x_dims_mapping_dst.begin() + axis, -1); + } + } + + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + TensorDistAttr out_dist_attr_dst = + CopyTensorDistAttrForOutput(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(out_dims_mapping_dst); + + VLOG(4) << "ArgMaxInferSpmdReverse:"; + VLOG(4) << "out:"; + VLOG(4) << "src_dist_attr: [" << out_dist_attr_src.to_string() << "] " + << "dst_dist_attr: [" << out_dist_attr_dst.to_string() << "]"; + VLOG(4) << "x:"; + VLOG(4) << "src_dist_attr: [" << x_dist_attr_src.to_string() << "] " + << "dst_dist_attr: [" << x_dist_attr_dst.to_string() << "]" + << std::endl; + return {{x_dist_attr_dst}, {out_dist_attr_dst}}; +} + +SpmdInfo ArgMaxInferSpmdDynamic(const DistMetaTensor& x, + const Scalar& axis, + bool keepdims, + bool flatten, + DataType dtype) { + return ArgMaxInferSpmdBase(x, axis.to(), keepdims, flatten); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/argmax.h b/paddle/phi/infermeta/spmd_rules/argmax.h new file mode 100644 index 0000000000000..186e16c9f9998 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/argmax.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo ArgMaxInferSpmdBase(const DistMetaTensor& x, + int axis, + bool keepdims, + bool flatten); + +SpmdInfo ArgMaxInferSpmdReverseBase(const DistMetaTensor& x, + const DistMetaTensor& out, + int axis, + bool keepdims, + bool flatten); + +SpmdInfo ArgMaxInferSpmdDynamic(const DistMetaTensor& x, + const Scalar& axis, + bool keepdims, + bool flatten, + DataType dtype); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/cumsum.cc b/paddle/phi/infermeta/spmd_rules/cumsum.cc new file mode 100644 index 0000000000000..a93a617bb7780 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/cumsum.cc @@ -0,0 +1,124 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/cumsum.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo CumSumInferSpmd(const DistMetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + + std::vector x_dims_mapping_dst(x_dims_mapping_src); + std::vector out_dims_mapping; + if (flatten) { + x_dims_mapping_dst.assign(x_ndim, -1); + out_dims_mapping.assign(1, -1); + } else { + x_dims_mapping_dst[axis] = -1; + out_dims_mapping.assign(x_dims_mapping_dst.begin(), + x_dims_mapping_dst.end()); + } + + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + VLOG(4) << "CumSumInferSpmd:"; + VLOG(4) << "axis: " << axis << "flatten: " << flatten; + VLOG(4) << "x shape: [" << str_join(x_shape) << "], " + << "src_dist_attr: [" << x_dist_attr_src.to_string() << "], " + << "dst_dist_attr: [" << x_dist_attr_dst.to_string() << "]"; + VLOG(4) << "out dist_attr: [" << out_dist_attr.to_string() << "]"; + VLOG(4) << std::endl; + + return {{x_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo CumSumInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int axis, + bool flatten, + bool exclusive, + bool reverse) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(out); + + std::vector out_dims_mapping_dst(out_dims_mapping_src); + std::vector x_dims_mapping_dst; + + if (flatten) { + out_dims_mapping_dst.assign(1, -1); + x_dims_mapping_dst.assign(x_ndim, -1); + } else { + out_dims_mapping_dst[axis] = -1; + x_dims_mapping_dst.assign(out_dims_mapping_dst.begin(), + out_dims_mapping_dst.end()); + } + + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + TensorDistAttr out_dist_attr_dst = + CopyTensorDistAttrForOutput(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(out_dims_mapping_dst); + + VLOG(4) << "CumSumInferSpmdReverse:"; + VLOG(4) << "axis: " << axis << "flatten: " << flatten; + VLOG(4) << "out shape: [" << str_join(out_shape) << "], " + << "src_dist_attr: [" << out_dist_attr_src.to_string() << "], " + << "dst_dist_attr: [" << out_dist_attr_dst.to_string() << "]"; + VLOG(4) << "x shape: [" << str_join(x_shape) << "], " + << "src_dist_attr: [" << x_dist_attr_src.to_string() << "], " + << "dst_dist_attr: [" << x_dist_attr_dst.to_string() << "]"; + VLOG(4) << std::endl; + + return {{x_dist_attr_dst}, {out_dist_attr_dst}}; +} + +SpmdInfo CumSumInferSpmdDynamic(const DistMetaTensor& x, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse) { + return CumSumInferSpmd(x, axis.to(), flatten, exclusive, reverse); +} + +SpmdInfo CumSumGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse) { + SpmdInfo info = CumSumInferSpmdReverse( + x, out_grad, axis.to(), flatten, exclusive, reverse); + return {{x.dist_attr(), info.second[0]}, {info.first[0]}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/cumsum.h b/paddle/phi/infermeta/spmd_rules/cumsum.h new file mode 100644 index 0000000000000..4de46bdf16c52 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/cumsum.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo CumSumInferSpmd(const DistMetaTensor& x, + int axis, + bool flatten, + bool exclusive, + bool reverse); + +SpmdInfo CumSumInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int axis, + bool flatten, + bool exclusive, + bool reverse); + +SpmdInfo CumSumInferSpmdDynamic(const DistMetaTensor& x, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse); + +SpmdInfo CumSumGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const Scalar& axis, + bool flatten, + bool exclusive, + bool reverse); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.cc b/paddle/phi/infermeta/spmd_rules/elementwise.cc index 3db396de8b613..4e12c994b595b 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.cc +++ b/paddle/phi/infermeta/spmd_rules/elementwise.cc @@ -31,7 +31,7 @@ std::string GetInputBroadcastNotation(const std::vector& shape, const int max_ndim, const std::string& alphabet, std::vector* broadcast_axis_count) { - int ndim = shape.size(); + int ndim = static_cast(shape.size()); int start_dim = max_ndim - ndim; std::string axes_notation = GetBroadcastAxes(ndim, max_ndim, alphabet); @@ -54,8 +54,8 @@ void GetBinaryNotations(const std::vector& x_shape, std::string* x_axes, std::string* y_axes, std::string* out_axes) { - int x_ndim = x_shape.size(); - int y_ndim = y_shape.size(); + int x_ndim = static_cast(x_shape.size()); + int y_ndim = static_cast(y_shape.size()); int max_ndim = std::max(x_ndim, y_ndim); int ninputs = 2; std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; @@ -82,7 +82,7 @@ void GetBinaryNotations(const std::vector& x_shape, SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x) { // Step0: Verify Input Args Based on Elementwise Logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); TensorDistAttr x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ(x_ndim, @@ -129,7 +129,7 @@ SpmdInfo ElementwiseUnaryInferSpmd(const DistMetaTensor& x) { SpmdInfo ElementwiseUnaryWithPartialInferSpmd(const DistMetaTensor& x) { // Step0: Verify Input Args Based on Elementwise Logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); TensorDistAttr x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ(x_ndim, @@ -177,9 +177,9 @@ SpmdInfo ElementwiseUnaryInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out) { // Step0: Verify Input Args Based on Elementwise Logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto out_shape = common::vectorize(out.dims()); - int out_ndim = out_shape.size(); + int out_ndim = static_cast(out_shape.size()); TensorDistAttr out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -233,9 +233,9 @@ SpmdInfo ElementwiseBinaryInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y) { // Step0: Verify Input Args Based on Elementwise Logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto y_shape = common::vectorize(y.dims()); - int y_ndim = y_shape.size(); + int y_ndim = static_cast(y_shape.size()); TensorDistAttr x_dist_attr_src = x.dist_attr(); TensorDistAttr y_dist_attr_src = y.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); @@ -303,11 +303,11 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, const DistMetaTensor& out) { // Step0: Verify Input Args Based on Elementwise Logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto y_shape = common::vectorize(y.dims()); - int y_ndim = y_shape.size(); + int y_ndim = static_cast(y_shape.size()); auto out_shape = common::vectorize(out.dims()); - int out_ndim = out_shape.size(); + int out_ndim = static_cast(out_shape.size()); int max_ndim = std::max(x_ndim, y_ndim); TensorDistAttr out_dist_attr = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr.dims_mapping(); @@ -365,14 +365,17 @@ SpmdInfo ElementwiseBinaryInferSpmdReverse(const DistMetaTensor& x, SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& out_grad) { - return {{out_grad.dist_attr(), out_grad.dist_attr()}, {out_grad.dist_attr()}}; + auto dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr()); + dist_attr.set_dims_mapping(out_grad.dist_attr().dims_mapping()); + return {{dist_attr, dist_attr}, {dist_attr}}; } SpmdInfo ElementwiseUnaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& out, const DistMetaTensor& out_grad) { - return {{out_grad.dist_attr(), out_grad.dist_attr(), out_grad.dist_attr()}, - {out_grad.dist_attr()}}; + auto dist_attr = CopyTensorDistAttrForOutput(out_grad.dist_attr()); + dist_attr.set_dims_mapping(out_grad.dist_attr().dims_mapping()); + return {{dist_attr, dist_attr, dist_attr}, {dist_attr}}; } bool DimsNotEqualOrHasBroadcastDim(const DistMetaTensor& x, diff --git a/paddle/phi/infermeta/spmd_rules/elementwise.h b/paddle/phi/infermeta/spmd_rules/elementwise.h index a25de93679439..d93b8416f878a 100644 --- a/paddle/phi/infermeta/spmd_rules/elementwise.h +++ b/paddle/phi/infermeta/spmd_rules/elementwise.h @@ -54,5 +54,15 @@ SpmdInfo ElementwiseBinaryGradInferSpmd(const DistMetaTensor& x, const DistMetaTensor& out_grad, int64_t axis = -1); +SpmdInfo SwiGLUInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y); + +SpmdInfo SwiGLUInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out); + +SpmdInfo SwiGLUGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out_grad); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/expand_as.cc b/paddle/phi/infermeta/spmd_rules/expand_as.cc new file mode 100644 index 0000000000000..6bd663c826664 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/expand_as.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/expand_as.h" + +#include "glog/logging.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +std::tuple AlignExpandAsDistAttrs( + const DistMetaTensor& x, const DistMetaTensor& y) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(y); + auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + auto y_dist_attr_dst = CopyTensorDistAttrForOutput(y_dist_attr_src); + auto x_dims_mapping_dst = x_dims_mapping_src; + auto y_dims_mapping_dst = y_dims_mapping_src; + int dims_diff = y_ndim - x_ndim; + for (int i = 0; i < y_ndim; ++i) { + if (i >= dims_diff) { + if (x_shape[i - dims_diff] == y_shape[i]) { + x_dims_mapping_dst[i - dims_diff] = y_dims_mapping_src[i]; + } else { + x_dims_mapping_dst[i - dims_diff] = -1; + } + } + } + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + y_dist_attr_dst.set_dims_mapping(y_dims_mapping_dst); + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(y); + return {x_dist_attr_dst, y_dist_attr_dst}; +} + +SpmdInfo ExpandAsInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const std::vector& target_shape) { + auto [x_dist_attr, y_dist_attr] = AlignExpandAsDistAttrs(x, y); + return {{x_dist_attr, y_dist_attr}, {y_dist_attr}}; +} + +SpmdInfo ExpandAsInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& output, + const std::vector& target_shape) { + auto [x_dist_attr, y_dist_attr] = AlignExpandAsDistAttrs(x, output); + return {{x_dist_attr, y_dist_attr}, {y_dist_attr}}; +} + +SpmdInfo ExpandAsGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const std::vector& target_shape) { + auto [x_dist_attr, y_dist_attr] = AlignExpandAsDistAttrs(x, out_grad); + const auto& x_dims_mapping = x_dist_attr.dims_mapping(); + const auto& y_dims_mapping = y_dist_attr.dims_mapping(); + + // handle partial grad + auto x_grad_dist_attr = x_dist_attr; + int x_ndims = x_dims_mapping.size(); + int y_ndims = y_dims_mapping.size(); + int dims_diff = y_ndims - x_ndims; + std::vector partial; + for (int i = 0; i < y_ndims; ++i) { + if (i < dims_diff || x_dims_mapping[i - dims_diff] != y_dims_mapping[i]) { + if (y_dims_mapping[i] >= 0) { + partial.push_back(y_dims_mapping[i]); + } + } + } + x_grad_dist_attr.set_partial_status(partial); + return {{x_dist_attr, y_dist_attr}, {x_grad_dist_attr}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/expand_as.h b/paddle/phi/infermeta/spmd_rules/expand_as.h new file mode 100644 index 0000000000000..67cc6f3853dc1 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/expand_as.h @@ -0,0 +1,38 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo ExpandAsInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const std::vector& target_shape); + +SpmdInfo ExpandAsInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& output, + const std::vector& target_shape); + +SpmdInfo ExpandAsGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& out_grad, + const std::vector& target_shape); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/flash_attention.cc b/paddle/phi/infermeta/spmd_rules/flash_attention.cc index edec1af106a39..737ad4eff03c9 100644 --- a/paddle/phi/infermeta/spmd_rules/flash_attention.cc +++ b/paddle/phi/infermeta/spmd_rules/flash_attention.cc @@ -21,6 +21,7 @@ limitations under the License. */ namespace phi { namespace distributed { +const int kNumHeadsDimIndex = 2; #define LOG_SPMD_INPUT(name) \ do { \ @@ -109,10 +110,10 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, k_batch_size)); PADDLE_ENFORCE_EQ( - num_heads, - k_num_heads, + num_heads % k_num_heads == 0, + true, phi::errors::InvalidArgument( - "The Tensor q and k's num_heads [%d] vs [%d] are not matched.", + "The num_heads of q must be divisible by k's, but [%d] vs [%d].", num_heads, k_num_heads)); @@ -132,6 +133,14 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, k_ndim, k_dims_mapping_size)); + bool is_divisible = true; + int64_t num_head_mesh_dim = k_dist_attr.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + k_dist_attr.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = k_num_heads % num_head_split_size == 0; + } + // v // [batch_size, seq_len_kv, num_heads, head_dim] auto v_shape = common::vectorize(v.dims()); @@ -157,13 +166,15 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, v_batch_size)); PADDLE_ENFORCE_EQ( - num_heads, - v_num_heads, + num_heads % v_num_heads == 0, + true, phi::errors::InvalidArgument( - "The Tensor q and v's num_heads [%d] vs [%d] are not matched.", + "The num_heads of q must be divisible by v's, but [%d] vs [%d].", num_heads, v_num_heads)); + bool is_same_num_heads = num_heads == v_num_heads; + PADDLE_ENFORCE_EQ( k_seq_len, v_seq_len, @@ -230,6 +241,12 @@ SpmdInfo FlashAttInferSpmd(const DistMetaTensor& q, auto k_dist_attr_dst = UnShardTensorDims(k_dist_attr, {1, 3}); auto v_dist_attr_dst = UnShardTensorDims(k_dist_attr, {1, 3}); + if (!is_same_num_heads && !is_divisible) { + q_dist_attr_dst = UnShardTensorDims(q_dist_attr, {2}); + k_dist_attr_dst = UnShardTensorDims(k_dist_attr, {2}); + v_dist_attr_dst = UnShardTensorDims(k_dist_attr, {2}); + } + std::vector>> axes_sharding_info; axes_sharding_info.emplace_back(q_axes, q_dist_attr_dst.dims_mapping()); @@ -454,6 +471,21 @@ SpmdInfo FlashAttInferSpmdReverse(const DistMetaTensor& q, auto softmax_lse_dist_attr_dst = UnShardTensorDims(softmax_lse_dist_attr, {2}); + bool is_same_num_heads = q_shape[2] == k_shape[2]; + bool is_divisible = true; + int64_t num_head_mesh_dim = k_dist_attr.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + k_dist_attr.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = k_shape[2] % num_head_split_size == 0; + } + + if (!is_same_num_heads && !is_divisible) { + out_dist_attr_dst = UnShardTensorDims(out_dist_attr_dst, {2}); + softmax_lse_dist_attr_dst = + UnShardTensorDims(softmax_lse_dist_attr_dst, {1}); + } + std::vector>> axes_sharding_info; axes_sharding_info.emplace_back(out_axes, out_dist_attr_dst.dims_mapping()); @@ -566,10 +598,10 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, k_batch_size)); PADDLE_ENFORCE_EQ( - num_heads, - k_num_heads, + num_heads % k_num_heads == 0, + true, phi::errors::InvalidArgument( - "The Tensor q and k's num_heads [%d] vs [%d] are not matched.", + "The num_heads of q must be divisible by k's, but [%d] vs [%d].", num_heads, k_num_heads)); @@ -614,10 +646,10 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, v_batch_size)); PADDLE_ENFORCE_EQ( - num_heads, - v_num_heads, + num_heads % v_num_heads == 0, + true, phi::errors::InvalidArgument( - "The Tensor q and v's k_num_heads [%d] vs [%d] are not matched.", + "The num_head of q must be divisible by v's, but [%d] vs [%d].", num_heads, v_num_heads)); @@ -700,6 +732,24 @@ SpmdInfo FlashAttGradInferSpmd(const DistMetaTensor& q, auto softmax_lse_dist_attr_dst = UnShardTensorDims(softmax_lse_dist_attr, {2}); + bool is_same_num_heads = num_heads == v_num_heads; + bool is_divisible = true; + int64_t num_head_mesh_dim = k_dist_attr.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + k_dist_attr.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = k_shape[2] % num_head_split_size == 0; + } + if (!is_same_num_heads && !is_divisible) { + q_dist_attr_dst = UnShardTensorDims(q_dist_attr_dst, {2}); + k_dist_attr_dst = UnShardTensorDims(k_dist_attr_dst, {2}); + v_dist_attr_dst = UnShardTensorDims(v_dist_attr_dst, {2}); + out_dist_attr_dst = UnShardTensorDims(out_dist_attr_dst, {2}); + out_grad_dist_attr_dst = UnShardTensorDims(out_grad_dist_attr_dst, {2}); + softmax_lse_dist_attr_dst = + UnShardTensorDims(softmax_lse_dist_attr_dst, {1}); + } + std::vector>> axes_sharding_info; axes_sharding_info.emplace_back(q_axes, q_dist_attr_dst.dims_mapping()); axes_sharding_info.emplace_back(k_axes, k_dist_attr_dst.dims_mapping()); diff --git a/paddle/phi/infermeta/spmd_rules/fused_rope.cc b/paddle/phi/infermeta/spmd_rules/fused_rope.cc index 138f0813be2c5..e58b987fb3499 100644 --- a/paddle/phi/infermeta/spmd_rules/fused_rope.cc +++ b/paddle/phi/infermeta/spmd_rules/fused_rope.cc @@ -68,20 +68,43 @@ void check_k_or_v(const DistMetaTensor& k_or_v, ndim, dims_mapping_size)); + int64_t k_num_head = shape[kNumHeadsDimIndex]; + int64_t q_num_head = q_shape[kNumHeadsDimIndex]; PADDLE_ENFORCE_EQ( - shape, - q_shape, - phi::errors::InvalidArgument( - "The shape of q and k/v's are not matched, [%d] vs [%d]", - str_join(q_shape), - str_join(shape))); + q_num_head % k_num_head == 0, + true, + phi::errors::InvalidArgument("The num_head of q must be divisible by k " + "and v, but got [%d] vs [%d]", + q_num_head, + k_num_head)); + + for (size_t i = 0; i <= kHeadDimIndex; ++i) { + if (i == kNumHeadsDimIndex) { + PADDLE_ENFORCE_EQ( + q_shape[i] % shape[i] == 0, + true, + phi::errors::InvalidArgument("The num_head of q must be divisible by " + "k and v, but got [%d] vs [%d]", + q_shape[i], + shape[i])); + } else { + PADDLE_ENFORCE_EQ(q_shape[i], + shape[i], + phi::errors::InvalidArgument( + "The shape except for num_head of q " + "must be same as k and v, but got [%d] vs [%d]", + str_join(q_shape), + str_join(shape))); + } + } } void check_sin_cos(const DistMetaTensor& sin, const DistMetaTensor& cos, const DistMetaTensor& position_ids, - const std::vector& q_shape, - bool time_major) { + const int64_t batch_size, + const int64_t seq_len, + const int64_t head_dim) { PADDLE_ENFORCE_EQ(sin.dims(), cos.dims(), phi::errors::InvalidArgument( @@ -98,13 +121,6 @@ void check_sin_cos(const DistMetaTensor& sin, phi::errors::InvalidArgument( "The Tensor sin/cos's ndim must be 2 or 4. but given [%d]", ndim)); - const int kBatchDimIndex = time_major ? 1 : 0; - const int kSeqlenDimIndex = time_major ? 0 : 1; - - int batch_size = q_shape[kBatchDimIndex]; - int seq_len = q_shape[kSeqlenDimIndex]; - int head_dim = q_shape[kHeadDimIndex]; - int seq_len_dim_index = ndim == 2 ? 0 : 1; int head_dim_index = ndim == 2 ? 1 : 3; if (ndim == 4) { @@ -143,9 +159,10 @@ void check_sin_cos(const DistMetaTensor& sin, phi::errors::InvalidArgument( "The batch_size and seq_len of position_ids must be the same as " "those of q. But received position_ids's " - "shape is {%s}, q's shape is {%s}.", + "shape is {%s}, q's batch_size is {%d}, q's seq_len is {%d}.", str_join(position_ids_shape), - str_join(q_shape))); + batch_size, + seq_len)); } else { PADDLE_ENFORCE_EQ( (shape[seq_len_dim_index] == seq_len && @@ -162,8 +179,10 @@ void check_sin_cos(const DistMetaTensor& sin, void infer_sin_cos(const DistMetaTensor& sin, const DistMetaTensor& cos, const DistMetaTensor& position_ids, + const TensorDistAttr& q_dist_attr_dst, const std::vector& q_shape, bool time_major, + bool enable_sequence_parallel, TensorDistAttr* sin_dist_attr_dst, TensorDistAttr* cos_dist_attr_dst) { const TensorDistAttr& sin_dist_attr_src = sin.dist_attr(); @@ -178,13 +197,39 @@ void infer_sin_cos(const DistMetaTensor& sin, // if one of sin cos is empty, they are all useless in kernel if (!IsEmpty(sin_shape) && !IsEmpty(cos_shape)) { // check sin, cos, position_ids's shape - check_sin_cos(sin, cos, position_ids, q_shape, time_major); - if (sin_shape.size() == 4) { - *sin_dist_attr_dst = UnShardTensorDims(sin_dist_attr_src, {1, 3}); - *cos_dist_attr_dst = UnShardTensorDims(cos_dist_attr_src, {1, 3}); - } else { - *sin_dist_attr_dst = UnShardTensorDims(sin_dist_attr_src, {0, 1}); - *cos_dist_attr_dst = UnShardTensorDims(cos_dist_attr_src, {0, 1}); + const int kBatchDimIndex = time_major ? 1 : 0; + const int kSeqlenDimIndex = time_major ? 0 : 1; + int batch_size = q_shape[kBatchDimIndex]; + int seq_len = q_shape[kSeqlenDimIndex]; + int head_dim = q_shape[kHeadDimIndex]; + + int seq_len_dim_index = sin_shape.size() == 4 ? 1 : 0; + int head_dim_index = sin_shape.size() == 4 ? 3 : 1; + + check_sin_cos(sin, cos, position_ids, batch_size, seq_len, head_dim); + + *sin_dist_attr_dst = + enable_sequence_parallel + ? UnShardTensorDims(sin_dist_attr_src, {head_dim_index}) + : UnShardTensorDims(sin_dist_attr_src, + {seq_len_dim_index, head_dim_index}); + *cos_dist_attr_dst = + enable_sequence_parallel + ? UnShardTensorDims(sin_dist_attr_src, {head_dim_index}) + : UnShardTensorDims(cos_dist_attr_src, + {seq_len_dim_index, head_dim_index}); + + if (enable_sequence_parallel) { + // shard on seq_len dimension + std::vector sin_dims_mapping = sin_dist_attr_dst->dims_mapping(); + sin_dims_mapping[seq_len_dim_index] = + q_dist_attr_dst.dims_mapping()[kSeqlenDimIndex]; + sin_dist_attr_dst->set_dims_mapping(sin_dims_mapping); + + std::vector cos_dims_mapping = cos_dist_attr_dst->dims_mapping(); + cos_dims_mapping[seq_len_dim_index] = + q_dist_attr_dst.dims_mapping()[kSeqlenDimIndex]; + cos_dist_attr_dst->set_dims_mapping(cos_dims_mapping); } } } @@ -209,11 +254,25 @@ SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q, // q_shape equals [bs, seq_len, num_heads, head_dim] if time_major is False, // otherwise [seq_len, bs, num_heads, head_dim] std::vector q_shape = common::vectorize(q.dims()); + std::vector k_shape = common::vectorize(k.dims()); + std::vector v_shape = common::vectorize(v.dims()); bool is_k_none = IsEmpty(common::vectorize(k.dims())); // except for q, all other inputs are optional. + bool is_same_num_heads = true; + bool is_divisible = true; if (!is_k_none) { check_k_or_v(k, q_shape); inputs_sharding_info.emplace_back(qkv_axes, k_dist_attr_src.dims_mapping()); + is_same_num_heads = + q_shape[kNumHeadsDimIndex] == k_shape[kNumHeadsDimIndex]; + int64_t num_head_shape = k_shape[kNumHeadsDimIndex]; + int64_t num_head_mesh_dim = + k_dist_attr_src.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + k_dist_attr_src.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = num_head_shape % num_head_split_size == 0; + } } const TensorDistAttr& v_dist_attr_src = v.dist_attr(); @@ -221,6 +280,26 @@ SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q, if (!is_v_none) { check_k_or_v(v, q_shape); inputs_sharding_info.emplace_back(qkv_axes, v_dist_attr_src.dims_mapping()); + is_same_num_heads = + q_shape[kNumHeadsDimIndex] == v_shape[kNumHeadsDimIndex]; + int64_t num_head_shape = v_shape[kNumHeadsDimIndex]; + int64_t num_head_mesh_dim = + v_dist_attr_src.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + v_dist_attr_src.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = num_head_shape % num_head_split_size == 0; + } + } + + if (!is_k_none && !is_v_none) { + PADDLE_ENFORCE_EQ( + k_shape, + v_shape, + phi::errors::InvalidArgument("The shape of k and v must be same, " + "but [%d] vs [%d]", + str_join(k_shape), + str_join(v_shape))); } const TensorDistAttr& position_ids_dist_attr_src = position_ids.dist_attr(); @@ -237,9 +316,28 @@ SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q, GetDimsMappingForAxes(qkv_axes, axis_to_dim_map); TensorDistAttr q_dist_attr_dst = CopyTensorDistAttrForOutput(q_dist_attr_src); q_dist_attr_dst.set_dims_mapping(out_dims_mapping); + const int kSeqlenDimIndex = time_major ? 0 : 1; - q_dist_attr_dst = - UnShardTensorDims(q_dist_attr_dst, {kSeqlenDimIndex, kHeadDimIndex}); + // if one of sin cos is empty, they are all useless in kernel + bool is_sin_cos_none = IsEmpty(common::vectorize(sin.dims())) || + IsEmpty(common::vectorize(cos.dims())); + + // Enable sharding on seq_len dimension only if sin/cos is not None and + // position_ids is None + bool enable_sequence_parallel = + !is_sin_cos_none && is_ids_none && + IsDimSharded(q_dist_attr_dst, kSeqlenDimIndex); + if (enable_sequence_parallel) { + // Sharded along seq_len dimension + q_dist_attr_dst = UnShardTensorDims(q_dist_attr_dst, {kHeadDimIndex}); + } else { + q_dist_attr_dst = + UnShardTensorDims(q_dist_attr_dst, {kSeqlenDimIndex, kHeadDimIndex}); + } + + if (!is_same_num_heads && !is_divisible) { + q_dist_attr_dst = UnShardTensorDims(q_dist_attr_dst, {kNumHeadsDimIndex}); + } TensorDistAttr k_dist_attr_dst = CopyTensorDistAttrForOutput(k_dist_attr_src); k_dist_attr_dst.set_process_mesh(q_dist_attr_dst.process_mesh()); @@ -258,8 +356,10 @@ SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q, infer_sin_cos(sin, cos, position_ids, + q_dist_attr_dst, q_shape, time_major, + enable_sequence_parallel, &sin_dist_attr_dst, &cos_dist_attr_dst); @@ -304,12 +404,28 @@ SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, const TensorDistAttr& out_k_dist_attr_src = out_k.dist_attr(); // out_q shape = [bs, seq_len, num_heads, head_dim] std::vector out_q_shape = common::vectorize(out_q.dims()); + std::vector out_k_shape = common::vectorize(out_k.dims()); + std::vector out_v_shape = common::vectorize(out_v.dims()); bool is_k_none = IsEmpty(common::vectorize(out_k.dims())); // except for q, all other inputs are optional. + bool is_same_num_heads = true; + bool is_divisible = true; + if (!is_k_none) { check_k_or_v(out_k, out_q_shape); outputs_sharding_info.emplace_back(qkv_axes, out_k_dist_attr_src.dims_mapping()); + is_same_num_heads = + out_q_shape[kHeadDimIndex] == out_k_shape[kHeadDimIndex]; + + int64_t num_head_shape = out_k_shape[kNumHeadsDimIndex]; + int64_t num_head_mesh_dim = + out_k_dist_attr_src.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + out_k_dist_attr_src.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = num_head_shape % num_head_split_size == 0; + } } const TensorDistAttr& out_v_dist_attr_src = out_v.dist_attr(); @@ -318,6 +434,27 @@ SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, check_k_or_v(out_v, out_q_shape); outputs_sharding_info.emplace_back(qkv_axes, out_v_dist_attr_src.dims_mapping()); + is_same_num_heads = + out_q_shape[kHeadDimIndex] == out_v_shape[kHeadDimIndex]; + + int64_t num_head_shape = out_v_shape[kNumHeadsDimIndex]; + int64_t num_head_mesh_dim = + out_v_dist_attr_src.dims_mapping()[kNumHeadsDimIndex]; + if (num_head_mesh_dim != -1) { + int64_t num_head_split_size = + out_v_dist_attr_src.process_mesh().dim_size(num_head_mesh_dim); + is_divisible = num_head_shape % num_head_split_size == 0; + } + } + + if (!is_k_none && !is_v_none) { + PADDLE_ENFORCE_EQ( + out_k_shape, + out_v_shape, + phi::errors::InvalidArgument("The shape of k and v must be same, " + "but [%d] vs [%d]", + str_join(out_k_shape), + str_join(out_v_shape))); } std::unordered_map axis_to_dim_map = @@ -331,8 +468,28 @@ SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, q_dist_attr_dst.set_dims_mapping(dims_mapping); const int kSeqlenDimIndex = time_major ? 0 : 1; - q_dist_attr_dst = - UnShardTensorDims(q_dist_attr_dst, {kSeqlenDimIndex, kHeadDimIndex}); + // if one of sin cos is empty, they are all useless in kernel + bool is_sin_cos_none = IsEmpty(common::vectorize(sin.dims())) || + IsEmpty(common::vectorize(cos.dims())); + bool is_ids_none = IsEmpty(common::vectorize(position_ids.dims())); + + // Enable sharding on seq_len dimension only if sin/cos is not None and + // position_ids is None + bool enable_sequence_parallel = + !is_sin_cos_none && is_ids_none && + IsDimSharded(q_dist_attr_dst, kSeqlenDimIndex); + if (enable_sequence_parallel) { + // Sharded along seq_len dimension + q_dist_attr_dst = UnShardTensorDims(q_dist_attr_dst, {kHeadDimIndex}); + } else { + q_dist_attr_dst = + UnShardTensorDims(q_dist_attr_dst, {kSeqlenDimIndex, kHeadDimIndex}); + } + + if (!is_same_num_heads && !is_divisible) { + q_dist_attr_dst = UnShardTensorDims(q_dist_attr_dst, {kNumHeadsDimIndex}); + } + TensorDistAttr out_q_dist_attr_dst = q_dist_attr_dst; TensorDistAttr k_dist_attr_dst = CopyTensorDistAttrForOutput(k.dist_attr()); @@ -356,8 +513,10 @@ SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, infer_sin_cos(sin, cos, position_ids, + out_q_dist_attr_dst, out_q_shape, time_major, + enable_sequence_parallel, &sin_dist_attr_dst, &cos_dist_attr_dst); @@ -367,7 +526,6 @@ SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, TensorDistAttr position_ids_dist_attr_dst = CopyTensorDistAttrForOutput(position_ids.dist_attr()); - bool is_ids_none = IsEmpty(common::vectorize(position_ids.dims())); if (!is_ids_none) { position_ids_dist_attr_dst.set_dims_mapping(position_ids_dims_mapping); position_ids_dist_attr_dst = diff --git a/paddle/phi/infermeta/spmd_rules/fused_rope.h b/paddle/phi/infermeta/spmd_rules/fused_rope.h index fdd9ae27500b0..3a5c331098ad1 100644 --- a/paddle/phi/infermeta/spmd_rules/fused_rope.h +++ b/paddle/phi/infermeta/spmd_rules/fused_rope.h @@ -29,8 +29,8 @@ SpmdInfo FusedRopeInferSpmd(const DistMetaTensor& q, const DistMetaTensor& sin, const DistMetaTensor& cos, const DistMetaTensor& position_ids, - bool use_neox_rotary_style, - bool time_major); + bool use_neox_rotary_style = true, + bool time_major = false); SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, const DistMetaTensor& k, @@ -41,8 +41,8 @@ SpmdInfo FusedRopeInferSpmdReverse(const DistMetaTensor& q, const DistMetaTensor& out_q, const DistMetaTensor& out_k, const DistMetaTensor& out_v, - bool use_neox_rotary_style, - bool time_major); + bool use_neox_rotary_style = true, + bool time_major = false); SpmdInfo FusedRopeGradInferSpmd(const DistMetaTensor& sin, const DistMetaTensor& cos, @@ -50,8 +50,8 @@ SpmdInfo FusedRopeGradInferSpmd(const DistMetaTensor& sin, const DistMetaTensor& out_q_grad, const DistMetaTensor& out_k_grad, const DistMetaTensor& out_v_grad, - bool use_neox_rotary_style, - bool time_major); + bool use_neox_rotary_style = true, + bool time_major = false); } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/gather.cc b/paddle/phi/infermeta/spmd_rules/gather.cc new file mode 100644 index 0000000000000..014c5f358dd73 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/gather.cc @@ -0,0 +1,219 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/gather.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo GatherInferSpmdBase(const DistMetaTensor& x, + const DistMetaTensor& index, + int axis) { + // Step0: Verify Input Args Based on Gather Logic + // extract and check x_ndim, x_shape, x_dist_attr_src and + // x_dims_mapping_src with the macro + EXTRACT_SHAPE_AND_DIST_ATTR(x); + // index may be 0-d tensor, verify it specifically + auto index_shape = common::vectorize(index.dims()); + int index_ndim = index_shape.size(); + TensorDistAttr index_dist_attr_src = index.dist_attr(); + std::vector index_dims_mapping_src = + index_dist_attr_src.dims_mapping(); + if (index_ndim == 0) { + PADDLE_ENFORCE_EQ(index_dims_mapping_src.size(), + 1, + phi::errors::InvalidArgument( + "index is 0-d tensor, it's dims_mapping size " + "must be 1, but received [%d]", + index_dims_mapping_src.size())); + } else { + PADDLE_ENFORCE_EQ( + index_ndim, + index_dims_mapping_src.size(), + phi::errors::InvalidArgument("Tensor index's rank [%d] and " + "dims_mapping size [%d] are not matched.", + index_ndim, + index_dims_mapping_src.size())); + } + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + std::string index_axes = "k"; + std::string out_axes = x_axes; + if (index_ndim == 0) { + if (axis < x_ndim) { + out_axes.erase(axis, 1); + } + index_axes = ""; + } else { + out_axes[axis] = 'k'; + } + + // Step2: Sharding Propogation + // Step2.1: Merge input shardings + std::vector x_dims_mapping(x_dims_mapping_src); + if (axis < x_ndim) { + x_dims_mapping[axis] = -1; + } + std::vector index_dims_mapping(index_dims_mapping_src); + if (index_ndim == 0) { + index_dims_mapping[0] = -1; + } + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors( + {{x_axes, x_dims_mapping}, {index_axes, index_dims_mapping}}); + + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + + TensorDistAttr index_dist_attr_dst = + CopyTensorDistAttrForOutput(index_dist_attr_src); + index_dist_attr_dst.set_dims_mapping(index_dims_mapping); + + // Step2.2: Infer output dims mapping + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + VLOG(4) << "x_axes: " << x_axes << " index_axes: " << index_axes + << " out_axes: " << out_axes; + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(index); + VLOG(4) << "out"; + VLOG(4) << "dist_attr: [" << out_dist_attr.to_string() << "]"; + return {{x_dist_attr_dst, index_dist_attr_dst}, {out_dist_attr}}; +} + +SpmdInfo GatherInferSpmdReverseBase(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& out, + int axis) { + // Step0: Verify Input Args Based on Gather Logic + // extract and check out_ndim, out_shape, out_dist_attr_src and + // out_dims_mapping_src with the macro + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(index); + EXTRACT_SHAPE_AND_DIST_ATTR(out); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + // x should be replicated on 0th axis + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + std::string index_axes = "k"; + std::string out_axes = x_axes; + if (index_ndim == 0) { + index_axes = ""; + if (axis < x_ndim) { + out_axes.erase(axis, 1); + } + } else { + out_axes[axis] = 'k'; + } + + // Step2: Sharding Propogation + // Step2.1: Merge output shardings + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{out_axes, out_dims_mapping_src}}); + + // Step2.2: Infer input dims mapping + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map, true); + if (axis < x_ndim) { + x_dims_mapping[axis] = -1; + } + std::vector index_dims_mapping = + GetDimsMappingForAxes(index_axes, axis_to_dim_map, true); + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + TensorDistAttr index_dist_attr_dst = + CopyTensorDistAttrForOutput(index_dist_attr_src); + index_dist_attr_dst.set_dims_mapping(index_dims_mapping); + + VLOG(4) << "out_axes: " << out_axes << " x_axes: " << x_axes + << " index_axes: " << index_axes; + VLOG(4) << "out dist_attr: [" << out_dist_attr_src.to_string() << "]"; + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(index); + VLOG(4) << std::endl; + return {{x_dist_attr_dst, index_dist_attr_dst}, {out_dist_attr_src}}; +} + +SpmdInfo GatherInferSpmdDynamic(const DistMetaTensor& x, + const DistMetaTensor& index, + const Scalar& axis) { + return GatherInferSpmdBase(x, index, axis.to()); +} + +SpmdInfo GatherInferSpmdReverseDynamic(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& out, + const Scalar& axis) { + return GatherInferSpmdReverseBase(x, index, out, axis.to()); +} + +SpmdInfo GatherGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& out_grad, + const Scalar& axis) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(out_grad); + auto index_shape = common::vectorize(index.dims()); + int index_ndim = index_shape.size(); + TensorDistAttr index_dist_attr_src = index.dist_attr(); + std::vector index_dims_mapping_src = + index_dist_attr_src.dims_mapping(); + int axis_ = axis.to(); + + // TODO(zhangyichen): support shard on index and out_grad[axis] + std::vector out_grad_dims_mapping_dst(out_grad_dims_mapping_src); + TensorDistAttr out_grad_dist_attr_dst(out_grad_dist_attr_src); + if (index_ndim == 0) { + out_grad_dims_mapping_dst.insert(out_grad_dims_mapping_dst.begin() + axis_, + -1); + } else { + out_grad_dims_mapping_dst[axis_] = -1; + out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping_dst); + } + + std::vector index_dims_mapping_dst(index_dims_mapping_src); + TensorDistAttr index_dist_attr_dst(index_dims_mapping_src); + index_dims_mapping_dst[axis_] = -1; + index_dist_attr_dst.set_dims_mapping(index_dims_mapping_dst); + + std::vector x_grad_dims_mapping(x_dims_mapping_src); + for (int i = 0; i < x_ndim; ++i) { + x_grad_dims_mapping[i] = out_grad_dims_mapping_dst[i]; + } + + TensorDistAttr x_grad_dist_attr(x_dist_attr_src); + x_grad_dist_attr.set_dims_mapping(x_grad_dims_mapping); + + return {{x_dist_attr_src, index_dist_attr_dst, out_grad_dist_attr_dst}, + {x_grad_dist_attr}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/gather.h b/paddle/phi/infermeta/spmd_rules/gather.h new file mode 100644 index 0000000000000..7dd829094ca57 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/gather.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { +SpmdInfo GatherInferSpmdBase(const DistMetaTensor& x, + const DistMetaTensor& index, + int axis); + +SpmdInfo GatherInferSpmdReverseBase(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& out, + int axis); + +SpmdInfo GatherInferSpmdDynamic(const DistMetaTensor& x, + const DistMetaTensor& index, + const Scalar& axis); + +SpmdInfo GatherInferSpmdReverseDynamic(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& out, + const Scalar& axis); + +SpmdInfo GatherGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& out_grad, + const Scalar& axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc index 35c2e56af3de0..6ea65d106bc71 100644 --- a/paddle/phi/infermeta/spmd_rules/layer_norm.cc +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -26,6 +26,26 @@ namespace distributed { using phi::distributed::auto_parallel::str_join; +void LogInputDistAttr(const std::string& name, + const std::vector& shape, + const TensorDistAttr& src_dist_attr, + const TensorDistAttr& dst_dist_attr) { + VLOG(4) << name << " shape: [" << str_join(shape) << "] " + << "src_dims_mapping: [" << str_join(src_dist_attr.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(dst_dist_attr.dims_mapping()) + << "] " + << "src_partial: " << src_dist_attr.partial_status_string() + << " dst_partial: " << dst_dist_attr.partial_status_string(); +} + +void LogOutputDistAttr(const std::string& name, + const TensorDistAttr& dst_dist_attr) { + VLOG(4) << name << " dims mapping: [" + << str_join(dst_dist_attr.dims_mapping()) << "] " + << "partial: " << dst_dist_attr.partial_status_string(); +} + SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x, const DistMetaTensor& scale, const DistMetaTensor& bias, @@ -347,12 +367,16 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, TensorDistAttr x_dist_attr; TensorDistAttr mean_dist_attr; TensorDistAttr variance_dist_attr; - TensorDistAttr grad_dist_attr; + TensorDistAttr out_grad_dist_attr; + std::vector dist_attrs; dist_attrs.push_back(x.dist_attr()); dist_attrs.push_back(mean.dist_attr()); dist_attrs.push_back(variance.dist_attr()); - dist_attrs.push_back(out_grad.dist_attr()); + out_grad_dist_attr = out_grad.dist_attr(); + out_grad_dist_attr.clean_partial_status(); + dist_attrs.push_back(out_grad_dist_attr); + if (begin_norm_axis > 0) { std::vector> shapes = { x_shape, mean_shape, variance_shape, x_shape}; @@ -365,16 +389,17 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, x_dist_attr = std::move(dist_attrs[0]); mean_dist_attr = std::move(dist_attrs[1]); variance_dist_attr = std::move(dist_attrs[2]); - grad_dist_attr = std::move(dist_attrs[3]); + out_grad_dist_attr = std::move(dist_attrs[3]); } else { x_dist_attr = GetReplicatedDistAttr(dist_attrs[0]); mean_dist_attr = GetReplicatedDistAttr(dist_attrs[1]); variance_dist_attr = GetReplicatedDistAttr(dist_attrs[2]); - grad_dist_attr = GetReplicatedDistAttr(dist_attrs[3]); + out_grad_dist_attr = GetReplicatedDistAttr(dist_attrs[3]); } // TODO(liuzhenhai): support sharded scale and bias TensorDistAttr scale_dist_attr = GetReplicatedDistAttr(scale.dist_attr()); TensorDistAttr bias_dist_attr = GetReplicatedDistAttr(bias.dist_attr()); + TensorDistAttr x_grad_dist_attr = out_grad_dist_attr; TensorDistAttr scale_grad_dist_attr = GetReplicatedDistAttr(scale.dist_attr()); TensorDistAttr bias_grad_dist_attr = GetReplicatedDistAttr(bias.dist_attr()); @@ -390,13 +415,29 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x, scale_grad_dist_attr.set_partial_status(partial_on_dims); bias_grad_dist_attr.set_partial_status(partial_on_dims); - return SpmdInfo({x_dist_attr, - scale_dist_attr, - bias_dist_attr, - mean_dist_attr, - variance_dist_attr, - grad_dist_attr}, - {grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr}); + VLOG(4) << "LayerNormGradInferSpmd:"; + VLOG(4) << "begin_norm_axis: " << begin_norm_axis; + LogInputDistAttr("X", x_shape, x.dist_attr(), x_dist_attr); + LogInputDistAttr("Scale", scale_shape, scale.dist_attr(), scale_dist_attr); + LogInputDistAttr("Bias", bias_shape, bias.dist_attr(), bias_dist_attr); + LogInputDistAttr("Mean", mean_shape, mean.dist_attr(), mean_dist_attr); + LogInputDistAttr( + "Variance", variance_shape, variance.dist_attr(), variance_dist_attr); + LogInputDistAttr( + "OutGrad", out_grad_shape, out_grad.dist_attr(), out_grad_dist_attr); + LogOutputDistAttr("XGrad", x_grad_dist_attr); + LogOutputDistAttr("ScaleGrad", scale_grad_dist_attr); + LogOutputDistAttr("BiasGrad", bias_grad_dist_attr); + VLOG(4) << std::endl; + + return SpmdInfo( + {x_dist_attr, + scale_dist_attr, + bias_dist_attr, + mean_dist_attr, + variance_dist_attr, + out_grad_dist_attr}, + {x_grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr}); } } // namespace distributed diff --git a/paddle/phi/infermeta/spmd_rules/one_hot.cc b/paddle/phi/infermeta/spmd_rules/one_hot.cc new file mode 100644 index 0000000000000..dc90684dde1ef --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/one_hot.cc @@ -0,0 +1,94 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/one_hot.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo OneHotInferSpmd(const DistMetaTensor& x, int num_classes) { + // Step0: Verify input args based on split logic + auto x_shape = common::vectorize(x.dims()); + int x_ndim = static_cast(x_shape.size()); + auto x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping_src = x_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + x_ndim, + x_dims_mapping_src.size(), + phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's " + "dims_mapping size [%d] are not matched.", + x_ndim, + x_dims_mapping_src.size())); + + std::vector out_dims_mapping(x_dims_mapping_src); + out_dims_mapping.emplace_back(-1); + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + // Step3 Handle input tensor partial (TODO) + VLOG(4) << "OneHotInferSpmd:"; + VLOG(4) << "x shape: [" << str_join(x_shape) << "] " + << "src_dims_mapping: [" << str_join(x_dims_mapping_src) << "] " + << "dst_dims_mapping: [" << str_join(x_dims_mapping_src) << "]"; + VLOG(4) << "Out dims_mapping: [" << str_join(out_dims_mapping) << "]"; + VLOG(4) << std::endl; + return {{x_dist_attr_src}, {out_dist_attr}}; +} + +SpmdInfo OneHotInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int num_classes) { + // Step0: Verify input args based on split logic + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(out); + + std::vector out_dims_mapping_dst(out_dims_mapping_src); + out_dims_mapping_dst[out_ndim - 1] = -1; + TensorDistAttr out_dist_attr_dst = + CopyTensorDistAttrForOutput(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(out_dims_mapping_dst); + + std::vector x_dims_mapping_dst(out_dims_mapping_dst.begin(), + out_dims_mapping_dst.end() - 1); + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + + VLOG(4) << "OneHotInferSpmdReverse:"; + VLOG(4) << "out shape: [" << str_join(out_shape) << "] " + << "src_dims_mapping: [" << str_join(out_dims_mapping_src) << "] " + << "dst_dims_mapping: [" << str_join(out_dims_mapping_dst) << "]"; + VLOG(4) << "x shape: [" << str_join(x_shape) << "] " + << "src_dims_mapping: [" << str_join(x_dims_mapping_src) << "] " + << "dst_dims_mapping: [" << str_join(x_dims_mapping_dst) << "]"; + VLOG(4) << std::endl; + return {{x_dist_attr_dst}, {out_dist_attr_dst}}; +} + +SpmdInfo OneHotInferSpmdDynamic(const DistMetaTensor& x, + const Scalar& num_classes) { + return OneHotInferSpmd(x, num_classes.to()); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/one_hot.h b/paddle/phi/infermeta/spmd_rules/one_hot.h new file mode 100644 index 0000000000000..66b900a2881d9 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/one_hot.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo OneHotInferSpmd(const DistMetaTensor& x, int num_classes); + +SpmdInfo OneHotInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& out, + int num_classes); + +SpmdInfo OneHotInferSpmdDynamic(const DistMetaTensor& x, + const Scalar& num_classes); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/reduction.cc b/paddle/phi/infermeta/spmd_rules/reduction.cc index 608794d348541..96e9230fb9182 100644 --- a/paddle/phi/infermeta/spmd_rules/reduction.cc +++ b/paddle/phi/infermeta/spmd_rules/reduction.cc @@ -71,7 +71,7 @@ SpmdInfo ReductionInferSpmdBase(const DistMetaTensor& x, int reduce_type) { // Step0: Verify input args based on reduction logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -175,8 +175,8 @@ SpmdInfo ReductionInferSpmdReverse(const DistMetaTensor& x, // Step0: Verify input args based on reduction logic auto x_shape = common::vectorize(x.dims()); auto out_shape = common::vectorize(out.dims()); - int x_ndim = x_shape.size(); - int out_ndim = out_shape.size(); + int x_ndim = static_cast(x_shape.size()); + int out_ndim = static_cast(out_shape.size()); auto out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -238,9 +238,9 @@ SpmdInfo ReductionGradInferSpmd(const DistMetaTensor& x, auto dims_mapping = x_dist_attr.dims_mapping(); auto axis_value = axis.GetData(); - for (size_t i = 0; i < axis_value.size(); ++i) { - if (axis_value[i] < 0) { - axis_value[i] += x_dim.size(); + for (auto& i : axis_value) { + if (i < 0) { + i += x_dim.size(); } } std::sort(axis_value.begin(), axis_value.end()); diff --git a/paddle/phi/infermeta/spmd_rules/replicated.cc b/paddle/phi/infermeta/spmd_rules/replicated.cc index 8d9c6d0d5be6c..390117862e04e 100644 --- a/paddle/phi/infermeta/spmd_rules/replicated.cc +++ b/paddle/phi/infermeta/spmd_rules/replicated.cc @@ -35,8 +35,8 @@ std::vector GetReplicatedDimsMapping(const int ndim) { SpmdInfo ReplicatedInferSpmd(const std::vector& ins, const std::vector& outs) { // step1: Build Einsum Notation for input tensor's batch axis - int64_t ninputs = ins.size(); - int64_t noutputs = outs.size(); + int64_t ninputs = static_cast(ins.size()); + int64_t noutputs = static_cast(outs.size()); // Step2: Unshard Output's Dims Mapping. std::vector output_dist_attrs; @@ -94,8 +94,8 @@ SpmdInfo ReplicatedInferSpmdReverse( const std::vector& ins, const std::vector& outs) { // step1: Build Einsum Notation for input tensor's batch axis - int64_t ninputs = ins.size(); - int64_t noutputs = outs.size(); + int64_t ninputs = static_cast(ins.size()); + int64_t noutputs = static_cast(outs.size()); // Step2: Unshard Output's Dims Mapping. std::vector output_dist_attrs; @@ -145,7 +145,7 @@ SpmdInfo ReplicatedInferDynamic( const std::vector*>>& inputs) { std::vector nonnull_inputs; - int64_t ninputs = inputs.size(); + int64_t ninputs = static_cast(inputs.size()); SpmdInfo spmd_info; auto build_tensor_dist_attr = diff --git a/paddle/phi/infermeta/spmd_rules/reshape.cc b/paddle/phi/infermeta/spmd_rules/reshape.cc index 2e8d79e14bf49..9ca886f0dc637 100644 --- a/paddle/phi/infermeta/spmd_rules/reshape.cc +++ b/paddle/phi/infermeta/spmd_rules/reshape.cc @@ -122,8 +122,7 @@ std::vector> MakeReshapeDimTrans( if (!tgt_splitted_shape.empty()) { std::vector> input_dims; - for (int i = 0, n = static_cast(src_dims.size()); i < n; i++) { - int64_t in_dim = src_dims[i]; + for (auto in_dim : src_dims) { if (src_shape[in_dim] > 1) { input_dims.emplace_back(std::make_shared(in_dim)); } diff --git a/paddle/phi/infermeta/spmd_rules/rules.cc b/paddle/phi/infermeta/spmd_rules/rules.cc index 0921763df1229..9c6492ee75913 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.cc +++ b/paddle/phi/infermeta/spmd_rules/rules.cc @@ -435,12 +435,13 @@ PD_REGISTER_SPMD_RULE( logical_xor, PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - PD_REGISTER_SPMD_RULE( not_equal, PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmd), PD_INFER_SPMD(phi::distributed::ElementwiseBinaryInferSpmdReverse)); - +PD_REGISTER_SPMD_RULE(swiglu, + PD_INFER_SPMD(phi::distributed::SwiGLUInferSpmd), + PD_INFER_SPMD(phi::distributed::SwiGLUInferSpmdReverse)); // TODO(pkuzyc): add multiary elementwise rule // reduction rule @@ -605,5 +606,46 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD( phi::distributed::FusedLinearParamGradAddInferSpmdFakeReverse)); +PD_REGISTER_SPMD_RULE( + expand_as, + PD_INFER_SPMD(phi::distributed::ExpandAsInferSpmd), + PD_INFER_SPMD(phi::distributed::ExpandAsInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE( + expand_as_v2, + PD_INFER_SPMD(phi::distributed::ExpandAsInferSpmd), + PD_INFER_SPMD(phi::distributed::ExpandAsInferSpmdReverse)); + +// scatter +PD_REGISTER_SPMD_RULE(scatter, + PD_INFER_SPMD(phi::distributed::ScatterInferSpmd), + PD_INFER_SPMD(phi::distributed::ScatterInferSpmdReverse)); + +// gather +PD_REGISTER_SPMD_RULE( + gather, + PD_INFER_SPMD(phi::distributed::GatherInferSpmdBase), + PD_INFER_SPMD(phi::distributed::GatherInferSpmdReverseBase)); + +// one_hot +PD_REGISTER_SPMD_RULE(one_hot, + PD_INFER_SPMD(phi::distributed::OneHotInferSpmd), + PD_INFER_SPMD(phi::distributed::OneHotInferSpmdReverse)); + +PD_REGISTER_SPMD_RULE(cumsum, + PD_INFER_SPMD(phi::distributed::CumSumInferSpmd), + PD_INFER_SPMD(phi::distributed::CumSumInferSpmdReverse)); + +// argmax +PD_REGISTER_SPMD_RULE( + argmax, + PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdBase), + PD_INFER_SPMD(phi::distributed::ArgMaxInferSpmdReverseBase)); + +// unbind +PD_REGISTER_SPMD_RULE(unbind, + PD_INFER_SPMD(phi::distributed::UnbindInferSpmd), + PD_INFER_SPMD(phi::distributed::UnbindInferSpmdReverse)); + } // namespace distributed } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 03446ca5d2789..01ec6687a463d 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -14,20 +14,25 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/infermeta/spmd_rules/argmax.h" #include "paddle/phi/infermeta/spmd_rules/cast.h" #include "paddle/phi/infermeta/spmd_rules/concat.h" #include "paddle/phi/infermeta/spmd_rules/cross_entropy_with_softmax.h" +#include "paddle/phi/infermeta/spmd_rules/cumsum.h" #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" #include "paddle/phi/infermeta/spmd_rules/embedding.h" +#include "paddle/phi/infermeta/spmd_rules/expand_as.h" #include "paddle/phi/infermeta/spmd_rules/flash_attention.h" #include "paddle/phi/infermeta/spmd_rules/flatten.h" #include "paddle/phi/infermeta/spmd_rules/full_like.h" #include "paddle/phi/infermeta/spmd_rules/fused_linear_param_grad_add.h" #include "paddle/phi/infermeta/spmd_rules/fused_rope.h" +#include "paddle/phi/infermeta/spmd_rules/gather.h" #include "paddle/phi/infermeta/spmd_rules/layer_norm.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/numel.h" +#include "paddle/phi/infermeta/spmd_rules/one_hot.h" #include "paddle/phi/infermeta/spmd_rules/optimizer.h" #include "paddle/phi/infermeta/spmd_rules/pow.h" #include "paddle/phi/infermeta/spmd_rules/reduction.h" @@ -35,6 +40,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/reshape.h" #include "paddle/phi/infermeta/spmd_rules/rms_norm.h" #include "paddle/phi/infermeta/spmd_rules/scale.h" +#include "paddle/phi/infermeta/spmd_rules/scatter.h" #include "paddle/phi/infermeta/spmd_rules/slice.h" #include "paddle/phi/infermeta/spmd_rules/softmax.h" #include "paddle/phi/infermeta/spmd_rules/split.h" @@ -43,5 +49,6 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/tile.h" #include "paddle/phi/infermeta/spmd_rules/transpose.h" #include "paddle/phi/infermeta/spmd_rules/triu.h" +#include "paddle/phi/infermeta/spmd_rules/unbind.h" #include "paddle/phi/infermeta/spmd_rules/unsqueeze.h" #include "paddle/phi/infermeta/spmd_rules/where.h" diff --git a/paddle/phi/infermeta/spmd_rules/scale.cc b/paddle/phi/infermeta/spmd_rules/scale.cc index b6e8aaef754b7..040e7979ddcfa 100644 --- a/paddle/phi/infermeta/spmd_rules/scale.cc +++ b/paddle/phi/infermeta/spmd_rules/scale.cc @@ -16,7 +16,7 @@ namespace phi { namespace distributed { SpmdInfo ScaleInferSpmd(const DistMetaTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale) { return ElementwiseUnaryInferSpmd(x); } diff --git a/paddle/phi/infermeta/spmd_rules/scale.h b/paddle/phi/infermeta/spmd_rules/scale.h index c020337ec3710..8e4e20a4c435b 100644 --- a/paddle/phi/infermeta/spmd_rules/scale.h +++ b/paddle/phi/infermeta/spmd_rules/scale.h @@ -24,7 +24,7 @@ namespace phi { namespace distributed { SpmdInfo ScaleInferSpmd(const DistMetaTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale); } } // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/scatter.cc b/paddle/phi/infermeta/spmd_rules/scatter.cc new file mode 100644 index 0000000000000..6a31318045e16 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/scatter.cc @@ -0,0 +1,207 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/scatter.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/gather.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +////////////////// Utils Functions ////////////////// + +SpmdInfo ScatterInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& updates, + bool overwrite) { + // Step0: Verify Input Args Based on Scatter Logic + // extract and check x_ndim, x_shape, x_dist_attr_src and + // x_dims_mapping_src with the macro + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(index); + EXTRACT_SHAPE_AND_DIST_ATTR(updates); + PADDLE_ENFORCE_LE( + index_ndim, + updates_ndim, + phi::errors::InvalidArgument( + "%s (%d): The Index's rank [%d] should be less or equal " + "to Updates' rank [%d].", + __FILE__, + __LINE__, + index_ndim, + updates_ndim)); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + // x should be replicated on 0th axis + std::string index_axes = GetBroadcastAxes(index_ndim, index_ndim, alphabet); + std::string updates_axes = + GetBroadcastAxes(updates_ndim, updates_ndim, alphabet); + std::string out_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + out_axes[0] = '1'; + + // Step2: Sharding Propogation + // Step2.1: Merge input shardings + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{index_axes, index_dims_mapping_src}, + {updates_axes, updates_dims_mapping_src}}); + + std::vector index_dims_mapping = + GetDimsMappingForAxes(index_axes, axis_to_dim_map); + TensorDistAttr index_dist_attr_dst = + CopyTensorDistAttrForOutput(index_dist_attr_src); + index_dist_attr_dst.set_dims_mapping(index_dims_mapping); + + std::vector updates_dims_mapping = + GetDimsMappingForAxes(updates_axes, axis_to_dim_map); + TensorDistAttr updates_dist_attr_dst = + CopyTensorDistAttrForOutput(updates_dist_attr_src); + updates_dist_attr_dst.set_dims_mapping(updates_dims_mapping); + + // Step2.2: Infer output dims mapping + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + // the batch axis of output must be replicated + out_dims_mapping[0] = -1; + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping(out_dims_mapping); + + // the dims mapping of x should be the same as output + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(out_dims_mapping); + + // Step3: Handle partial + // output partial status + // output is partialed if the batch axis of index and updates are sharded + if (updates_dims_mapping[0] != -1) { + std::vector partial_dims(1, updates_dims_mapping[0]); + out_dist_attr.set_partial_status(partial_dims); + } + + VLOG(4) << "index_axes: " << index_axes << " updates_axes: " << updates_axes + << " out_axes: " << out_axes; + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(index); + LOG_SPMD_INPUT(updates); + VLOG(4) << "Out dist_attr: [" << out_dist_attr.to_string() << "]\n\n"; + return {{x_dist_attr_dst, index_dist_attr_dst, updates_dist_attr_dst}, + {out_dist_attr}}; +} + +SpmdInfo ScatterInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& updates, + const DistMetaTensor& out, + bool overwrite) { + // Step0: Verify Input Args Based on Scatter Logic + // extract and check out_ndim, out_shape, out_dist_attr_src and + // out_dims_mapping_src with the macro + EXTRACT_SHAPE_AND_DIST_ATTR(x); + EXTRACT_SHAPE_AND_DIST_ATTR(index); + EXTRACT_SHAPE_AND_DIST_ATTR(updates); + EXTRACT_SHAPE_AND_DIST_ATTR(out); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + // x should be replicated on 0th axis + std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet); + std::string index_axes = GetBroadcastAxes(index_ndim, index_ndim, alphabet); + std::string updates_axes = + GetBroadcastAxes(updates_ndim, updates_ndim, alphabet); + std::string out_axes = GetBroadcastAxes(out_ndim, out_ndim, alphabet); + + // Step2: Sharding Propogation + // Step2.1: Merge output shardings + // the batch axis of output must be replicated + // TODO(zhangyichen): consider the case when the output is partial + std::vector out_dims_mapping(out_dims_mapping_src); + out_dims_mapping[0] = -1; + TensorDistAttr out_dist_attr_dst = + CopyTensorDistAttrForOutput(out_dist_attr_src); + out_dist_attr_dst.set_dims_mapping(out_dims_mapping); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{out_axes, out_dims_mapping}}); + + // Step2.2: Infer input dims mapping + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map); + std::vector index_dims_mapping = + GetDimsMappingForAxes(index_axes, axis_to_dim_map); + std::vector updates_dims_mapping = + GetDimsMappingForAxes(updates_axes, axis_to_dim_map); + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + TensorDistAttr index_dist_attr_dst = + CopyTensorDistAttrForOutput(index_dist_attr_src); + index_dist_attr_dst.set_dims_mapping(index_dims_mapping); + TensorDistAttr updates_dist_attr_dst = + CopyTensorDistAttrForOutput(updates_dist_attr_src); + updates_dist_attr_dst.set_dims_mapping(updates_dims_mapping); + + LOG_SPMD_INPUT(out); + LOG_SPMD_INPUT(x); + LOG_SPMD_INPUT(index); + LOG_SPMD_INPUT(updates); + VLOG(4) << std::endl; + return {{x_dist_attr_dst, index_dist_attr_dst, updates_dist_attr_dst}, + {out_dist_attr_dst}}; +} + +SpmdInfo ScatterGradInferSpmd(const DistMetaTensor& index, + const DistMetaTensor& updates, + const DistMetaTensor& out_grad, + bool overwrite) { + EXTRACT_SHAPE_AND_DIST_ATTR(index); + EXTRACT_SHAPE_AND_DIST_ATTR(updates); + EXTRACT_SHAPE_AND_DIST_ATTR(out_grad); + + // the batch axis of index, updates, out_grad must be replicated + std::vector index_dims_mapping(index_dims_mapping_src); + index_dims_mapping[0] = -1; + std::vector out_grad_dims_mapping(out_grad_dims_mapping_src); + out_grad_dims_mapping[0] = -1; + + TensorDistAttr index_dist_attr_dst = + CopyTensorDistAttrForOutput(index_dist_attr_src); + index_dist_attr_dst.set_dims_mapping(index_dims_mapping); + TensorDistAttr out_grad_dist_attr_dst = + CopyTensorDistAttrForOutput(out_grad_dist_attr_src); + out_grad_dist_attr_dst.set_dims_mapping(out_grad_dims_mapping); + + TensorDistAttr x_grad_dist_attr(out_grad_dist_attr_src); + std::vector x_dims_mapping(out_grad_dims_mapping); + x_grad_dist_attr.set_dims_mapping(x_dims_mapping); + + DistMetaTensor out_grad_dst(out_grad.dims(), out_grad_dist_attr_dst); + DistMetaTensor index_dst(index.dims(), index_dist_attr_dst); + + SpmdInfo spmd_info = GatherInferSpmdBase(out_grad_dst, index_dst, 0); + TensorDistAttr updates_grad_dist_attr = + PADDLE_GET_CONST(TensorDistAttr, spmd_info.second[0]); + + return {{index_dist_attr_dst, updates_dist_attr_src, out_grad_dist_attr_dst}, + {x_grad_dist_attr, updates_grad_dist_attr}}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/scatter.h b/paddle/phi/infermeta/spmd_rules/scatter.h new file mode 100644 index 0000000000000..f074ba998bdac --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/scatter.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo ScatterInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& updates, + bool overwrite); + +SpmdInfo ScatterInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& index, + const DistMetaTensor& updates, + const DistMetaTensor& out, + bool overwrite); + +SpmdInfo ScatterGradInferSpmd(const DistMetaTensor& index, + const DistMetaTensor& updates, + const DistMetaTensor& out_grad, + bool overwrite); +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/slice.cc b/paddle/phi/infermeta/spmd_rules/slice.cc index 3615e57340a0d..9daed3ce8c764 100644 --- a/paddle/phi/infermeta/spmd_rules/slice.cc +++ b/paddle/phi/infermeta/spmd_rules/slice.cc @@ -77,8 +77,8 @@ SpmdInfo SliceInferSpmdBase(const DistMetaTensor& input, // cannot be sharded, if it is sharded, set it to replicated. TensorDistAttr input_dist_attr_dst = CopyTensorDistAttrForOutput(input_dist_attr_src); - for (int i = 0; i < static_cast(axes.size()); i++) { - int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; // NOLINT + for (auto axe : axes) { + int axis = axe < 0 ? axe + input_ndim : axe; input_dims_mapping[axis] = -1; } input_dist_attr_dst.set_dims_mapping(input_dims_mapping); @@ -164,8 +164,8 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input, out_axes[i] = input_axes[input_axis]; } - for (int i = 0; i < static_cast(axes.size()); i++) { - int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; // NOLINT + for (auto axe : axes) { + int axis = axe < 0 ? axe + input_ndim : axe; // the sliced axis cannot be sharded, set its notation // with the special '1' to set its dim mapping to -1. input_axes[axis] = '1'; @@ -190,8 +190,8 @@ SpmdInfo SliceInferSpmdReverseBase(const DistMetaTensor& input, // step2.3 get new dist attribute for output. the sliced // cannot be sharded, if it is sharded, set it to replicated. out_dims_mapping = GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); - for (int i = 0; i < static_cast(axes.size()); i++) { - int axis = axes[i] < 0 ? axes[i] + input_ndim : axes[i]; + for (auto axe : axes) { + int axis = axe < 0 ? axe + input_ndim : axe; out_dims_mapping[axis] = -1; } auto out_dist_attr_dst = CopyTensorDistAttrForOutput(out_dist_attr); diff --git a/paddle/phi/infermeta/spmd_rules/softmax.cc b/paddle/phi/infermeta/spmd_rules/softmax.cc index d86db4d41ae23..b6f886a49468a 100644 --- a/paddle/phi/infermeta/spmd_rules/softmax.cc +++ b/paddle/phi/infermeta/spmd_rules/softmax.cc @@ -31,7 +31,7 @@ using phi::distributed::auto_parallel::str_join; SpmdInfo SoftmaxInferSpmd(const DistMetaTensor& x, int axis) { // Step0: Verify input args based on softmax logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -100,8 +100,8 @@ SpmdInfo SoftmaxInferSpmdReverse(const DistMetaTensor& x, // Step0: verify input args based on softmax logic auto x_shape = common::vectorize(x.dims()); auto out_shape = common::vectorize(out.dims()); - int x_ndim = x_shape.size(); - int out_ndim = out_shape.size(); + int x_ndim = static_cast(x_shape.size()); + int out_ndim = static_cast(out_shape.size()); auto out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( diff --git a/paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h b/paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h index a9d49f3718171..43147db5b6194 100644 --- a/paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h +++ b/paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h @@ -16,33 +16,33 @@ limitations under the License. */ using phi::distributed::auto_parallel::str_join; -#define EXTRACT_SHAPE_AND_DIST_ATTR(x) \ - auto x##_shape = phi::vectorize(x.dims()); \ - int x##_ndim = x##_shape.size(); \ - auto x##_dist_attr_src = x.dist_attr(); \ - const auto& x##_dims_mapping_src = x##_dist_attr_src.dims_mapping(); \ - PADDLE_ENFORCE_EQ(x##_ndim, \ - x##_dims_mapping_src.size(), \ - phi::errors::InvalidArgument( \ - "[%d] [%d] The Tensor [%d]'s rank [%d] and Loss's " \ - "dims_mapping size [%d] are not matched.", \ - __FILE__, \ - __LINE__, \ - #x, \ - x##_ndim, \ +#define EXTRACT_SHAPE_AND_DIST_ATTR(x) \ + auto x##_shape = phi::vectorize(x.dims()); \ + int x##_ndim = x##_shape.size(); \ + auto x##_dist_attr_src = x.dist_attr(); \ + const auto& x##_dims_mapping_src = x##_dist_attr_src.dims_mapping(); \ + PADDLE_ENFORCE_EQ(x##_ndim, \ + x##_dims_mapping_src.size(), \ + phi::errors::InvalidArgument( \ + "[%d] [%d] The Tensor [%d]'s rank [%d] and " \ + "dims_mapping size [%d] are not matched.", \ + __FILE__, \ + __LINE__, \ + #x, \ + x##_ndim, \ x##_dims_mapping_src.size())) -#define EXTRACT_SHAPE_AND_DIST_ATTR_WITH_DIM_CK(x) \ - EXTRACT_SHAPE_AND_DIST_ATTR(x); \ - PADDLE_ENFORCE_EQ(x##_ndim, \ - x##_dims_mapping_src.size(), \ - phi::errors::InvalidArgument( \ - "[%d] [%d] The Tensor [%d]'s rank [%d] and Loss's " \ - "dims_mapping size [%d] are not matched.", \ - __FILE__, \ - __LINE__, \ - #x, \ - x##_ndim, \ +#define EXTRACT_SHAPE_AND_DIST_ATTR_WITH_DIM_CK(x) \ + EXTRACT_SHAPE_AND_DIST_ATTR(x); \ + PADDLE_ENFORCE_EQ(x##_ndim, \ + x##_dims_mapping_src.size(), \ + phi::errors::InvalidArgument( \ + "[%d] [%d] The Tensor [%d]'s rank [%d] and " \ + "dims_mapping size [%d] are not matched.", \ + __FILE__, \ + __LINE__, \ + #x, \ + x##_ndim, \ x##_dims_mapping_src.size())) #define LOG_SPMD_INPUT(name) \ @@ -50,7 +50,7 @@ using phi::distributed::auto_parallel::str_join; VLOG(4) << #name; \ VLOG(4) << "shape: [" << str_join(name##_shape) << "] " \ << "src_dist_attr: [" << name##_dist_attr_src.to_string() << "] " \ - << "src_dist_attr: [" << name##_dist_attr_dst.to_string() << "]"; \ + << "dst_dist_attr: [" << name##_dist_attr_dst.to_string() << "]"; \ } while (0) #define LOG_SPMD_OUTPUT(name) \ diff --git a/paddle/phi/infermeta/spmd_rules/swiglu.cc b/paddle/phi/infermeta/spmd_rules/swiglu.cc new file mode 100644 index 0000000000000..924a80c2e39a0 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/swiglu.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/elementwise.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +SpmdInfo SwiGLUInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y) { + // y.dist_attr() is empty means y is None + if (y.dist_attr() == TensorDistAttr()) { + PADDLE_THROW( + phi::errors::Unimplemented("The input y is not allowed to be None")); + } else { + return ElementwiseBinaryInferSpmd(x, y); + } +} + +SpmdInfo SwiGLUInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out) { + if (y.dist_attr() == TensorDistAttr()) { + PADDLE_THROW( + phi::errors::Unimplemented("The input y is not allowed to be None")); + } else { + return ElementwiseBinaryInferSpmdReverse(x, y, out); + } +} + +SpmdInfo SwiGLUGradInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& y, + const DistMetaTensor& out_grad) { + if (y.dist_attr() == TensorDistAttr()) { + PADDLE_THROW( + phi::errors::Unimplemented("The input y is not allowed to be None")); + } else { + return ElementwiseBinaryGradInferSpmd(x, y, out_grad); + } +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unbind.cc b/paddle/phi/infermeta/spmd_rules/unbind.cc new file mode 100644 index 0000000000000..0e869aad2674d --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/unbind.cc @@ -0,0 +1,182 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/unbind.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/spmd_rule_macro_define.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo UnbindInferSpmd(const DistMetaTensor& x, int axis) { + EXTRACT_SHAPE_AND_DIST_ATTR(x); + if (axis < 0) { + axis += x_ndim; + } + PADDLE_ENFORCE_LT( + axis, + x_ndim, + phi::errors::InvalidArgument("[%d] [%d] The axis [%d] should be less " + "than the rank of input tensor [%d].", + __FILE__, + __LINE__, + axis, + x_ndim)); + + // Step1: Build Einsum Notation + std::string alphabet = "abcdefghijklmnopqrstuvwxyz"; + // get einsum notation for input + std::string x_axes = alphabet.substr(0, x_ndim); + // get einsum notation for output + std::string out_axes(x_axes); + out_axes.erase(axis, 1); + + // Step2: Sharding Propagation + // Step2.1: merge input shardings + std::vector x_dims_mapping_dst(x_dims_mapping_src); + x_dims_mapping_dst[axis] = -1; + TensorDistAttr x_dist_attr_dst(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); + + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping_dst}}); + + // Step2.2: infer output dims mapping from merged input dims mapping + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map); + + // get the dist attributes for all outputs, the + // dist attributes are same for all outputs. + int noutputs = x_shape[axis]; + std::vector out_dist_attrs; + for (int i = 0; i < noutputs; i++) { + out_dist_attrs.emplace_back(CopyTensorDistAttrForOutput(x_dist_attr_src)); + out_dist_attrs[i].set_dims_mapping(out_dims_mapping); + } + + // Step3 Handle input tensor partial (TODO) + VLOG(4) << "UnbindInferSpmd:"; + VLOG(4) << "Einsum Notation: " << x_axes << "-->" << out_axes; + VLOG(4) << "x:"; + VLOG(4) << "\tshape: [" << str_join(x_shape) << "] "; + VLOG(4) << "\tsrc_dist_attr: [" << x_dist_attr_src.to_string() << "]"; + VLOG(4) << "\tdst_dist_attr: [" << x_dist_attr_dst.to_string() << "]"; + for (int64_t i = 0; i < noutputs; i++) { + VLOG(4) << "out" << std::to_string(i); + VLOG(4) << "\tdist_attr: [" << out_dist_attrs[i].to_string() << "]"; + } + VLOG(4) << std::endl; + // TODO(liuzhenhai): remedy this + // should return list in list [] + // return {{x_dist_attr_dst}, {out_dist_attrs}}; + return {{x_dist_attr_dst}, ToArgDistAttr(out_dist_attrs)}; +} + +SpmdInfo UnbindInferSpmdReverse(const DistMetaTensor& x, + const std::vector& outs, + int axis) { + // Step0: Verify input args based on split logic + EXTRACT_SHAPE_AND_DIST_ATTR(x); + int nouts = static_cast(outs.size()); + + for (int i = 0; i < nouts; i++) { + auto shape = common::vectorize(outs[i]->dims()); + int ndim = static_cast(shape.size()); + auto dist_attr = outs[i]->dist_attr(); + int dims_mapping_size = static_cast(dist_attr.dims_mapping().size()); + PADDLE_ENFORCE_EQ( + ndim, + dims_mapping_size, + phi::errors::InvalidArgument("The Tensor Out[%d]'s rank [%d] and Its " + "dims_mapping size [%d] are not matched.", + i, + ndim, + dims_mapping_size)); + } + + // Step1: Build Einsum Notation + if (axis < 0) { + axis += x_ndim; + } + std::string alphabet = "abcdefghijlmnopqrstuvwxyz"; + std::string x_axes = alphabet.substr(0, x_ndim); + std::string out_axes(x_axes); + out_axes.erase(axis, 1); + + // Step2: Sharding Propagation + // Step2.1: merge output shardings + std::vector>> axes_sharding_info; + for (int i = 0; i < nouts; i++) { + std::vector out_dims_mapping = outs[i]->dist_attr().dims_mapping(); + axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); + } + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // Step2.2: infer input dims mapping from output dims mapping + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map, true); + + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + + // step2.3 get new dist attribute for output. the splitted + // cannot be sharded, if it is sharded, set it to replicated. + std::vector out_dist_attrs_dst; + for (int i = 0; i < nouts; i++) { + out_dist_attrs_dst.emplace_back( + CopyTensorDistAttrForOutput(outs[i]->dist_attr())); + std::vector out_dims_mapping = + GetDimsMappingForAxes(out_axes, axis_to_dim_map, true); + out_dist_attrs_dst[i].set_dims_mapping(out_dims_mapping); + } + + // step3 Handle input tensor partial (TODO) + + VLOG(4) << "UnbindInferSpmdReverse:"; + for (int i = 0; i < nouts; i++) { + VLOG(4) << "out" << std::to_string(i) << ":"; + VLOG(4) << "\tsrc_dist_attr: [" << outs[i]->dist_attr().to_string() << "]"; + VLOG(4) << "\tdst_dist_attr: [" << out_dist_attrs_dst[i].to_string() << "]"; + } + VLOG(4) << "x:"; + VLOG(4) << "\tsrc_dist_attr: [" << x_dist_attr_src.to_string() << "]"; + VLOG(4) << "\tdst_dist_attr: [" << x_dist_attr_dst.to_string() << "]"; + return {{x_dist_attr_dst}, ToArgDistAttr(out_dist_attrs_dst)}; +} + +SpmdInfo UnbindInferSpmdDynamic(const DistMetaTensor& x, int axis) { + auto tmp = UnbindInferSpmd(x, axis); + // bridge the diff concerning vector output between static and dynamic auto + // parallel ToDo(liuzhenhai): unify the difference between static and dynamic + SpmdInfo ret; + ret.first = tmp.first; + std::vector out_dist_attrs; + for (const auto& out : tmp.second) { + out_dist_attrs.push_back(PADDLE_GET_CONST(TensorDistAttr, out)); + } + ret.second = {out_dist_attrs}; + return ret; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unbind.h b/paddle/phi/infermeta/spmd_rules/unbind.h new file mode 100644 index 0000000000000..2daac013e8c0e --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/unbind.h @@ -0,0 +1,38 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo UnbindInferSpmd(const DistMetaTensor& x, int axis); + +SpmdInfo UnbindInferSpmdReverse(const DistMetaTensor& x, + const std::vector& outs, + int axis); + +SpmdInfo UnbindInferSpmdDynamic(const DistMetaTensor& x, int axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc index cbb010fe6c6bf..f7e16d4bb33da 100644 --- a/paddle/phi/infermeta/spmd_rules/unsqueeze.cc +++ b/paddle/phi/infermeta/spmd_rules/unsqueeze.cc @@ -74,7 +74,7 @@ std::vector> MakeUnsqueezeDimTransReverse( ret.resize(x_ndim); fill(ret.begin(), ret.end(), std::make_shared()); - for (int64_t i = 0, j = 0; i < out_ndim; i++) { + for (int64_t i = 0, j = 0; i < out_ndim; i++) { // NOLINT auto it = find(axis.begin(), axis.end(), i); if (it == axis.end()) { @@ -93,7 +93,7 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, const std::vector& axis) { // Step0: Verify input args based on unsqueeze logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto x_dist_attr_src = x.dist_attr(); std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -110,9 +110,9 @@ SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x, std::vector out_shape; std::vector axis_copy(axis); - for (int64_t i = 0; i < static_cast(axis_copy.size()); i++) { - if (axis_copy[i] < 0) { - axis_copy[i] += x_ndim + 1; + for (auto& i : axis_copy) { + if (i < 0) { + i += x_ndim + 1; } } @@ -162,9 +162,9 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, const std::vector& axis) { // Step0: Verify input args based on unsqueeze logic auto x_shape = common::vectorize(x.dims()); - int x_ndim = x_shape.size(); + int x_ndim = static_cast(x_shape.size()); auto out_shape = common::vectorize(out.dims()); - int out_ndim = out_shape.size(); + int out_ndim = static_cast(out_shape.size()); auto out_dist_attr_src = out.dist_attr(); std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); PADDLE_ENFORCE_EQ( @@ -183,9 +183,9 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, std::vector axis_copy(axis); - for (int64_t i = 0; i < static_cast(axis_copy.size()); i++) { - if (axis_copy[i] < 0) { - axis_copy[i] += x_ndim + 1; + for (auto& i : axis_copy) { + if (i < 0) { + i += x_ndim + 1; } } @@ -217,7 +217,7 @@ SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x, VLOG(4) << "UnsqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape) << "] X shape: [" << str_join(x_shape) << "]"; VLOG(4) << "Transformation from output to input:"; - for (int64_t i = 0, n = trans.size(); i < n; i++) { + for (int64_t i = 0, n = static_cast(trans.size()); i < n; i++) { std::shared_ptr t = trans[i]; VLOG(4) << "\tX axis[" << i << "]: " << t->to_string(); } diff --git a/paddle/phi/infermeta/spmd_rules/utils.cc b/paddle/phi/infermeta/spmd_rules/utils.cc index b67d7bd251b1b..336924dd5e951 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.cc +++ b/paddle/phi/infermeta/spmd_rules/utils.cc @@ -423,13 +423,14 @@ TensorDistAttr FromPlacements( auto& placement = placements[mesh_dim]; if (placement->is_shard()) { auto shard_placement = std::dynamic_pointer_cast(placement); - dims_mapping[shard_placement->get_axis()] = mesh_dim; + dims_mapping[shard_placement->get_axis()] = + static_cast(mesh_dim); } if (placement->is_partial()) { auto partial_placement = std::dynamic_pointer_cast(placement); auto reduce_type = partial_placement->get_reduce_type(); - partial_status[mesh_dim] = reduce_type; + partial_status[mesh_dim] = reduce_type; // NOLINT } } dst_dist_attr.set_dims_mapping(dims_mapping); @@ -470,7 +471,7 @@ std::vector GetLocalShape( for (size_t i = 0; i < n_placement; i++) { auto& placement = placements.at(i); if (placement->is_shard()) { - auto mesh_dim_size = mesh.dim_size(i); + auto mesh_dim_size = mesh.dim_size(i); // NOLINT auto shard_dim = std::dynamic_pointer_cast(placement)->get_axis(); auto split_size = diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index edd03e6b07513..f10a86b33836a 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -146,6 +146,47 @@ void AddmmInferMeta(const MetaTensor& input, out->set_dtype(input.dtype()); } +void BatchFCInferMeta(const MetaTensor& input, + const MetaTensor& w, + const MetaTensor& bias, + MetaTensor* out) { + auto input_dims = input.dims(); + auto w_dims = w.dims(); + + PADDLE_ENFORCE_EQ( + input_dims.size(), + 3, + phi::errors::InvalidArgument("Input of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + w_dims.size(), + 3, + phi::errors::InvalidArgument("W of BatchFCOp should have 3D.")); + PADDLE_ENFORCE_EQ( + input_dims[0], + w_dims[0], + phi::errors::InvalidArgument( + "Input.dim[0] and W.dim[0] of BatchFCOp should be same.")); + PADDLE_ENFORCE_EQ( + input_dims[2], + w_dims[1], + phi::errors::InvalidArgument( + "Input.dim[2] and W.dim[1] of BatchFCOp should be same.")); + + auto bias_dims = bias.dims(); + PADDLE_ENFORCE_EQ(bias_dims[0], + input_dims[0], + phi::errors::InvalidArgument( + "Bias.dim[0] should be same as input.dim[0].")); + PADDLE_ENFORCE_EQ(bias_dims[1], + w_dims[2], + phi::errors::InvalidArgument( + "Bias.dim[1] should be same as input.dim[2].")); + + out->set_dims({input_dims[0], input_dims[1], w_dims[2]}); + out->share_lod(input); + out->set_dtype(input.dtype()); +} + void BoxCoderInferMeta(const MetaTensor& prior_box, const MetaTensor& prior_box_var, const MetaTensor& target_box, @@ -255,6 +296,37 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, output_box->set_dtype(target_box.dtype()); } +void DistributedPushSparseInferMeta( + const std::vector& ids, + const std::vector& shows, + const std::vector& clicks, + int table_id, + int size, + bool is_distributed, + const std::string& push_sparse_version, + int64_t padding_idx, + DataType dtype, + bool is_test, + bool use_cvm_op, + std::vector output) { + auto ids_size = ids.size(); + std::vector ids_dims; + ids_dims.reserve(ids.size()); + for (size_t i = 1; i < ids_size; ++i) { + PADDLE_ENFORCE_EQ(ids_dims[i].size(), + 2, + phi::errors::InvalidArgument( + "The dimension of the 'Ids' tensor must be 2.")); + } + + for (auto& out : output) { + if (out == nullptr) { + continue; + } + out->set_dtype(ids[0]->dtype()); + } +} + void DpsgdInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -430,6 +502,33 @@ void InstanceNormInferMeta(const MetaTensor& x, } } +void GlobalScatterInferMeta(const MetaTensor& x, + const MetaTensor& local_count, + const MetaTensor& global_count, + int ring_id, + bool use_calc_stream, + MetaTensor* out) { + PADDLE_ENFORCE_GE( + ring_id, + 0, + phi::errors::InvalidArgument( + "The ring_id (%d) for global scatter op must be non-negative.", + ring_id)); + auto input_dims = x.dims(); + auto ndim_input = input_dims.size(); + // dim check + PADDLE_ENFORCE_EQ( + ndim_input, + 2, + phi::errors::InvalidArgument("The input tensor's dimension must be 2. " + "But received input's dimension = %d.", + ndim_input)); + + phi::DDim out_dims = common::make_ddim({-1, -1}); + out->set_dims(out_dims); + out->set_dtype(x.dtype()); +} + void GroupNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, @@ -1006,6 +1105,74 @@ void PutAlongAxisInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void RandomRoutingInferMeta(const MetaTensor& prob, + const MetaTensor& topk_value, + const MetaTensor& topk_idx, + MetaTensor* out) { + // check dims + auto topk_val_dims = topk_value.dims(); + auto prob_dims = prob.dims(); + auto topk_idx_dims = topk_idx.dims(); + + PADDLE_ENFORCE_EQ(prob_dims[0], + topk_val_dims[0], + phi::errors::InvalidArgument( + "Output(Out) of ScatterNdAddOp should not be null.")); + + PADDLE_ENFORCE_EQ(topk_idx_dims[1], + topk_val_dims[1], + phi::errors::InvalidArgument( + "Output(Out) of ScatterNdAddOp should not be null.")); + + PADDLE_ENFORCE_EQ(topk_idx_dims[0], + topk_val_dims[0], + phi::errors::InvalidArgument( + "Output(Out) of ScatterNdAddOp should not be null.")); + + out->set_dims(topk_idx_dims); + out->set_dtype(topk_idx.dtype()); + out->share_lod(topk_idx); +} + +void RankAttentionInferMeta(const MetaTensor& x, + const MetaTensor& rank_offset, + const MetaTensor& rank_param, + int max_rank, + int max_size, + MetaTensor* input_help, + MetaTensor* out, + MetaTensor* ins_rank) { + auto x_dims = x.dims(); + auto ins_num = x_dims[0]; + auto param_dims = rank_param.dims(); + auto para_col = param_dims[1]; + auto rank_offset_dims = rank_offset.dims(); + auto x_fea_dim = x_dims[1]; + auto block_matrix_row = max_rank * x_fea_dim; + + PADDLE_ENFORCE_EQ( + (rank_offset_dims[1] - 1) / 2, + max_rank, + phi::errors::InvalidArgument("Input(RankOffset) has wrong columns, " + "except columns to be %d, but got %d", + max_rank, + (rank_offset_dims[1] - 1) / 2)); + + std::vector out_dims({ins_num, para_col}); + out->set_dims(common::make_ddim(out_dims)); + out->set_dtype(x.dtype()); + + std::vector input_help_dims({ins_num, block_matrix_row}); + input_help->set_dims(common::make_ddim(input_help_dims)); + input_help->set_dtype(x.dtype()); + + std::vector ins_rank_dims({ins_num, 1}); + ins_rank->set_dims(common::make_ddim(ins_rank_dims)); + ins_rank->set_dtype(x.dtype()); + + out->share_lod(x); +} + void RoiAlignInferMeta(const MetaTensor& x, const MetaTensor& boxes, const MetaTensor& boxes_num, diff --git a/paddle/phi/infermeta/ternary.h b/paddle/phi/infermeta/ternary.h index d12378fe3a92c..c1c1af6f08218 100644 --- a/paddle/phi/infermeta/ternary.h +++ b/paddle/phi/infermeta/ternary.h @@ -53,6 +53,11 @@ void ArangeTensorInferMeta(const MetaTensor& start, const MetaTensor& step, MetaTensor* out); +void BatchFCInferMeta(const MetaTensor& input, + const MetaTensor& w, + const MetaTensor& bias, + MetaTensor* out); + void BoxCoderInferMeta(const MetaTensor& prior_box, const MetaTensor& prior_box_var, const MetaTensor& target_box, @@ -63,6 +68,20 @@ void BoxCoderInferMeta(const MetaTensor& prior_box, MetaTensor* output_box, MetaConfig config = MetaConfig()); +void DistributedPushSparseInferMeta( + const std::vector& ids, + const std::vector& shows, + const std::vector& clicks, + int table_id, + int size, + bool is_distributed, + const std::string& push_sparse_version, + int64_t padding_idx, + DataType dtype, + bool is_test, + bool use_cvm_op, + std::vector output); + void DpsgdInferMeta(const MetaTensor& param, const MetaTensor& grad, const MetaTensor& learning_rate, @@ -89,6 +108,13 @@ void InstanceNormInferMeta(const MetaTensor& x, MetaTensor* saved_variance, MetaConfig config = MetaConfig()); +void GlobalScatterInferMeta(const MetaTensor& x, + const MetaTensor& local_count, + const MetaTensor& global_count, + int ring_id, + bool use_calc_stream, + MetaTensor* out); + void GroupNormInferMeta(const MetaTensor& x, const MetaTensor& scale, const MetaTensor& bias, @@ -179,6 +205,20 @@ void PutAlongAxisInferMeta(const MetaTensor& x, const std::string& reduce, MetaTensor* out); +void RandomRoutingInferMeta(const MetaTensor& prob, + const MetaTensor& topk_value, + const MetaTensor& topk_idx, + MetaTensor* out); + +void RankAttentionInferMeta(const MetaTensor& x, + const MetaTensor& rank_offset, + const MetaTensor& rank_param, + int max_rank, + int max_size, + MetaTensor* input_help, + MetaTensor* out, + MetaTensor* ins_rank); + void RoiAlignInferMeta(const MetaTensor& x, const MetaTensor& boxes, const MetaTensor& boxes_num, diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 5648ff0d469a3..a152bc152ae6b 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -236,7 +236,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, if (!config.is_runtime && axis.FromTensor()) { std::vector vec; if (flatten) { - if (keepdims) { + if (keepdims) { // NOLINT vec = std::vector(x.dims().size(), -1); } else { vec = {}; @@ -307,7 +307,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x, std::vector vec; if (flatten) { - if (keepdims) { + if (keepdims) { // NOLINT vec = std::vector(x.dims().size(), 1); } else { vec = {}; @@ -738,6 +738,23 @@ void CropInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void CScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { + auto dim = x.dims(); + dim[0] = dim[0] / nranks; + if (dim[0] < 0) dim[0] = -1; + out->set_dims(dim); + out->set_dtype(x.dtype()); +} + +void CSplitInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { + phi::DDim dim = x.dims(); + dim[dim.size() - 1] = dim[dim.size() - 1] / nranks; + if (dim[0] < 0) dim[0] = -1; + out->set_dims(dim); + out->set_layout(x.layout()); + out->set_dtype(x.dtype()); +} + void DecodeJpegInferMeta(const MetaTensor& x, const std::string& mode, MetaTensor* out) { @@ -1202,7 +1219,7 @@ void EinsumRawInferMeta(const std::vector& inputs, void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out) { -#define MAX_RANK_SUPPORTED 6 +#define EXPAND_MAX_RANK_SUPPORTED 8 auto x_dims = x.dims(); auto expand_shape = shape.GetData(); @@ -1221,11 +1238,11 @@ void ExpandInferMeta(const MetaTensor& x, static_cast(x_dims.size()))); PADDLE_ENFORCE_LE( expand_shape.size(), - MAX_RANK_SUPPORTED, + EXPAND_MAX_RANK_SUPPORTED, phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for " "must not be greater than %d.", expand_shape.size(), - MAX_RANK_SUPPORTED)); + EXPAND_MAX_RANK_SUPPORTED)); PADDLE_ENFORCE_GE( expand_shape.size(), 0, @@ -1266,6 +1283,7 @@ void ExpandInferMeta(const MetaTensor& x, if (out_rank > 0 && out_shape[0] == x_dims[0]) { out->share_lod(x); } +#undef EXPAND_MAX_RANK_SUPPORTED } void FillAnyLikeInferMeta(const MetaTensor& x, @@ -2185,7 +2203,7 @@ void KthvalueInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); indices->set_dims(dims); indices->share_lod(x); - indices->set_dtype(x.dtype()); + indices->set_dtype(DataType::INT64); } void LogicalNotInferMeta(const MetaTensor& x, MetaTensor* out) { @@ -2567,14 +2585,12 @@ void MultinomialInferMeta(const MetaTensor& x, void NanmedianInferMeta(const MetaTensor& x, const IntArray& axes, bool keep_dim, + const std::string& mode, MetaTensor* out, MetaTensor* median_index) { std::vector axis_list = axes.GetData(); auto x_dim = x.dims(); int64_t x_rank = x_dim.size(); - out->set_dtype(x.dtype()); - median_index->set_dtype(DataType::INT64); - median_index->set_dims(common::make_ddim({x.numel() * 2})); std::vector out_dim; if (axis_list.empty()) { @@ -2584,7 +2600,7 @@ void NanmedianInferMeta(const MetaTensor& x, } } } else { - std::vector formated_axis; + std::vector formatted_axis; for (auto& axis : axis_list) { if (x_rank == 0) { PADDLE_ENFORCE_EQ(axis == 0 || axis == -1, @@ -2612,25 +2628,32 @@ void NanmedianInferMeta(const MetaTensor& x, } if (axis < 0) axis += x_rank; PADDLE_ENFORCE_EQ( - std::find(formated_axis.begin(), formated_axis.end(), axis), - formated_axis.end(), + std::find(formatted_axis.begin(), formatted_axis.end(), axis), + formatted_axis.end(), errors::InvalidArgument("Attr(axes) has duplicated elements: %d.", static_cast(axis))); - formated_axis.push_back(axis); + formatted_axis.push_back(axis); } for (int64_t i = 0; i < x_rank; i++) { - if (std::find(formated_axis.begin(), formated_axis.end(), i) == - formated_axis.end()) { + if (std::find(formatted_axis.begin(), formatted_axis.end(), i) == + formatted_axis.end()) { out_dim.push_back(x_dim[i]); // NOLINT } else if (keep_dim) { out_dim.push_back(1); } } } + out->set_dtype(x.dtype()); + out->set_dims(make_ddim(out_dim)); - out->set_dims(common::make_ddim(out_dim)); + auto median_dim = out_dim; + if (mode == "avg") { + median_dim.push_back(2); + } + median_index->set_dtype(DataType::INT64); + median_index->set_dims(make_ddim(median_dim)); } void NMSInferMeta(const MetaTensor& x, float threshold, MetaTensor* out) { @@ -2915,6 +2938,29 @@ void Pad3dInferMeta(const MetaTensor& x, out->share_lod(x); } +void PartialAllgatherInferMeta(const MetaTensor& x, + int nranks, + int rank, + int ring_id, + bool use_calc_stream, + MetaTensor* out) { + PADDLE_ENFORCE_GE( + nranks, + 2, + phi::errors::InvalidArgument("The value of nranks should be >=2.")); + PADDLE_ENFORCE_EQ( + (rank >= 0 && rank < nranks), + true, + phi::errors::InvalidArgument( + "The rank (%d) for partial_allgather op must >=0 and set_dims(x_dims); + out->set_dtype(x.dtype()); +} + void PartialSendInferMeta(const MetaTensor& x, int ring_id, int peer, @@ -3159,7 +3205,7 @@ void Pool2DInferMeta(const MetaTensor& x, (data_format == "NHWC" || data_format == "NDHWC"); if (!config.is_runtime && kernel_size.FromTensor()) { auto x_dims = x.dims(); - std::vector output_shape = std::move(common::vectorize(x_dims)); + std::vector output_shape = common::vectorize(x_dims); // set dims of HW -1 output_shape[x_dims.size() - 2] = -1; if (channel_last) { // for NHWC, NDHWC @@ -3332,6 +3378,17 @@ void PoolInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void PushDenseInferMeta(const std::vector& ids, + int table_id, + float scale_data_norm, + const std::vector& input_names) { + auto ids_num = ids.size(); + PADDLE_ENFORCE_GE(ids_num, + 1UL, + phi::errors::InvalidArgument( + "Input(Ids) of PushDenseOp can not be null.")); +} + void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_dims(x.dims()); out->set_dtype(dtype::ToReal(x.dtype())); @@ -3382,7 +3439,7 @@ DDim ReduceInferDim(const MetaTensor& x, bool reduce_all) { int x_rank = x.dims().size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); ++i) { if (x_rank == 0) { PADDLE_ENFORCE_EQ( @@ -3414,12 +3471,12 @@ DDim ReduceInferDim(const MetaTensor& x, } if (axis[i] < 0) { - formated_axis[i] = axis[i] + x_rank; + formatted_axis[i] = axis[i] + x_rank; } } bool full_dim = true; - std::set dims_set(formated_axis.begin(), formated_axis.end()); + std::set dims_set(formatted_axis.begin(), formatted_axis.end()); for (int64_t i = 0; i < x_rank; ++i) { if (dims_set.find(i) == dims_set.end()) { full_dim = false; @@ -3848,7 +3905,6 @@ void SliceArrayDenseInferMeta(const MetaTensor& input, if (config.is_runtime) { return; } - // out->set_dims(input.dims()); out->set_dtype(input.dtype()); out->set_dims(input.dims()); } @@ -4034,7 +4090,8 @@ void SplitInferMeta(const MetaTensor& x, if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { std::vector out_dims; - if ((sections.FromTensor() && !config.is_runtime) || axis_value == -1) { + if ((sections.FromTensor() && !config.is_runtime) || + axis_value == -1) { // NOLINT out_dims = std::vector( sections_data.size(), common::make_ddim(std::vector(x.dims().size(), -1))); @@ -4126,7 +4183,7 @@ void SplitWithNumInferMeta(const MetaTensor& x, // fill out dims with -1 if (axis_value == -1 || (axis_value >= 0 && x.dims().at(axis_value) <= 0)) { std::vector out_dims; - if (axis_value == -1) { + if (axis_value == -1) { // NOLINT out_dims = std::vector( num, common::make_ddim(std::vector(x.dims().size(), -1))); } else { @@ -4147,7 +4204,7 @@ void SplitWithNumInferMeta(const MetaTensor& x, } } else { auto input_axis_dim = x.dims().at(axis_value); - // step1: get formated sections + // step1: get formatted sections std::vector sections_vec; PADDLE_ENFORCE_NE( num, @@ -4435,6 +4492,140 @@ void SumInferMeta(const MetaTensor& x, SumRawInferMeta(x, axis, keep_dim, reduce_all, dtype, out, config); } +void PartialSumInferMeta(const std::vector& xs, + int start_index, + int length, + MetaTensor* out, + MetaConfig config) { + int64_t batch_size = -1; + int64_t input_len = -1; + + auto inputs_num = xs.size(); + PADDLE_ENFORCE_GT(inputs_num, + 0, + phi::errors::InvalidArgument( + "ShapeError: Input tensors count should > 0. But " + "received inputs' length is 0.")); + + if (inputs_num == 1) { + VLOG(3) << "Warning: partial_sum op have only one input, may be useless"; + } + + // Only support two dimensions now, should be extended later + // when length is -1, need make sure all dimensions to be added are the same + for (size_t i = 0; i < inputs_num; i++) { + auto x_dim = xs[i]->dims(); + + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + phi::errors::InvalidArgument("Only support two dimensions input now.")); + + if (i == 0) { + batch_size = x_dim[0]; + input_len = x_dim[1]; + } else { + // each tensor's dim must eq + PADDLE_ENFORCE_EQ(x_dim[0], + batch_size, + phi::errors::InvalidArgument( + "The batch size of all inputs must be same")); + PADDLE_ENFORCE_EQ(x_dim[1], + input_len, + phi::errors::InvalidArgument( + "The input len of all inputs must be same")); + } + } + PADDLE_ENFORCE_GT( + input_len, + start_index, + phi::errors::OutOfRange("start_index must be less than input len")); + if (length > 0) { + PADDLE_ENFORCE_GE(input_len, + start_index + length, + phi::errors::OutOfRange( + "start_index + length is larger than input length")); + } + + std::vector out_dims(2); + out_dims[0] = batch_size; + out_dims[1] = (length == -1) ? input_len - start_index : length; + DDim out_dim = common::make_ddim(out_dims); + out->set_dims(out_dim); + out->set_dtype(xs[0]->dtype()); +} + +void PartialConcatInferMeta(const std::vector& xs, + int start_index, + int length, + MetaTensor* out, + MetaConfig config) { + int64_t batch_size = -1; + int64_t input_len = -1; + + auto inputs_num = xs.size(); + PADDLE_ENFORCE_GT(inputs_num, + 0, + phi::errors::InvalidArgument( + "ShapeError: Input tensors count should > 0. But " + "received inputs' length is 0.")); + + // Only support two dimensions now, should be extended later + // when length is -1, need make sure all dimensions to be added are the same + for (size_t i = 0; i < inputs_num; i++) { + auto x_dim = xs[i]->dims(); + + PADDLE_ENFORCE_EQ( + x_dim.size(), + 2, + phi::errors::InvalidArgument("Only support two dimensions input now.")); + + if (i == 0) { + batch_size = x_dim[0]; + input_len = x_dim[1]; + } else { + // each tensor's dim must eq + PADDLE_ENFORCE_EQ(x_dim[0], + batch_size, + phi::errors::InvalidArgument( + "The batch size of all inputs must be same")); + PADDLE_ENFORCE_EQ(x_dim[1], + input_len, + phi::errors::InvalidArgument( + "The input len of all inputs must be same")); + } + } + + PADDLE_ENFORCE_EQ( + start_index >= -input_len && start_index < input_len, + true, + phi::errors::InvalidArgument( + "The start_index is expected to be in range of [%d, %d), but got %d", + -input_len, + input_len, + start_index)); + + if (start_index < 0) { + start_index += input_len; + } + + if (length > 0) { + PADDLE_ENFORCE_GE(input_len, + start_index + length, + phi::errors::OutOfRange( + "start_index + length is larger than input length")); + } + + std::vector out_dims(2); + out_dims[0] = batch_size; + // colnum = input_num * length + out_dims[1] = (length < 0) ? input_len - start_index : length; + out_dims[1] *= inputs_num; + DDim out_dim = common::make_ddim(out_dims); + out->set_dims(out_dim); + out->set_dtype(xs[0]->dtype()); +} + void SvdInferMeta(const MetaTensor& x, bool full_matrices, MetaTensor* u, @@ -4532,7 +4723,7 @@ void TileInferMeta(const MetaTensor& x, const IntArray& repeat_times, MetaTensor* out, MetaConfig config) { -#define MAX_RANK_SUPPORTED 6 +#define TILE_MAX_RANK_SUPPORTED 6 auto repeat_times_data = repeat_times.GetData(); auto x_dims = x.dims(); @@ -4542,19 +4733,19 @@ void TileInferMeta(const MetaTensor& x, PADDLE_ENFORCE_LE( x_dims.size(), - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, errors::InvalidArgument( "The rank of the input 'x' for tile op " "must not be greater than %d, but the value received is %d.", - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, x_dims.size())); PADDLE_ENFORCE_LE( repeat_times_data.size(), - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, errors::InvalidArgument( "The size of the shape of input 'repeat_times' for tile op " "must not be greater than %d, but the value received is %d.", - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, repeat_times_data.size())); PADDLE_ENFORCE_GE( repeat_times_data.size(), @@ -4595,6 +4786,7 @@ void TileInferMeta(const MetaTensor& x, out->share_lod(x); } out->set_dtype(x.dtype()); +#undef TILE_MAX_RANK_SUPPORTED } void TopKInferMeta(const MetaTensor& x, @@ -4756,7 +4948,7 @@ void TransposeInferMeta(const MetaTensor& x, x_rank, axis_size)); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; std::vector count(axis_size, 0); for (int i = 0; i < axis_size; i++) { PADDLE_ENFORCE_LT(axis[i], @@ -4779,10 +4971,10 @@ void TransposeInferMeta(const MetaTensor& x, axis[i])); if (axis[i] < 0) { - formated_axis[i] = axis[i] + x_rank; + formatted_axis[i] = axis[i] + x_rank; } PADDLE_ENFORCE_EQ( - ++count[formated_axis[i]], + ++count[formatted_axis[i]], 1, errors::InvalidArgument("Each element of axis should be unique. but " "axis[%d] is %d appear not only once", @@ -4792,7 +4984,7 @@ void TransposeInferMeta(const MetaTensor& x, phi::DDim out_dims(x_dims); for (int i = 0; i < axis_size; ++i) { - out_dims[i] = x_dims[formated_axis[i]]; + out_dims[i] = x_dims[formatted_axis[i]]; } out->set_dims(out_dims); @@ -4875,6 +5067,14 @@ void UnchangedArrayInferMeta(const MetaTensor& x, MetaTensor* out) { out->set_layout(x.layout()); } +void UnchangedVectorInferMeta(const std::vector& xs, + std::vector outs) { + for (size_t i = 0; i < xs.size(); ++i) { + outs[i]->set_dtype(xs[i]->dtype()); + outs[i]->set_layout(xs[i]->layout()); + } +} + // meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1] void UnchangedInferMetaCheckAxis(const MetaTensor& x, int axis, @@ -5415,7 +5615,7 @@ void WeightQuantizeInferMeta(const MetaTensor& x, } std::vector dim_out; - if (algo == "weight_only_int8" || algo == "llm.int8") { + if (algo == "weight_only_int8" || algo == "llm.int8") { // NOLINT dim_out = std::vector({x_dims[1], x_dims[0]}); } else if (algo == "weight_only_int4") { dim_out = std::vector({x_dims[1] / 2, x_dims[0]}); diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index d62789bd5183c..29fc97955e87a 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace phi { -class MetaConfig; +struct MetaConfig; // Common InferMeta Functions for unary operators, The format like: // @@ -137,6 +137,10 @@ void CropInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void CScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); + +void CSplitInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); + void CumInferMeta(const MetaTensor& x, int axis, bool flatten, @@ -392,6 +396,7 @@ void MultinomialInferMeta(const MetaTensor& x, void NanmedianInferMeta(const MetaTensor& x, const IntArray& axes, bool keep_dim, + const std::string& mode, MetaTensor* out, MetaTensor* median_index); @@ -434,6 +439,13 @@ void Pad3dInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void PartialAllgatherInferMeta(const MetaTensor& x, + int nranks, + int rank, + int ring_id, + bool use_calc_stream, + MetaTensor* out); + void PartialSendInferMeta(const MetaTensor& x, int ring_id, int peer, @@ -496,6 +508,11 @@ void PSendInferMeta(const MetaTensor& x, int peer); void PSendArrayInferMeta(const MetaTensor& x, int peer); +void PushDenseInferMeta(const std::vector& ids, + int table_id, + float scale_data_norm, + const std::vector& input_names); + void SendV2InferMeta(const int peer, const int ring_id); void QrInferMeta(const MetaTensor& x, @@ -693,6 +710,18 @@ void SumRawInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); +void PartialConcatInferMeta(const std::vector& xs, + int start_index, + int length, + MetaTensor* out, + MetaConfig config = MetaConfig()); + +void PartialSumInferMeta(const std::vector& xs, + int start_index, + int length, + MetaTensor* out, + MetaConfig config = MetaConfig()); + void SvdInferMeta(const MetaTensor& x, bool full_matrices, MetaTensor* u, @@ -753,6 +782,8 @@ void UnchangedExceptLayoutInferMeta(const MetaTensor& x, MetaTensor* out); void UnchangedExceptDtypeInferMeta(const MetaTensor& x, MetaTensor* out); void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out); void UnchangedArrayInferMeta(const MetaTensor& x, MetaTensor* out); +void UnchangedVectorInferMeta(const std::vector& xs, + std::vector outs); // meta x -> out without change, check if axis in range [-Rank(x), Rank(x)-1] void UnchangedInferMetaCheckAxis(const MetaTensor& x, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 80d61ebc9a9a6..304fd3cef793a 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -213,6 +213,7 @@ if(WITH_ROCM) "gpu/put_along_axis_grad_kernel.cu" "gpu/put_along_axis_kernel.cu" "gpu/qr_kernel.cu" + "gpu/rms_norm_grad_kernel.cu" "gpu/svd_kernel.cu" "gpudnn/mha_cudnn_frontend.cu" "fusion/gpu/block_multi_head_attention_kernel.cu" diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index a992d1ab3312b..b2fae7b0406e0 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -89,7 +89,7 @@ void ReluDoubleGradKernel(const Context& dev_ctx, template void SinDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const paddle::optional& dout, + const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout); @@ -97,7 +97,7 @@ void SinDoubleGradKernel(const Context& dev_ctx, template void CosDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const paddle::optional& dout, + const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout); diff --git a/paddle/phi/kernels/autotune/gpu_timer.h b/paddle/phi/kernels/autotune/gpu_timer.h index b04c46351c2cf..1bdb6de30cf26 100644 --- a/paddle/phi/kernels/autotune/gpu_timer.h +++ b/paddle/phi/kernels/autotune/gpu_timer.h @@ -16,10 +16,10 @@ #include "paddle/common/errors.h" #include "paddle/phi/backends/context_pool.h" -#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/common/port.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/enforce.h" diff --git a/paddle/phi/kernels/bmm_kernel.h b/paddle/phi/kernels/bmm_kernel.h index 09e7f9647b68e..6d3733bf750d3 100644 --- a/paddle/phi/kernels/bmm_kernel.h +++ b/paddle/phi/kernels/bmm_kernel.h @@ -22,7 +22,7 @@ namespace phi { * @brief Bmm Kernel. * Applies batched matrix multiplication to two tensors. * - * Both of the two input tensors must be three-dementional + * Both of the two input tensors must be three-dimensional * and share the same batch size. * if x is a (b, m, k) tensor, y is a (b, k, n) tensor, * the output will be a (b, m, n) tensor. diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index cb821233004f8..3f26f8c388e66 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -438,11 +438,12 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, LogSigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) -PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL(log_double_grad, LogDoubleGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel) +PD_REGISTER_ACTIVATION_DOUBLE_GRAD_KERNEL_WITH_COMPLEX(log_double_grad, + LogDoubleGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 11312aa3a7972..92acf104fedcf 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -254,7 +254,9 @@ PD_REGISTER_KERNEL(log, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log2, CPU, ALL_LAYOUT, @@ -264,7 +266,9 @@ PD_REGISTER_KERNEL(log2, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log10, CPU, ALL_LAYOUT, @@ -274,7 +278,9 @@ PD_REGISTER_KERNEL(log10, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log1p, CPU, ALL_LAYOUT, @@ -284,7 +290,9 @@ PD_REGISTER_KERNEL(log1p, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_KERNEL_WITH_COMPLEX(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(round, RoundKernel) diff --git a/paddle/phi/kernels/cpu/all_gather_kernel.cc b/paddle/phi/kernels/cpu/all_gather_kernel.cc index 96433694ffb2b..f16dbe06e9c18 100644 --- a/paddle/phi/kernels/cpu/all_gather_kernel.cc +++ b/paddle/phi/kernels/cpu/all_gather_kernel.cc @@ -88,7 +88,9 @@ PD_REGISTER_KERNEL(all_gather, uint8_t, int16_t, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #ifdef PADDLE_WITH_CUSTOM_DEVICE PD_REGISTER_KERNEL(all_gather, @@ -103,5 +105,7 @@ PD_REGISTER_KERNEL(all_gather, uint8_t, int16_t, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/cpu/all_to_all_kernel.cc b/paddle/phi/kernels/cpu/all_to_all_kernel.cc index 3407a1828e208..5df84c5360de7 100644 --- a/paddle/phi/kernels/cpu/all_to_all_kernel.cc +++ b/paddle/phi/kernels/cpu/all_to_all_kernel.cc @@ -45,8 +45,7 @@ void AllToAllKernel(const phi::CustomContext& dev_ctx, std::vector sendbuf, recvbuf; std::vector sendsize(send_numel, nranks); - std::vector sendtype( - phi::ccl::ToCCLDataType(x.dtype()), nranks); + std::vector sendtype(x.dtype(), nranks); for (auto i = 0; i < nranks; ++i) { sendbuf.push_back(x.data() + i * send_numel); recvbuf.push_back(out->data() + i * send_numel); diff --git a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc index 1bdf25dd4eb82..e9c5ae6a39e4a 100644 --- a/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_grad_kernel.cc @@ -611,7 +611,7 @@ void BatchNormDoubleGradKernel( EigenArrayMap ddy_arr( ctx.template Alloc(&transformed_ddy), C, sample_size); ddy_arr.setZero(); - if (use_global_stats) { + if (use_global_stats) { // NOLINT // math: ddy = r * ddx * inv_var + ddbias + // ddscale * (x - mean) * inv_var if (ddX) { diff --git a/paddle/phi/kernels/cpu/batch_norm_kernel.cc b/paddle/phi/kernels/cpu/batch_norm_kernel.cc index 39d53fec10a9f..f6d5e97dc7245 100644 --- a/paddle/phi/kernels/cpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/cpu/batch_norm_kernel.cc @@ -159,7 +159,7 @@ void BatchNormKernel(const Context& ctx, // use SavedMean and SavedVariance to do normalize Eigen::Array inv_std(C); - if (global_stats) { + if (global_stats) { // NOLINT ConstEigenVectorArrayMap var_arr(variance.data(), C); inv_std = (var_arr + epsilon).sqrt().inverse(); } else { @@ -178,7 +178,7 @@ void BatchNormKernel(const Context& ctx, auto* Bias = bias.get_ptr(); Eigen::Array new_scale(C); Eigen::Array new_bias(C); - if (Scale && Bias) { + if (Scale && Bias) { // NOLINT ConstEigenVectorArrayMap scale_arr(Scale->data(), C); ConstEigenVectorArrayMap bias_arr(Bias->data(), C); new_scale = inv_std * scale_arr; diff --git a/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc index 1644f99850347..5c661b2304056 100644 --- a/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc @@ -96,4 +96,6 @@ PD_REGISTER_KERNEL(c_embedding_grad, phi::CEmbeddingGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/c_embedding_kernel.cc b/paddle/phi/kernels/cpu/c_embedding_kernel.cc index 67e4ffbe263ec..1343d8d22dcf8 100644 --- a/paddle/phi/kernels/cpu/c_embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/c_embedding_kernel.cc @@ -85,4 +85,6 @@ PD_REGISTER_KERNEL(c_embedding, phi::CEmbeddingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/data_kernel.cc b/paddle/phi/kernels/cpu/data_kernel.cc index 4ab0a01cb7172..2081b0bd8e748 100644 --- a/paddle/phi/kernels/cpu/data_kernel.cc +++ b/paddle/phi/kernels/cpu/data_kernel.cc @@ -70,6 +70,23 @@ PD_REGISTER_KERNEL(shadow_feed, phi::complex64, phi::complex128) {} +PD_REGISTER_KERNEL(shadow_feed_tensors, + CPU, + ALL_LAYOUT, + phi::ShadowFeedTensorsKernel, + bool, + uint8_t, + float, + int8_t, + int16_t, + int32_t, + int64_t, + double, + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} + PD_REGISTER_KERNEL(print_kernel, CPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/cpu/diag_grad_kernel.cc b/paddle/phi/kernels/cpu/diag_grad_kernel.cc index 5a2f15d11428a..7922029fa4fec 100644 --- a/paddle/phi/kernels/cpu/diag_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/diag_grad_kernel.cc @@ -70,4 +70,6 @@ PD_REGISTER_KERNEL(diag_grad, int, int64_t, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/diag_kernel.cc b/paddle/phi/kernels/cpu/diag_kernel.cc index fb15fcbe61f7e..3104a15dee552 100644 --- a/paddle/phi/kernels/cpu/diag_kernel.cc +++ b/paddle/phi/kernels/cpu/diag_kernel.cc @@ -70,4 +70,6 @@ PD_REGISTER_KERNEL(diag, int, float, double, - int64_t) {} + int64_t, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/dropout_grad_kernel.cc b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc index 9a48fb3994adb..305d734e51dd2 100644 --- a/paddle/phi/kernels/cpu/dropout_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc @@ -89,6 +89,7 @@ PD_REGISTER_KERNEL(dropout_grad, phi::DropoutGradRawKernel, float, double, + phi::dtype::float16, phi::dtype::bfloat16) {} PD_REGISTER_KERNEL( diff --git a/paddle/phi/kernels/cpu/dropout_kernel.cc b/paddle/phi/kernels/cpu/dropout_kernel.cc index 322ce0110d2bc..60c02e96d58c0 100644 --- a/paddle/phi/kernels/cpu/dropout_kernel.cc +++ b/paddle/phi/kernels/cpu/dropout_kernel.cc @@ -209,6 +209,7 @@ PD_REGISTER_KERNEL(dropout, phi::DropoutRawKernel, float, double, + phi::dtype::float16, phi::dtype::bfloat16) { kernel->OutputAt(1).SetDataType(phi::DataType::UINT8); } diff --git a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc index b7fdefe023e73..ed80148344e1f 100644 --- a/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_divide_kernel.cc @@ -35,7 +35,7 @@ void DivideKernel(const Context& dev_ctx, } else { auto x_dims = x.dims(); auto y_dims = y.dims(); - if (x_dims.size() >= y_dims.size()) { + if (x_dims.size() >= y_dims.size()) { // NOLINT funcs::ElementwiseCompute, T>( dev_ctx, x, y, funcs::DivideFunctor(), out, -1); } else { diff --git a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc index db833d93b1a60..87f90e4e94161 100644 --- a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc @@ -209,7 +209,9 @@ PD_REGISTER_KERNEL(embedding_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(embedding_sparse_grad, CPU, @@ -217,4 +219,6 @@ PD_REGISTER_KERNEL(embedding_sparse_grad, phi::EmbeddingSparseGradKernel, float, double, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc index 6ddccf509d588..0b4d5be40eb27 100644 --- a/paddle/phi/kernels/cpu/embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -123,4 +123,6 @@ PD_REGISTER_KERNEL(embedding, double, int8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/eye_kernel.cc b/paddle/phi/kernels/cpu/eye_kernel.cc index ef3489d3fae0d..f2e277d94250e 100644 --- a/paddle/phi/kernels/cpu/eye_kernel.cc +++ b/paddle/phi/kernels/cpu/eye_kernel.cc @@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eye, double, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/gather_grad_kernel.cc b/paddle/phi/kernels/cpu/gather_grad_kernel.cc index 456c7ea633cde..29ed2612adda7 100644 --- a/paddle/phi/kernels/cpu/gather_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_grad_kernel.cc @@ -72,4 +72,6 @@ PD_REGISTER_KERNEL(gather_grad, int, uint8_t, int64_t, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/gather_kernel.cc b/paddle/phi/kernels/cpu/gather_kernel.cc index 9f6e7d2291a1b..361063548e880 100644 --- a/paddle/phi/kernels/cpu/gather_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_kernel.cc @@ -67,4 +67,6 @@ PD_REGISTER_KERNEL(gather, int, uint8_t, int64_t, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/cpu/gather_tree_kernel.cc b/paddle/phi/kernels/cpu/gather_tree_kernel.cc index dac1441cb5006..3d403cf7327f2 100644 --- a/paddle/phi/kernels/cpu/gather_tree_kernel.cc +++ b/paddle/phi/kernels/cpu/gather_tree_kernel.cc @@ -54,11 +54,19 @@ void GatherTreeKernel(const Context &dev_ctx, parent, beam_size, phi::errors::InvalidArgument( - "The parents must be less than beam size, but received" + "The parents must be less than beam size, but received " "parents %d is greater than or equal to beam size %d. ", parent, beam_size)); + PADDLE_ENFORCE_GE( + parent, + 0, + phi::errors::InvalidArgument( + "The parents must be greater than or equal to 0, but received " + "parents %d is less than 0. ", + parent)); + idx = step * batch_size * beam_size + batch * beam_size; out_data[idx + beam] = ids_data[idx + parent]; parent = parents_data[idx + parent]; diff --git a/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc b/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc index 0fc6ae271460d..366f1d65cc8f0 100644 --- a/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc +++ b/paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc @@ -74,7 +74,7 @@ void Array2Poly(const T* box, template void PointVec2Poly(const std::vector>& vec, phi::funcs::gpc_polygon* poly) { - int pts_num = vec.size(); + size_t pts_num = vec.size(); (*poly).num_contours = 1; (*poly).hole = reinterpret_cast(malloc(sizeof(int))); // NOLINT (*poly).hole[0] = 0; diff --git a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc index 73ba727c3cb91..37f92ef526f28 100644 --- a/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_grad_kernel.cc @@ -21,11 +21,50 @@ namespace phi { +template +void CalcMedianMeanGrad(int64_t pre_dim, + int64_t stride, + const int64_t* m_data, + T* dx_data, + const T* dout_data) { + int64_t i = 0; + int64_t offset = 0; + for (i = 0; i < pre_dim; i++) { + if (m_data[2 * i] >= 0) { + if (m_data[2 * i] == m_data[2 * i + 1]) { + dx_data[offset + m_data[2 * i]] = dout_data[i]; + } else { + dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast(2.0); + dx_data[offset + m_data[2 * i + 1]] = + dout_data[i] / static_cast(2.0); + } + } + offset += stride; + } +} + +template +void CalcMedianMinGrad(int64_t pre_dim, + int64_t stride, + const int64_t* m_data, + T* dx_data, + const T* dout_data) { + int64_t i = 0; + int64_t offset = 0; + for (i = 0; i < pre_dim; i++) { + if (m_data[i] >= 0) { + dx_data[offset + m_data[i]] = dout_data[i]; + } + offset += stride; + } +} + template void CalcMedianGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& median_index, const DenseTensor& out_grad, + const std::string& mode, DenseTensor* x_grad) { T* dx_data = dev_ctx.template Alloc(x_grad); if (!dx_data) return; @@ -41,19 +80,10 @@ void CalcMedianGradKernel(const Context& dev_ctx, int64_t stride = x_dim[static_cast(rank - 1)]; int64_t pre_dim = numel / stride; - int64_t i = 0; - int64_t offset = 0; - for (i = 0; i < pre_dim; i++) { - if (m_data[2 * i] >= 0) { - if (m_data[2 * i] == m_data[2 * i + 1]) { - dx_data[offset + m_data[2 * i]] = dout_data[i]; - } else { - dx_data[offset + m_data[2 * i]] = dout_data[i] / static_cast(2.0); - dx_data[offset + m_data[2 * i + 1]] = - dout_data[i] / static_cast(2.0); - } - } - offset += stride; + if (mode == "avg") { + CalcMedianMeanGrad(pre_dim, stride, m_data, dx_data, dout_data); + } else { + CalcMedianMinGrad(pre_dim, stride, m_data, dx_data, dout_data); } } @@ -64,6 +94,7 @@ void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, const IntArray& axes, bool keepdim UNUSED, + const std::string& mode, DenseTensor* x_grad) { DenseTensor tmp_x; auto rank = x.dims().size(); @@ -71,14 +102,14 @@ void NanmedianGradKernel(const Context& dev_ctx, tmp_x = x; tmp_x.Resize({x.numel()}); CalcMedianGradKernel( - dev_ctx, tmp_x, median_index, out_grad, x_grad); + dev_ctx, tmp_x, median_index, out_grad, mode, x_grad); } else { funcs::PreprocessMedianKernel(dev_ctx, x, axes, &tmp_x); DenseTensor tmp_x_grad; tmp_x_grad.Resize(x_grad->dims()); CalcMedianGradKernel( - dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad); + dev_ctx, tmp_x, median_index, out_grad, mode, &tmp_x_grad); dev_ctx.template Alloc(x_grad); funcs::PostprocessMedianGradKernel( diff --git a/paddle/phi/kernels/cpu/nanmedian_kernel.cc b/paddle/phi/kernels/cpu/nanmedian_kernel.cc index a44a800c74123..2911d5c0fcec5 100644 --- a/paddle/phi/kernels/cpu/nanmedian_kernel.cc +++ b/paddle/phi/kernels/cpu/nanmedian_kernel.cc @@ -30,7 +30,8 @@ void CalcMedianFunc(const Context& dev_ctx, int64_t stride, int64_t pre_dim, T* o_ptr, - int64_t* m_ptr) { + int64_t* m_ptr, + const std::string& mode) { DenseTensor sort_out; DenseTensor sort_indices; auto sort_dim = x.dims(); @@ -51,12 +52,16 @@ void CalcMedianFunc(const Context& dev_ctx, int64_t offset = 0; int64_t i = 0; bool is_ori_odd = stride & 1; - if (ignore_nan) { + if (ignore_nan) { // ignore_nan - has nan value; sort_k = max_valid_num for (i = 0; i < pre_dim; i++) { offset = i * sort_k; if (nan_counts[i] == stride) { - m_ptr[i * 2] = -1; - m_ptr[i * 2 + 1] = -1; + if (mode == "avg") { + m_ptr[i * 2] = -1; + m_ptr[i * 2 + 1] = -1; // index is -1 + } else { + m_ptr[i] = -1; + } o_ptr[i] = sort_out_ptr[offset]; } else { int64_t nan_k = nan_counts[i] > 0 @@ -65,21 +70,34 @@ void CalcMedianFunc(const Context& dev_ctx, int64_t row_pos = static_cast(nan_k >> 1); int64_t pos = offset + row_pos; if (nan_k & 1) { - m_ptr[2 * i] = sort_indices_ptr[pos]; - m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + if (mode == "avg") { + m_ptr[2 * i] = sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + } else { + m_ptr[i] = sort_indices_ptr[pos]; + } o_ptr[i] = sort_out_ptr[pos]; } else { - m_ptr[2 * i] = - row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; - m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + // nan_k is even T m_val_left = row_pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; T m_val_right = sort_out_ptr[pos]; - o_ptr[i] = (m_val_left + m_val_right) / div_factor; + if (mode == "avg") { + m_ptr[2 * i] = + row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + o_ptr[i] = (m_val_left + m_val_right) / div_factor; + } else { + // mode == "min": output median value should be the left val since + // the sort_out is in ascending order + m_ptr[i] = + row_pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + o_ptr[i] = m_val_left; + } } } } - } else { + } else { // not ignore_nan - no nan value; sort_k = stride/2 + 1 if (is_ori_odd) { for (i = 0; i < pre_dim; i++) { offset = i * sort_k; @@ -92,12 +110,20 @@ void CalcMedianFunc(const Context& dev_ctx, for (i = 0; i < pre_dim; i++) { offset = i * sort_k; int64_t pos = offset + sort_k - 1; - m_ptr[2 * i] = - sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; - m_ptr[2 * i + 1] = sort_indices_ptr[pos]; T m_val_left = sort_k > 1 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; T m_val_right = sort_out_ptr[pos]; - o_ptr[i] = (m_val_left + m_val_right) / div_factor; + if (mode == "avg") { + m_ptr[2 * i] = + sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + m_ptr[2 * i + 1] = sort_indices_ptr[pos]; + o_ptr[i] = (m_val_left + m_val_right) / div_factor; + } else { + // mode == "min": output median value should be the left val since the + // sort_out is in ascending order + m_ptr[i] = + sort_k > 1 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + o_ptr[i] = m_val_left; + } } } } @@ -106,6 +132,7 @@ void CalcMedianFunc(const Context& dev_ctx, template void ProcessMedianKernel(const Context& dev_ctx, const DenseTensor& x, + const std::string& mode, DenseTensor* out, DenseTensor* median_index) { const T* x_data = x.data(); @@ -154,8 +181,12 @@ void ProcessMedianKernel(const Context& dev_ctx, if (total_nan_num == numel) { for (i = 0; i < pre_dim; i++) { out_data[i] = std::numeric_limits::quiet_NaN(); - m_data[2 * i] = -1; - m_data[2 * i + 1] = -1; + if (mode == "avg") { + m_data[2 * i] = -1; + m_data[2 * i + 1] = -1; // indices are all -1 + } else { + m_data[i] = -1; + } } return; } @@ -171,7 +202,8 @@ void ProcessMedianKernel(const Context& dev_ctx, stride, pre_dim, out_data, - m_data); + m_data, + mode); } template @@ -179,18 +211,23 @@ void NanmedianKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, bool keepdim UNUSED, + const std::string& mode, DenseTensor* out, DenseTensor* median_index) { DenseTensor tmp_x; auto rank = x.dims().size(); if ((axes.size() == 0) || rank <= 1) { tmp_x = x; - tmp_x.Resize({x.numel()}); + tmp_x.Resize({x.numel()}); // flatten } else { - funcs::PreprocessMedianKernel(dev_ctx, x, axes, &tmp_x); + funcs::PreprocessMedianKernel( + dev_ctx, + x, + axes, + &tmp_x); // resize to 2D so as to compute median on last axis } - ProcessMedianKernel(dev_ctx, tmp_x, out, median_index); + ProcessMedianKernel(dev_ctx, tmp_x, mode, out, median_index); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/rnn_grad_kernel.cc b/paddle/phi/kernels/cpu/rnn_grad_kernel.cc index a48d05b8d783e..8b26bf31de9bb 100644 --- a/paddle/phi/kernels/cpu/rnn_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/rnn_grad_kernel.cc @@ -1311,7 +1311,7 @@ void RnnGradKernel(const Context& dev_ctx, pre_state_grad, weight_grad_list); // run gru - } else if (is_rnn_relu(mode)) { + } else if (is_rnn_relu(mode)) { // NOLINT gate_num = 1; RnnGradFunc, SingleGradLayer, diff --git a/paddle/phi/kernels/cpu/rnn_kernel.cc b/paddle/phi/kernels/cpu/rnn_kernel.cc index a0035c6db4a75..5b594089793c8 100644 --- a/paddle/phi/kernels/cpu/rnn_kernel.cc +++ b/paddle/phi/kernels/cpu/rnn_kernel.cc @@ -868,7 +868,7 @@ void RnnKernel(const Context& dev_ctx, is_test, seed, reserve); - } else if (is_rnn_relu(mode)) { + } else if (is_rnn_relu(mode)) { // NOLINT gate_num = 1; RnnFunc void ScaleKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, DenseTensor* out) { // calc @@ -44,12 +44,7 @@ void ScaleKernel(const Context& dev_ctx, return; } phi::funcs::EigenScale, T>::Eval( - dev, - eigen_out, - eigen_x, - scale.to(), - static_cast(bias), - bias_after_scale); + dev, eigen_out, eigen_x, scale.to(), bias.to(), bias_after_scale); } } // namespace phi diff --git a/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc b/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc index 0d0210ac661c0..6097a3d1be679 100644 --- a/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/send_ue_recv_grad_kernel.cc @@ -378,10 +378,8 @@ void GraphSendUERecvGradOpKernelLaunchHelper( const auto& x_dims = x.dims(); const auto& y_dims = y.dims(); int64_t memset_size_x = 1, memset_size_y = 1; - int64_t slice_size = 1; for (int i = 0; i < x_dims.size(); i++) { memset_size_x *= x_dims[i]; - if (i > 0) slice_size *= x_dims[i]; } for (int i = 0; i < y_dims.size(); i++) { memset_size_y *= y_dims[i]; diff --git a/paddle/phi/kernels/cpu/send_uv_kernel.cc b/paddle/phi/kernels/cpu/send_uv_kernel.cc index 301611d13d7be..726acbf404107 100644 --- a/paddle/phi/kernels/cpu/send_uv_kernel.cc +++ b/paddle/phi/kernels/cpu/send_uv_kernel.cc @@ -65,11 +65,6 @@ void GraphSendUVOpKernelLaunchHelper(const Context& ctx, "should be greater than 0, but received %d.", index_size)); - auto out_dims = out->dims(); - int64_t memset_size = 1; - for (int i = 0; i < out_dims.size(); i++) { - memset_size *= out_dims[i]; - } ctx.template Alloc(out); T* out_data = out->data(); diff --git a/paddle/phi/kernels/cpu/top_k_kernel.cc b/paddle/phi/kernels/cpu/top_k_kernel.cc index 36956f243d656..0551b72ea4c13 100644 --- a/paddle/phi/kernels/cpu/top_k_kernel.cc +++ b/paddle/phi/kernels/cpu/top_k_kernel.cc @@ -89,14 +89,14 @@ static void FullTopK(Type input_height, }); // the nth-element will get the unorder elements, sort the element if (sorted) { - std::sort(col_vec.begin(), - col_vec.begin() + k - 1, - [&largest](const std::pair& l, - const std::pair& r) { - return (std::isnan(static_cast(l.first)) && - !std::isnan(static_cast(r.first))) || - (l.first > r.first); - }); + std::sort( + col_vec.begin(), + col_vec.begin() + k - 1, + [](const std::pair& l, const std::pair& r) { + return (std::isnan(static_cast(l.first)) && + !std::isnan(static_cast(r.first))) || + (l.first > r.first); + }); } } else { std::nth_element( diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc index bab9d47caa9aa..67f2b2ce9b403 100644 --- a/paddle/phi/kernels/cpu/transpose_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -29,10 +29,10 @@ void TransposeKernel(const Context& ctx, const std::vector& axis, DenseTensor* out) { size_t x_rank = x.dims().size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { - formated_axis[i] = static_cast(axis[i] + x_rank); + formatted_axis[i] = static_cast(axis[i] + x_rank); } } @@ -40,39 +40,39 @@ void TransposeKernel(const Context& ctx, if (out->numel() == 0) { return; } - int rank = static_cast(formated_axis.size()); + int rank = static_cast(formatted_axis.size()); switch (rank) { case 0: phi::Copy(ctx, x, ctx.GetPlace(), false, out); break; case 1: funcs::Transpose trans1; - trans1(ctx, x, out, formated_axis); + trans1(ctx, x, out, formatted_axis); break; case 2: funcs::Transpose trans2; - trans2(ctx, x, out, formated_axis); + trans2(ctx, x, out, formatted_axis); break; case 3: funcs::Transpose trans3; - trans3(ctx, x, out, formated_axis); + trans3(ctx, x, out, formatted_axis); break; case 4: funcs::Transpose trans4; - trans4(ctx, x, out, formated_axis); + trans4(ctx, x, out, formatted_axis); break; case 5: funcs::Transpose trans5; - trans5(ctx, x, out, formated_axis); + trans5(ctx, x, out, formatted_axis); break; case 6: funcs::Transpose trans6; - trans6(ctx, x, out, formated_axis); + trans6(ctx, x, out, formatted_axis); break; default: // for rank >= 7 situation funcs::TransposeNormal trans_normal; - trans_normal(ctx, x, out, formated_axis); + trans_normal(ctx, x, out, formatted_axis); } } diff --git a/paddle/phi/kernels/cpu/uniform_kernel.cc b/paddle/phi/kernels/cpu/uniform_kernel.cc index 5a85675bdeffa..900cf2f26a875 100644 --- a/paddle/phi/kernels/cpu/uniform_kernel.cc +++ b/paddle/phi/kernels/cpu/uniform_kernel.cc @@ -49,4 +49,5 @@ PD_REGISTER_KERNEL(uniform, phi::UniformKernel, float, double, + phi::dtype::float16, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc b/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc index e137e37a6bd19..d59960a79377a 100644 --- a/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc +++ b/paddle/phi/kernels/cpu/weighted_sample_neighbors_kernel.cc @@ -36,6 +36,14 @@ struct GraphWeightedNode { GraphWeightedNode(T node_id, float weight_key, T eid = 0) : node_id(node_id), weight_key(weight_key), eid(eid) {} + GraphWeightedNode(const GraphWeightedNode& other) { + if (this != &other) { + this->node_id = other.node_id; + this->weight_key = other.weight_key; + this->eid = other.eid; + } + } + GraphWeightedNode& operator=(const GraphWeightedNode& other) { if (this != &other) { this->node_id = other.node_id; diff --git a/paddle/phi/kernels/data_kernel.h b/paddle/phi/kernels/data_kernel.h index 6a90834baae2e..94d33f7e7ca98 100644 --- a/paddle/phi/kernels/data_kernel.h +++ b/paddle/phi/kernels/data_kernel.h @@ -36,6 +36,11 @@ void ShadowFeedKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out); +template +void ShadowFeedTensorsKernel(const Context& ctx, + const std::vector& xs, + std::vector outs); + template void PrintKernel(const Context& ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/elementwise_divide_grad_kernel.h b/paddle/phi/kernels/elementwise_divide_grad_kernel.h index c764f05c3983f..15b1e65a9cfdf 100644 --- a/paddle/phi/kernels/elementwise_divide_grad_kernel.h +++ b/paddle/phi/kernels/elementwise_divide_grad_kernel.h @@ -33,7 +33,8 @@ template void DivideDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, const DenseTensor& out, - const DenseTensor& dx, + const DenseTensor& grad_out, + const paddle::optional& dx, const paddle::optional& ddx, const paddle::optional& ddy, int axis, diff --git a/paddle/phi/kernels/empty_kernel.cc b/paddle/phi/kernels/empty_kernel.cc index 0250fdd3b1f69..eb818ae120f66 100644 --- a/paddle/phi/kernels/empty_kernel.cc +++ b/paddle/phi/kernels/empty_kernel.cc @@ -158,7 +158,8 @@ PD_REGISTER_KERNEL(empty, int, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(empty_like, Custom, ALL_LAYOUT, @@ -171,7 +172,8 @@ PD_REGISTER_KERNEL(empty_like, int, int64_t, bool, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } #endif diff --git a/paddle/phi/kernels/flash_attn_grad_kernel.h b/paddle/phi/kernels/flash_attn_grad_kernel.h index ef5458f4708eb..ac331df406c33 100644 --- a/paddle/phi/kernels/flash_attn_grad_kernel.h +++ b/paddle/phi/kernels/flash_attn_grad_kernel.h @@ -56,4 +56,22 @@ void FlashAttnGradKernel(const Context& ctx, DenseTensor* dk, DenseTensor* dv); +template +void FlashAttnWithSparseMaskGradKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& attn_mask_start_row_indices, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + float dropout, + bool causal, + int attn_mask_start_row, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv); + } // namespace phi diff --git a/paddle/phi/kernels/flash_attn_kernel.h b/paddle/phi/kernels/flash_attn_kernel.h index ec72d85a0babb..1550c48b5bf27 100644 --- a/paddle/phi/kernels/flash_attn_kernel.h +++ b/paddle/phi/kernels/flash_attn_kernel.h @@ -59,4 +59,23 @@ void FlashAttnKernel(const Context& ctx, DenseTensor* softmax_lse, DenseTensor* seed_offset); +template +void FlashAttnWithSparseMaskKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& attn_mask_start_row_indices, + const paddle::optional& fixed_seed_offset, + float dropout, + bool causal, + int attn_mask_start_row, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset); + } // namespace phi diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 8b83fcb0d10c1..ba1d9873ec2a4 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2445,6 +2445,13 @@ struct Log { HOSTDEVICE T operator()(const T& val) const { return std::log(val); } }; +template +struct Log> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log(std::complex(val))); + } +}; + template <> struct Log { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2484,11 +2491,35 @@ struct LogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct LogGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = + dout * (static_cast>(1) / x).unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Log2 { HOSTDEVICE T operator()(const T& val) const { return std::log2(val); } }; +template +struct Log2> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log(std::complex(val)) / + std::log(std::complex(2))); + } +}; + template <> struct Log2 { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2529,11 +2560,35 @@ struct Log2GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log2GradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x * static_cast>(log(2)))) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Log10 { HOSTDEVICE T operator()(const T& val) const { return std::log10(val); } }; +template +struct Log10> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log10(std::complex(val))); + } +}; + template <> struct Log10 { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2574,11 +2629,35 @@ struct Log10GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log10GradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x * static_cast>(log(10)))) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct Log1p { HOSTDEVICE T operator()(const T& val) const { return std::log1p(val); } }; +template +struct Log1p> { + HOSTDEVICE ComplexType operator()(const ComplexType& val) const { + return ComplexType(std::log(std::complex(1) + std::complex(val))); + } +}; + template <> struct Log1p { HOSTDEVICE dtype::float16 operator()(const dtype::float16& val) const { @@ -2618,6 +2697,23 @@ struct Log1pGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct Log1pGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast>(1) / + (x + static_cast>(1))) + .unaryExpr(Conj()); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct LogGradGradFunctor : public BaseActivationFunctor { template @@ -2651,6 +2747,42 @@ struct LogGradGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct LogGradGradFunctor> + : public BaseActivationFunctor> { + template + void operator()(const Device& dev, + const DenseTensor* X, + const DenseTensor* ddX, + DenseTensor* ddOut, + const DenseTensor* dOut, + DenseTensor* dX) const { + auto* d = dev.eigen_device(); + auto ddx = EigenVector>::Flatten( + GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad")); + auto x = EigenVector>::Flatten( + GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad")); + // ddout = ddx / x; dx = -(dout / x) * (ddx / x) + // calculate dx first, so ddout can inplace ddx + if (dX) { + auto dout = EigenVector>::Flatten( + GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad")); + auto dx = EigenVector>::Flatten( + GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad")); + dx.device(*d) = dout * static_cast>(-1) * ddx / + (x * x).unaryExpr(Conj()); + } + if (ddOut) { + auto ddout = EigenVector>::Flatten( + GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad")); + ddout.device(*d) = + ddx * static_cast>(1) / x.unaryExpr(Conj()); + } + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + // HardSwish = min(max(0, x+3), 6) * x / 6 template struct HardSwishFunctor : public BaseActivationFunctor { @@ -4642,6 +4774,16 @@ struct CudaLogFunctor : public BaseActivationFunctor { } }; +template +struct CudaLogFunctor> + : public BaseActivationFunctor> { + // log(x) = log(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>(log(arg_x)); + } +}; + template struct CudaLogGradFunctor : public BaseActivationFunctor { // dx = dout / x @@ -4652,6 +4794,18 @@ struct CudaLogGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLogGradFunctor> + : public BaseActivationFunctor> { + // dx = dout / conj(x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaLog1pFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -4665,6 +4819,17 @@ struct CudaLog1pFunctor : public BaseActivationFunctor { } }; +template +struct CudaLog1pFunctor> + : public BaseActivationFunctor> { + // log1p(x) = log(1 + x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>( + log(static_cast>(1) + arg_x)); + } +}; + template struct CudaLog1pGradFunctor : public BaseActivationFunctor { T one = static_cast(1.0f); @@ -4677,6 +4842,20 @@ struct CudaLog1pGradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLog1pGradFunctor> + : public BaseActivationFunctor> { + ComplexType one = static_cast>(1.0f); + + // dx = dout / conj(1 + x) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(one + x); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template __device__ __forceinline__ std::conditional_t::value, float, T> @@ -4709,6 +4888,17 @@ struct CudaLog2Functor : public BaseActivationFunctor { } }; +template +struct CudaLog2Functor> + : public BaseActivationFunctor> { + // log2(x) = log(x)/log(2) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>(log(arg_x) / + static_cast>(log(2.0f))); + } +}; + template struct CudaLog2GradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -4722,6 +4912,18 @@ struct CudaLog2GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLog2GradFunctor> + : public BaseActivationFunctor> { + // dx = dout / conj(x * log(2)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(x * static_cast>(log(2.0f))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template __device__ __forceinline__ std::conditional_t::value, float, T> @@ -4754,6 +4956,17 @@ struct CudaLog10Functor : public BaseActivationFunctor { } }; +template +struct CudaLog10Functor> + : public BaseActivationFunctor> { + // log10(x) = log(x)/log(10) + __device__ __forceinline__ ComplexType operator()( + const ComplexType arg_x) const { + return static_cast>(log(arg_x) / + static_cast>(log(10.0f))); + } +}; + template struct CudaLog10GradFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; @@ -4767,6 +4980,18 @@ struct CudaLog10GradFunctor : public BaseActivationFunctor { static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } }; +template +struct CudaLog10GradFunctor> + : public BaseActivationFunctor> { + // dx = dout / conj(x * log(10)) + __device__ __forceinline__ ComplexType operator()( + const ComplexType dout, const ComplexType x) const { + return dout / conj(x * static_cast>(log(10.0f))); + } + + static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; } +}; + template struct CudaSwishFunctor : public BaseActivationFunctor { using MPType = typename phi::dtype::MPTypeTrait::Type; diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index 19f2fa1f2fac4..45a1024339ba3 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -52,7 +52,6 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, "Axis should be less than or equal to %d, but received axis is %d.", max_dim, axis)); - if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { @@ -68,7 +67,6 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis); std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array); } - for (int i = 0; i < max_dim; ++i) { PADDLE_ENFORCE_EQ( x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.h b/paddle/phi/kernels/funcs/concat_and_split_functor.h index 9e3f663cb419c..562f85041e663 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.h +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.h @@ -40,7 +40,8 @@ namespace funcs { * [5,6]] */ template -struct ConcatFunctor { +class ConcatFunctor { + public: void operator()(const Context& context, const std::vector& input, int axis, diff --git a/paddle/phi/kernels/funcs/data_layout_transform.h b/paddle/phi/kernels/funcs/data_layout_transform.h index 4bcc96d9c2ab7..3ecfaec6e0670 100644 --- a/paddle/phi/kernels/funcs/data_layout_transform.h +++ b/paddle/phi/kernels/funcs/data_layout_transform.h @@ -83,7 +83,8 @@ void TransDataLayoutFromOneDNN(DataLayout in_layout, DenseTensor* out, Place place, bool always_copy = false); -void* GetDataFromTensor(const DenseTensor& tensor, OneDNNDataType type); +TEST_API void* GetDataFromTensor(const DenseTensor& tensor, + OneDNNDataType type); dnnl::memory::desc make_memory_desc(const phi::DenseTensor& ref_tensor, phi::DataLayout target_layout); diff --git a/paddle/phi/kernels/funcs/detection/poly_util.h b/paddle/phi/kernels/funcs/detection/poly_util.h index 608f373f3d6a3..38a8ed8357c35 100644 --- a/paddle/phi/kernels/funcs/detection/poly_util.h +++ b/paddle/phi/kernels/funcs/detection/poly_util.h @@ -80,7 +80,7 @@ void Array2Poly(const T* box, template void PointVec2Poly(const std::vector>& vec, phi::funcs::gpc_polygon* poly) { - int pts_num = vec.size(); + size_t pts_num = vec.size(); (*poly).num_contours = 1; (*poly).hole = reinterpret_cast(malloc(sizeof(int))); (*poly).hole[0] = 0; diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 03bc6ca85efed..463272a37c00d 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -368,7 +368,7 @@ void DropoutFwGPUKernelDriver( phi::backends::gpu::CUDAGraphNodeLauncher::parameterSetter_t parameterSetter = [offset, dev_ctx_p, state_index, is_fix_seed]( - phi::backends::gpu::CUDAKernelParams& params) { + phi::backends::gpu::gpuKernelParams& params) { if (!is_fix_seed) { // we assume seed is null pointer // seed copy to cpu is meaningless here @@ -389,7 +389,7 @@ void DropoutFwGPUKernelDriver( } }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void* functionPtr = reinterpret_cast(&(VectorizedRandomGenerator)); diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index 04e13a6799931..0bf9d37d60e4a 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -73,7 +73,9 @@ struct EigenBroadcastGrad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR; \ - template struct FUNCTOR + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::bfloat16); diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index 0c5a3408872c4..fe16588c9bce6 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -72,7 +72,9 @@ struct EigenBroadcastGrad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR; \ - template struct FUNCTOR + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::bfloat16); diff --git a/paddle/phi/kernels/funcs/fused_gemm_epilogue.h b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h index d490b0abdff62..a81912ca1a8b7 100644 --- a/paddle/phi/kernels/funcs/fused_gemm_epilogue.h +++ b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h @@ -646,7 +646,6 @@ void ComputeFusedGemmEpilogueBackwardImplDev( // NOTE(zengjinle): I do not know whether the 4MB workspace size is // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. size_t workspace_size = static_cast(4) * 1024 * 1024; - const cublasLtMatmulAlgo_t* algo = nullptr; cudaStream_t stream = dev_ctx.stream(); MT alpha = static_cast(1.0); diff --git a/paddle/phi/kernels/funcs/gather.cu.h b/paddle/phi/kernels/funcs/gather.cu.h index a112680cf7dd0..b05500caba064 100644 --- a/paddle/phi/kernels/funcs/gather.cu.h +++ b/paddle/phi/kernels/funcs/gather.cu.h @@ -301,7 +301,7 @@ void GatherV2GradCUDAFunction(const DenseTensor* input, auto* out_data = ctx.Alloc(out); auto out_dim = out->dims(); int64_t out_index_dim_size = out_dim[axis_index]; - phi::funcs::set_constant(ctx, out, static_cast(0.0)); + phi::funcs::set_constant(ctx, out, static_cast(0.0)); auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, input_size); auto stream = ctx.stream(); diff --git a/paddle/phi/kernels/funcs/gather.h b/paddle/phi/kernels/funcs/gather.h index fb4e91f9b9b13..b637ef1f6f05d 100644 --- a/paddle/phi/kernels/funcs/gather.h +++ b/paddle/phi/kernels/funcs/gather.h @@ -247,7 +247,8 @@ void GatherV2GradFunction(const phi::CPUContext& ctx, auto* out_data = ctx.Alloc(out); auto out_dim = out->dims(); int64_t out_index_dim_size = out_dim[axis_index]; - phi::funcs::set_constant(ctx, out, static_cast(0.0)); + // set_constant only supports input of type float value + phi::funcs::set_constant(ctx, out, static_cast(0.0)); for (int64_t i = 0; i < inner_dim_size; i++) { for (int64_t j = 0; j < input_index_dim_size; j++) { diff --git a/paddle/phi/kernels/funcs/index_put_utils.h b/paddle/phi/kernels/funcs/index_put_utils.h index 983d33bedc72c..bc6eeb3382f3f 100644 --- a/paddle/phi/kernels/funcs/index_put_utils.h +++ b/paddle/phi/kernels/funcs/index_put_utils.h @@ -186,7 +186,7 @@ template T** GetDevicePointerArray(const Context& ctx, const std::vector& indices_v) { std::vector h_indices_v(indices_v.size()); - for (int i = 0; i < indices_v.size(); ++i) { + for (size_t i = 0; i < indices_v.size(); ++i) { h_indices_v[i] = indices_v[i]->data(); } auto d_indices_data = phi::memory_utils::Alloc( diff --git a/paddle/phi/kernels/funcs/jit/gen/blas.cc b/paddle/phi/kernels/funcs/jit/gen/blas.cc index 8c287efcf5ddd..1e29b7f4953fe 100644 --- a/paddle/phi/kernels/funcs/jit/gen/blas.cc +++ b/paddle/phi/kernels/funcs/jit/gen/blas.cc @@ -104,7 +104,7 @@ void VXXJitCode::genCode() { } else { vmovss(ptr[param3 + offset], xmm_dst); } - offset += sizeof(float) * block; + offset += sizeof(float) * block; // NOLINT rest -= block; } ret(); diff --git a/paddle/phi/kernels/funcs/jit/gen/gru.cc b/paddle/phi/kernels/funcs/jit/gen/gru.cc index 599564f431497..33dfaa6cd097c 100644 --- a/paddle/phi/kernels/funcs/jit/gen/gru.cc +++ b/paddle/phi/kernels/funcs/jit/gen/gru.cc @@ -39,7 +39,7 @@ void GRUJitCode::genCode() { vmovaps(ymm_one, ptr[reg_ptr_tmp + OFFSET_EXP_ONE]); } int offset = 0; - int d = num_ * sizeof(float); + int d = num_ * sizeof(float); // NOLINT for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { ymm_t ymm_u = ymm_t(1); ymm_t ymm_r = ymm_t(2); diff --git a/paddle/phi/kernels/funcs/jit/gen/lstm.cc b/paddle/phi/kernels/funcs/jit/gen/lstm.cc index e22a5a2880dff..4943989a50c79 100644 --- a/paddle/phi/kernels/funcs/jit/gen/lstm.cc +++ b/paddle/phi/kernels/funcs/jit/gen/lstm.cc @@ -42,7 +42,7 @@ void LSTMJitCode::genCode() { } int offset = 0; - int d = num_ * sizeof(float); + int d = num_ * sizeof(float); // NOLINT for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) { /* gates: W_ch, W_ih, W_fh, W_oh */ ymm_t ymm_c = ymm_t(0); diff --git a/paddle/phi/kernels/funcs/jit/kernel_base.h b/paddle/phi/kernels/funcs/jit/kernel_base.h index b8a638b48fc8d..e08f7821793c0 100644 --- a/paddle/phi/kernels/funcs/jit/kernel_base.h +++ b/paddle/phi/kernels/funcs/jit/kernel_base.h @@ -119,7 +119,7 @@ DECLARE_KERNELTUPLE(XYNTuple, VSigmoid); DECLARE_KERNELTUPLE(XYNTuple, VTanh); DECLARE_KERNELTUPLE(XYNTuple, VCopy); -typedef struct { +typedef struct lstm_t { void* gates; // gates: x_ch, x_ih, x_fh, x_oh const void* ct_1; void* ct; diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index e5af38b4d2b79..3d69d11c4f839 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -2454,7 +2454,7 @@ class MaxPool3dWithIndexFunctor { int thread_y = 8; int thread_z = 1; dim3 threads(thread_x, thread_y, thread_z); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); int block_x = (output_width + threads.x - 1) / threads.x; int block_y = (output_height + threads.y - 1) / threads.y; int block_z = (ncd > max_grid_dim[2] * threads.z) @@ -2535,7 +2535,7 @@ class MaxPool3dWithIndexGradFunctor { int thread_y = 8; int thread_z = 1; dim3 threads(thread_x, thread_y, thread_z); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); int block_x = (output_width + threads.x - 1) / threads.x; int block_y = (output_height + threads.y - 1) / threads.y; int block_z = (ncd > max_grid_dim[2] * threads.z) @@ -2767,7 +2767,7 @@ class FractionalMaxPool2dFunctor { int thread_y = 1; int thread_z = 1; dim3 threads(thread_x, thread_y, thread_z); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); int block_x = (output_width + threads.x - 1) / threads.x; int block_y = (ncd > max_grid_dim[1] * threads.y) ? max_grid_dim[1] @@ -2839,7 +2839,7 @@ class FractionalMaxPool2dGradFunctor { int thread_y = 1; int thread_z = 1; dim3 threads(thread_x, thread_y, thread_z); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); int block_x = (output_width + threads.x - 1) / threads.x; int block_y = (ncd > max_grid_dim[1] * threads.y) ? max_grid_dim[1] @@ -3105,7 +3105,7 @@ class FractionalMaxPool3dFunctor { int thread_y = 8; int thread_z = 1; dim3 threads(thread_x, thread_y, thread_z); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); int block_x = (output_width + threads.x - 1) / threads.x; int block_y = (output_height + threads.y - 1) / threads.y; int block_z = (ncd > max_grid_dim[2] * threads.z) @@ -3183,7 +3183,7 @@ class FractionalMaxPool3dGradFunctor { int thread_y = 8; int thread_z = 1; dim3 threads(thread_x, thread_y, thread_z); - std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); + std::array max_grid_dim = context.GetCUDAMaxGridDimSize(); int block_x = (output_width + threads.x - 1) / threads.x; int block_y = (output_height + threads.y - 1) / threads.y; int block_z = (ncd > max_grid_dim[2] * threads.z) diff --git a/paddle/phi/kernels/funcs/segmented_array.h b/paddle/phi/kernels/funcs/segmented_array.h index e6ecb9819e505..4b4b1b59db66e 100644 --- a/paddle/phi/kernels/funcs/segmented_array.h +++ b/paddle/phi/kernels/funcs/segmented_array.h @@ -118,7 +118,7 @@ struct ArraySetterBase { phi::Stream(reinterpret_cast(ctx.stream()))); int8_t* restored = reinterpret_cast(src); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (use_cuda_graph) { restored = phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph( restored, num_bytes); diff --git a/paddle/phi/kernels/funcs/selected_rows_functor.cc b/paddle/phi/kernels/funcs/selected_rows_functor.cc index b37b5bec78d2f..b370c80311882 100644 --- a/paddle/phi/kernels/funcs/selected_rows_functor.cc +++ b/paddle/phi/kernels/funcs/selected_rows_functor.cc @@ -856,7 +856,6 @@ struct MergeAverage { auto input_height = has_value_input->height(); phi::SelectedRows& out = *output; std::set merged_row_set; - size_t row_num = 0; for (auto* input : inputs) { if (input->rows().empty()) { continue; @@ -870,7 +869,6 @@ struct MergeAverage { input_height, input->height(), phi::errors::InvalidArgument("All input should have same height.")); - row_num += input->rows().size(); merged_row_set.insert(input->rows().begin(), input->rows().end()); } diff --git a/paddle/phi/kernels/funcs/sequence_pooling.cc b/paddle/phi/kernels/funcs/sequence_pooling.cc index 004bef522ab16..f4ee9c323366e 100644 --- a/paddle/phi/kernels/funcs/sequence_pooling.cc +++ b/paddle/phi/kernels/funcs/sequence_pooling.cc @@ -417,7 +417,7 @@ class SequencePoolFunctor { int64_t h = static_cast(lod[i + 1] - lod[i]); auto in_e = EigenMatrix::From(in_t, common::make_ddim({h, w})); auto out_e = EigenVector::Flatten(out_t); - if (pooltype == "AVERAGE") { + if (pooltype == "AVERAGE") { // NOLINT out_e.device(place) = in_e.mean(Eigen::array({{0}})); } else if (pooltype == "SQRT") { out_e.device(place) = in_e.sum(Eigen::array({{0}})) / diff --git a/paddle/phi/kernels/funcs/strided_utils.h b/paddle/phi/kernels/funcs/strided_utils.h new file mode 100644 index 0000000000000..0842b52d7af9f --- /dev/null +++ b/paddle/phi/kernels/funcs/strided_utils.h @@ -0,0 +1,155 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/phi/backends/context_pool.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/core/visit_type.h" +#include "paddle/phi/kernels/contiguous_kernel.h" +#include "paddle/phi/kernels/fill_kernel.h" +#include "paddle/phi/kernels/strided_copy_kernel.h" + +namespace phi { +template +inline void StridedTensorCopy(const phi::DenseTensor& input, + const std::vector& dims, + const std::vector& out_stride, + int64_t offset, + phi::DenseTensor* out) { + auto& pool = phi::DeviceContextPool::Instance(); + if (input.place().GetType() == phi::AllocationType::CPU) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + phi::StridedCopyKernel( + *dev_ctx, input, dims, out_stride, offset, out); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + } else if (input.place().GetType() == phi::AllocationType::GPU) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + phi::StridedCopyKernel( + *dev_ctx, input, dims, out_stride, offset, out); +#endif +#ifdef PADDLE_WITH_XPU + } else if (input.place().GetType() == phi::AllocationType::XPU) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + phi::StridedCopyKernel( + *dev_ctx, input, dims, out_stride, offset, out); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (input.place().GetType() == phi::AllocationType::CUSTOM) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + const phi::KernelKey& strided_copy_key = { + phi::TransToPhiBackend(dev_ctx->GetPlace()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; + using strided_copy_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const std::vector&, + const std::vector&, + int64_t, + phi::DenseTensor*); + PD_VISIT_KERNEL("strided_copy", + strided_copy_key, + strided_copy_signature, + false, + *dev_ctx, + input, + dims, + out_stride, + offset, + out); +#endif + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Place type is not supported when `strided_copy` kernel is called.")); + } +} + +template +inline void StridedTensorFill(const phi::DenseTensor& x, + const phi::Scalar& value, + phi::DenseTensor* out) { + auto& pool = phi::DeviceContextPool::Instance(); + if (x.place().GetType() == phi::AllocationType::CPU) { + auto* dev_ctx = static_cast(pool.Get(x.place())); + phi::FillKernel(*dev_ctx, x, value, out); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + } else if (x.place().GetType() == phi::AllocationType::GPU) { + auto* dev_ctx = static_cast(pool.Get(x.place())); + phi::FillKernel(*dev_ctx, x, value, out); +#endif +#ifdef PADDLE_WITH_XPU + } else if (x.place().GetType() == phi::AllocationType::XPU) { + auto* dev_ctx = static_cast(pool.Get(x.place())); + phi::FillKernel(*dev_ctx, x, value, out); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (x.place().GetType() == phi::AllocationType::CUSTOM) { + auto* dev_ctx = static_cast(pool.Get(x.place())); + const phi::KernelKey& fill_key = { + phi::TransToPhiBackend(dev_ctx->GetPlace()), + phi::DataLayout::ALL_LAYOUT, + x.dtype()}; + using fill_signature = void (*)(const phi::DeviceContext&, + const phi::DenseTensor&, + const phi::Scalar&, + phi::DenseTensor*); + PD_VISIT_KERNEL( + "fill", fill_key, fill_signature, false, *dev_ctx, x, value, out); +#endif + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Place type is not supported when `fill` kernel is called.")); + } +} + +template +inline void StridedTensorContiguous(const phi::DenseTensor& input, + phi::DenseTensor* out) { + auto& pool = phi::DeviceContextPool::Instance(); + if (input.place().GetType() == phi::AllocationType::CPU) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + phi::ContiguousKernel(*dev_ctx, input, out); +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + } else if (input.place().GetType() == phi::AllocationType::GPU) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + phi::ContiguousKernel(*dev_ctx, input, out); +#endif +#ifdef PADDLE_WITH_XPU + } else if (input.place().GetType() == phi::AllocationType::XPU) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + phi::ContiguousKernel(*dev_ctx, input, out); +#endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + } else if (input.place().GetType() == phi::AllocationType::CUSTOM) { + auto* dev_ctx = static_cast(pool.Get(input.place())); + const phi::KernelKey& contiguous_key = { + phi::TransToPhiBackend(dev_ctx->GetPlace()), + phi::DataLayout::ALL_LAYOUT, + input.dtype()}; + using contiguous_signature = void (*)( + const phi::DeviceContext&, const phi::DenseTensor&, phi::DenseTensor*); + PD_VISIT_KERNEL("contiguous", + contiguous_key, + contiguous_signature, + false, + *dev_ctx, + input, + out); +#endif + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "Place type is not supported when `contiguous` kernel is called.")); + } +} +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc b/paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc index 56107c31d6d9c..0d3189187351c 100644 --- a/paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc +++ b/paddle/phi/kernels/fusion/cpu/self_dp_attention_kernel.cc @@ -161,8 +161,8 @@ void sgemm(const float* A, int ldc = n; float alpha = 1; float beta = 0; - char ta[] = "N"; - char tb[] = "N"; + std::array ta = {"N"}; + std::array tb = {"N"}; if (transa) ta[0] = 'T'; if (transb) tb[0] = 'T'; diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/CMakeLists.txt b/paddle/phi/kernels/fusion/cutlass/conv2d/CMakeLists.txt index cd82bbf1dc8b7..b77a565121bee 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/CMakeLists.txt +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/CMakeLists.txt @@ -21,15 +21,17 @@ execute_process(COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/generated_tmp") execute_process( - COMMAND ${PYTHON_EXECUTABLE} "${CMAKE_CURRENT_SOURCE_DIR}/conv2d_bias_act.py" + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/conv2d_bias_act.py + --cuda_arch ${COMPUTE_CAPABILITY} + COMMAND + ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/conv2d_bias_residual.py + --cuda_arch ${COMPUTE_CAPABILITY} COMMAND ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_SOURCE_DIR}/conv2d_bias_residual.py" - COMMAND ${PYTHON_EXECUTABLE} - "${CMAKE_CURRENT_SOURCE_DIR}/conv2d_depthwise_bias_act.py" + ${CMAKE_CURRENT_SOURCE_DIR}/conv2d_depthwise_bias_act.py WORKING_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}") find_package(CUDA) - +# you can append -std=c++17 in CUDA_NVCC_FLAGS for compiling cutlass 3.0 set(CUDA_NVCC_FLAGS -gencode arch=compute_${COMPUTE_CAPABILITY},code=sm_${COMPUTE_CAPABILITY};) #set(CMAKE_CXX_FLAGS -fvisibility=hidden) diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/README.md b/paddle/phi/kernels/fusion/cutlass/conv2d/README.md index a717b3d692b91..4a2b6c6ac61aa 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/README.md +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/README.md @@ -23,3 +23,9 @@ compile.sh 脚本中会下载cutlass,执行CMakeLists.txt脚本,编译生成 step2. step1执行后,就可以看到在 build 目录生成了 `libCutlassConv2d.so` ,并将build目录添加到LD_LIBRARY_PATH中即可使用此库。 + + +step3. + +默认情况下,在处理conv2d类算子时,Paddle Inference 会调用cuDNN实现; +基于 cutlass 开发的conv2d类算子能够融合更多的后处理算子,用户可以通过python API `exp_enable_use_cutlass()` 和 C++ API `Exp_EnableUseCutlass()`来获得一定的速度和显存收益。 diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/compile.sh b/paddle/phi/kernels/fusion/cutlass/conv2d/compile.sh index 44c0fdf3a04da..d43bda262f543 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/compile.sh +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/compile.sh @@ -25,7 +25,7 @@ fi python_exe_path="python" cuda_root_path="/usr/local/cuda" -gpu_cc="75" +gpu_cc="80" cd $build_directory cmake .. -DPYTHON_EXECUTABLE=$python_exe_path -DCUDA_TOOLKIT_ROOT_DIR=$cuda_root_path -DCOMPUTE_CAPABILITY=$gpu_cc diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py index 0cb925489f14a..9dd7e98a4109b 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py @@ -21,7 +21,7 @@ CommonTail, GenerateFunctionForPhi, ) -from util import SubstituteTemplate, TileDesc +from util import SubstituteTemplate, TileDesc, parse_args, write_kernel_to_file # this is a file's header part @@ -54,10 +54,10 @@ + ''' typename ImplicitGemm::Arguments arguments{ problem_size, - {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}}, - {(cutlass::half_t *)(bias), {0, 0, 0}}, - {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {input, {ic, ic * iw, ic * iw * ih}}, + {weight, {kc, kc * kw, kc * kw * kh}}, + {bias, {0, 0, 0}}, + {output, {oc, oc * ow, oc * ow * oh}}, {1.f, 1.f}}; ''' + CommonCutlassConvKernelExecute @@ -170,10 +170,11 @@ def generate_sm75_1688(): sm75_code = "" for epi_func in SupportedAct: op_dict = {} - op_dict["func_name"] = UnderScoreName[epi_func].lower() + "_sm75" + op_dict["func_name"] = UnderScoreName[epi_func].lower() + "_sm75_fp16" op_dict["enum_op_name"] = UnderScoreName[epi_func].upper() # For a function, we record all its kernels into a std::vector in C++ code all_kernel_names = "" + all_kernel_declares = "" kernel_dict["epi_func"] = ActTag[epi_func] suffix = 0 for iterator_algorithm in iterator_algorithms: @@ -203,23 +204,291 @@ def generate_sm75_1688(): cba_kernel = cba_kernel_no_alpha if epi_func in [CbaAct.LeakyRelu]: cba_kernel = cba_kernel_alpha - sm75_code += SubstituteTemplate(cba_kernel, kernel_dict) + # sm75_code += SubstituteTemplate(cba_kernel, kernel_dict) + + kernel_str = ( + cba_header + + SubstituteTemplate(cba_kernel, kernel_dict) + + CommonTail + ) + file_name = ( + "generated_tmp/" + + kernel_dict["kernel_func_name"] + + ".cu" + ) + write_kernel_to_file(kernel_str, file_name) + all_kernel_names += ( kernel_dict["kernel_func_name"] + ", \n" ) + all_kernel_declares += ( + "cutlass::Status " + + kernel_dict["kernel_func_name"] + + "(const ConvAllParams& params);" + ) # Generate op code + op_dict["kernel_func_declare"] = all_kernel_declares op_dict["all_kernel_func_name"] = all_kernel_names sm75_code += SubstituteTemplate(CommonConvFunction, op_dict) return sm75_code +def generate_sm80_16816(cutlass_dtype="cutlass::half_t"): + kernel_dict = { + "element_a": cutlass_dtype, + "layout_a": "cutlass::layout::TensorNHWC", + "element_b": cutlass_dtype, + "layout_b": "cutlass::layout::TensorNHWC", + "element_c": cutlass_dtype, + "layout_c": "cutlass::layout::TensorNHWC", + "opcode_class": "cutlass::arch::OpClassTensorOp", + "arch": "cutlass::arch::Sm80", + "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + # alpha is always float! + "element_epilogue": "float", + "math_operator": "cutlass::arch::OpMultiplyAdd", + } + + kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided" + + # iterate over this loop + iterator_algorithms = [ + "cutlass::conv::IteratorAlgorithm::kOptimized", + ] + + math_instructions = [ + ( + "16,8,16", + cutlass_dtype, + cutlass_dtype, + "float", + ), + ] + + alignments = [8] + + kernel_dict["align_a"] = "8" + kernel_dict["align_b"] = "8" + # this should divided by oc + kernel_dict["epilogue_vector_length"] = "8" + kernel_dict["split_k_slices"] = "1" + + sm80_code = "" + for epi_func in SupportedAct: + op_dict = {} + op_dict["func_name"] = ( + UnderScoreName[epi_func].lower() + + "_sm80_" + + ("fp16" if "half" in cutlass_dtype else "bf16") + ) + op_dict["enum_op_name"] = UnderScoreName[epi_func].upper() + # For a function, we record all its kernels into a std::vector in C++ code + all_kernel_names = "" + all_kernel_declares = "" + kernel_dict["epi_func"] = ActTag[epi_func] + suffix = 0 + for iterator_algorithm in iterator_algorithms: + for alignment in alignments: + for math_inst in math_instructions: + tiles = [ + TileDesc("256, 128, 32", 3, "64, 64, 32", math_inst), + TileDesc("128, 256, 32", 3, "64, 64, 32", math_inst), + TileDesc("256, 64, 32", 3, "64, 64, 32", math_inst), + TileDesc("256, 64, 32", 4, "64, 64, 32", math_inst), + TileDesc("64, 256, 32", 4, "64, 64, 32", math_inst), + TileDesc("128, 128, 32", 3, "64, 64, 32", math_inst), + TileDesc("128, 128, 32", 4, "64, 64, 32", math_inst), + TileDesc("128, 128, 32", 5, "64, 64, 32", math_inst), + TileDesc("128, 64, 32", 6, "64, 32, 32", math_inst), + TileDesc("64, 128, 32", 6, "32, 64, 32", math_inst), + TileDesc("64, 64, 32", 10, "32, 32, 32", math_inst), + TileDesc("256, 128, 64", 3, "64, 64, 64", math_inst), + TileDesc("128, 256, 64", 3, "64, 64, 64", math_inst), + TileDesc("256, 64, 64", 4, "64, 64, 64", math_inst), + TileDesc("64, 256, 64", 4, "64, 64, 64", math_inst), + TileDesc("128, 128, 64", 4, "64, 64, 64", math_inst), + TileDesc("256, 64, 64", 3, "64, 64, 64", math_inst), + TileDesc("64, 256, 64", 3, "64, 64, 64", math_inst), + TileDesc("128, 128, 64", 3, "64, 64, 64", math_inst), + TileDesc("128, 64, 64", 3, "64, 32, 64", math_inst), + TileDesc("64, 128, 64", 3, "32, 64, 64", math_inst), + TileDesc("64, 64, 64", 5, "32, 32, 64", math_inst), + ] + for tile in tiles: + kernel_dict["iterator_algorithm"] = iterator_algorithm + kernel_dict["Tshape"] = tile.Tshape + kernel_dict["Wshape"] = tile.Wshape + kernel_dict["Ishape"] = tile.math_inst[0] + kernel_dict["stages"] = str(tile.stages) + kernel_dict["element_accum"] = tile.math_inst[3] + kernel_dict["kernel_func_name"] = op_dict[ + "func_name" + ] + str(suffix) + suffix += 1 + cba_kernel = cba_kernel_no_alpha + if epi_func in [CbaAct.LeakyRelu]: + cba_kernel = cba_kernel_alpha + # sm80_code += SubstituteTemplate(cba_kernel, kernel_dict) + + kernel_str = ( + cba_header + + SubstituteTemplate(cba_kernel, kernel_dict) + + CommonTail + ) + file_name = ( + "generated_tmp/" + + kernel_dict["kernel_func_name"] + + ".cu" + ) + write_kernel_to_file(kernel_str, file_name) + + all_kernel_names += ( + kernel_dict["kernel_func_name"] + ", \n" + ) + + all_kernel_declares += ( + "cutlass::Status " + + kernel_dict["kernel_func_name"] + + "(const ConvAllParams& params);" + ) + + # Generate op code + op_dict["kernel_func_declare"] = all_kernel_declares + op_dict["all_kernel_func_name"] = all_kernel_names + sm80_code += SubstituteTemplate(CommonConvFunction, op_dict) + return sm80_code + + +# hers is sm80 tf32. +def generate_sm80_1688(cutlass_dtype="cutlass::tfloat32_t"): + kernel_dict = { + "element_a": cutlass_dtype, + "layout_a": "cutlass::layout::TensorNHWC", + "element_b": cutlass_dtype, + "layout_b": "cutlass::layout::TensorNHWC", + "element_c": cutlass_dtype, + "layout_c": "cutlass::layout::TensorNHWC", + "opcode_class": "cutlass::arch::OpClassTensorOp", + "arch": "cutlass::arch::Sm80", + "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + # alpha is always float! + "element_epilogue": "float", + "math_operator": "cutlass::arch::OpMultiplyAdd", + } + + kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided" + + # iterate over this loop + iterator_algorithms = [ + "cutlass::conv::IteratorAlgorithm::kOptimized", + ] + + math_instructions = [ + ( + "16,8,8", + cutlass_dtype, + cutlass_dtype, + "float", + ), + ] + + alignments = [4] + + kernel_dict["align_a"] = "4" + kernel_dict["align_b"] = "4" + # this should divided by oc + kernel_dict["epilogue_vector_length"] = "4" + kernel_dict["split_k_slices"] = "1" + + sm80_code = "" + for epi_func in SupportedAct: + op_dict = {} + op_dict["func_name"] = UnderScoreName[epi_func].lower() + "_sm80_fp32" + op_dict["enum_op_name"] = UnderScoreName[epi_func].upper() + # For a function, we record all its kernels into a std::vector in C++ code + all_kernel_names = "" + all_kernel_declares = "" + kernel_dict["epi_func"] = ActTag[epi_func] + suffix = 0 + for iterator_algorithm in iterator_algorithms: + for alignment in alignments: + for math_inst in math_instructions: + tiles = [ + TileDesc("128, 128, 16", 4, "32, 64, 16", math_inst), + TileDesc("128, 128, 16", 3, "32, 64, 16", math_inst), + TileDesc("256, 64, 16", 3, "64, 32, 16", math_inst), + TileDesc("64, 256, 16", 3, "32, 64, 16", math_inst), + TileDesc("128, 64, 16", 4, "64, 32, 16", math_inst), + TileDesc("64, 128, 16", 4, "32, 64, 16", math_inst), + TileDesc("64, 64, 16", 3, "32, 32, 16", math_inst), + TileDesc("128, 128, 32", 3, "32, 64, 32", math_inst), + TileDesc("256, 64, 32", 3, "64, 32, 32", math_inst), + TileDesc("64, 256, 32", 3, "32, 64, 32", math_inst), + TileDesc("128, 64, 32", 3, "64, 32, 32", math_inst), + TileDesc("64, 128, 32", 3, "32, 64, 32", math_inst), + TileDesc("64, 64, 32", 3, "32, 32, 32", math_inst), + ] + for tile in tiles: + kernel_dict["iterator_algorithm"] = iterator_algorithm + kernel_dict["Tshape"] = tile.Tshape + kernel_dict["Wshape"] = tile.Wshape + kernel_dict["Ishape"] = tile.math_inst[0] + kernel_dict["stages"] = str(tile.stages) + kernel_dict["element_accum"] = tile.math_inst[3] + kernel_dict["kernel_func_name"] = op_dict[ + "func_name" + ] + str(suffix) + suffix += 1 + cba_kernel = cba_kernel_no_alpha + if epi_func in [CbaAct.LeakyRelu]: + cba_kernel = cba_kernel_alpha + kernel_str = ( + cba_header + + SubstituteTemplate(cba_kernel, kernel_dict) + + CommonTail + ) + file_name = ( + "generated_tmp/" + + kernel_dict["kernel_func_name"] + + ".cu" + ) + write_kernel_to_file(kernel_str, file_name) + all_kernel_names += ( + kernel_dict["kernel_func_name"] + ", \n" + ) + all_kernel_declares += ( + "cutlass::Status " + + kernel_dict["kernel_func_name"] + + "(const ConvAllParams& params);" + ) + + # Generate op code + op_dict["kernel_func_declare"] = all_kernel_declares + op_dict["all_kernel_func_name"] = all_kernel_names + sm80_code += SubstituteTemplate(CommonConvFunction, op_dict) + + return sm80_code + + if __name__ == "__main__": - sm_versions = ["75"] + sm_versions_and_types = [] + args = parse_args() + all_code = cba_header - all_code += generate_sm75_1688() + if args.cuda_arch == "75": + sm_versions_and_types.append(["75", "fp16"]) + all_code += generate_sm75_1688() + if args.cuda_arch in ["80", "86", "89"]: + sm_versions_and_types.append(["80", "fp16"]) + sm_versions_and_types.append(["80", "bf16"]) + sm_versions_and_types.append(["80", "fp32"]) + all_code += generate_sm80_16816() + all_code += generate_sm80_16816(cutlass_dtype="cutlass::bfloat16_t") + all_code += generate_sm80_1688(cutlass_dtype="cutlass::tfloat32_t") + all_code += GenerateFunctionForPhi( - sm_versions, SupportedAct, UnderScoreName, CamelName + sm_versions_and_types, SupportedAct, UnderScoreName, CamelName ) all_code += CommonTail with open("generated_tmp/conv2d_bias_act.cu", "w") as f: diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py index 55fde0722b6b3..e243a64e1548d 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py @@ -21,7 +21,7 @@ CommonTail, GenerateFunctionForPhi, ) -from util import SubstituteTemplate, TileDesc +from util import SubstituteTemplate, TileDesc, parse_args, write_kernel_to_file # this is a file's header part @@ -48,13 +48,12 @@ cbr_kernel = ( SubstituteTemplate(CommonCutlassConvKernelDeclare, dict_for_declare_part) + ''' - const half *residual = params.residual; typename ImplicitGemm::Arguments arguments{ problem_size, - {(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}}, - {(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}}, - {(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}}, + {input, {ic, ic * iw, ic * iw * ih}}, + {weight, {kc, kc * kw, kc * kw * kh}}, + {residual, {oc, oc * ow, oc * ow * oh}}, + {output, {oc, oc * ow, oc * ow * oh}}, {1.f, 1.f}, cutlass::conv::SplitKMode::kSerial, (cutlass::half_t *)(bias), nullptr, @@ -80,16 +79,19 @@ class CbrAct(enum.Enum): SupportedEpilogue = [ (CbrAct.Silu, "cutlass::plus", CbrAct.Identity), (CbrAct.Identity, "cutlass::plus", CbrAct.Relu), + (CbrAct.Identity, "cutlass::plus", CbrAct.Identity), ] UnderScoreName = { SupportedEpilogue[0]: "conv2d_bias_silu_add", SupportedEpilogue[1]: "conv2d_bias_add_relu", + SupportedEpilogue[2]: "conv2d_bias_add", } CamelName = { SupportedEpilogue[0]: "Conv2dBiasSiluAdd", SupportedEpilogue[1]: "Conv2dBiasAddRelu", + SupportedEpilogue[2]: "Conv2dBiasAdd", } # Generate sm75 TensorOp conv code. @@ -150,10 +152,13 @@ def generate_sm75_1688(): sm75_code = "" for epi_res_block in SupportedEpilogue: op_dict = {} - op_dict["func_name"] = UnderScoreName[epi_res_block].lower() + "_sm75" + op_dict["func_name"] = ( + UnderScoreName[epi_res_block].lower() + "_sm75_fp16" + ) op_dict["enum_op_name"] = UnderScoreName[epi_res_block].upper() # for a op, we record all its kernels into a std::vector in C++ code all_kernel_names = "" + all_kernel_declares = "" suffix = 0 for iterator_algorithm in iterator_algorithms: for alignment in alignments: @@ -188,23 +193,296 @@ def generate_sm75_1688(): kernel_dict["act2"] = ActTag[epi_res_block[2]] suffix += 1 - sm75_code += SubstituteTemplate(cbr_kernel, kernel_dict) + # sm75_code += SubstituteTemplate(cbr_kernel, kernel_dict) + + kernel_str = ( + cbr_header + + SubstituteTemplate(cbr_kernel, kernel_dict) + + CommonTail + ) + file_name = ( + "generated_tmp/" + + kernel_dict["kernel_func_name"] + + ".cu" + ) + write_kernel_to_file(kernel_str, file_name) + all_kernel_names += ( kernel_dict["kernel_func_name"] + ", \n" ) + all_kernel_declares += ( + "cutlass::Status " + + kernel_dict["kernel_func_name"] + + "(const ConvAllParams& params);" + ) - # Generate op code with sm_version + # Generate op code + op_dict["kernel_func_declare"] = all_kernel_declares op_dict["all_kernel_func_name"] = all_kernel_names sm75_code += SubstituteTemplate(CommonConvFunction, op_dict) return sm75_code +def generate_sm80_16816(cutlass_dtype="cutlass::half_t"): + kernel_dict = { + "conv_kind_name": "Fprop", + "element_a": cutlass_dtype, + "layout_a": "cutlass::layout::TensorNHWC", + "element_b": cutlass_dtype, + "layout_b": "cutlass::layout::TensorNHWC", + "element_c": cutlass_dtype, + "layout_c": "cutlass::layout::TensorNHWC", + "opcode_class": "cutlass::arch::OpClassTensorOp", + "arch": "cutlass::arch::Sm80", + "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + # alpha is always float! + "element_epilogue": "float", + "math_operator": "cutlass::arch::OpMultiplyAdd", + "element_residul": cutlass_dtype, + } + + kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided" + + # iterate over this loop + iterator_algorithms = [ + "cutlass::conv::IteratorAlgorithm::kOptimized", + ] + + math_instructions = [ + ( + "16,8,16", + cutlass_dtype, + cutlass_dtype, + "float", + ), + ] + + alignments = [8] + + kernel_dict["align_a"] = "8" + kernel_dict["align_b"] = "8" + kernel_dict["epilogue_vector_length"] = "8" + kernel_dict["split_k_slices"] = "1" + + sm80_code = "" + for epi_res_block in SupportedEpilogue: + op_dict = {} + op_dict["func_name"] = ( + UnderScoreName[epi_res_block].lower() + + "_sm80_" + + ("fp16" if "half" in cutlass_dtype else "bf16") + ) + + op_dict["enum_op_name"] = UnderScoreName[epi_res_block].upper() + # for a op, we record all its kernels into a std::vector in C++ code + all_kernel_names = "" + all_kernel_declares = "" + suffix = 0 + for iterator_algorithm in iterator_algorithms: + for alignment in alignments: + for math_inst in math_instructions: + tiles = [ + TileDesc("256, 128, 32", 3, "64, 64, 32", math_inst), + TileDesc("128, 256, 32", 3, "64, 64, 32", math_inst), + TileDesc("256, 64, 32", 3, "64, 64, 32", math_inst), + TileDesc("256, 64, 32", 4, "64, 64, 32", math_inst), + TileDesc("64, 256, 32", 4, "64, 64, 32", math_inst), + TileDesc("128, 128, 32", 3, "64, 64, 32", math_inst), + TileDesc("128, 128, 32", 4, "64, 64, 32", math_inst), + TileDesc("128, 128, 32", 5, "64, 64, 32", math_inst), + TileDesc("128, 64, 32", 6, "64, 32, 32", math_inst), + TileDesc("64, 128, 32", 6, "32, 64, 32", math_inst), + TileDesc("64, 64, 32", 10, "32, 32, 32", math_inst), + TileDesc("256, 128, 64", 3, "64, 64, 64", math_inst), + TileDesc("128, 256, 64", 3, "64, 64, 64", math_inst), + TileDesc("256, 64, 64", 4, "64, 64, 64", math_inst), + TileDesc("64, 256, 64", 4, "64, 64, 64", math_inst), + TileDesc("128, 128, 64", 4, "64, 64, 64", math_inst), + TileDesc("256, 64, 64", 3, "64, 64, 64", math_inst), + TileDesc("64, 256, 64", 3, "64, 64, 64", math_inst), + TileDesc("128, 128, 64", 3, "64, 64, 64", math_inst), + TileDesc("128, 64, 64", 3, "64, 32, 64", math_inst), + TileDesc("64, 128, 64", 3, "32, 64, 64", math_inst), + TileDesc("64, 64, 64", 5, "32, 32, 64", math_inst), + ] + + for tile in tiles: + kernel_dict["iterator_algorithm"] = iterator_algorithm + kernel_dict["Tshape"] = tile.Tshape + kernel_dict["Wshape"] = tile.Wshape + kernel_dict["Ishape"] = tile.math_inst[0] + kernel_dict["stages"] = str(tile.stages) + kernel_dict["element_accum"] = tile.math_inst[3] + kernel_dict["kernel_func_name"] = op_dict[ + "func_name" + ] + str(suffix) + kernel_dict["act1"] = ActTag[epi_res_block[0]] + kernel_dict["binary"] = epi_res_block[1] + kernel_dict["act2"] = ActTag[epi_res_block[2]] + suffix += 1 + + # sm80_code += SubstituteTemplate(cbr_kernel, kernel_dict) + kernel_str = ( + cbr_header + + SubstituteTemplate(cbr_kernel, kernel_dict) + + CommonTail + ) + file_name = ( + "generated_tmp/" + + kernel_dict["kernel_func_name"] + + ".cu" + ) + write_kernel_to_file(kernel_str, file_name) + + all_kernel_names += ( + kernel_dict["kernel_func_name"] + ", \n" + ) + all_kernel_declares += ( + "cutlass::Status " + + kernel_dict["kernel_func_name"] + + "(const ConvAllParams& params);" + ) + + # Generate op code + op_dict["kernel_func_declare"] = all_kernel_declares + op_dict["all_kernel_func_name"] = all_kernel_names + sm80_code += SubstituteTemplate(CommonConvFunction, op_dict) + return sm80_code + + +def generate_sm80_1688(cutlass_dtype="cutlass::tfloat32_t"): + kernel_dict = { + "conv_kind_name": "Fprop", + "element_a": cutlass_dtype, + "layout_a": "cutlass::layout::TensorNHWC", + "element_b": cutlass_dtype, + "layout_b": "cutlass::layout::TensorNHWC", + "element_c": cutlass_dtype, + "layout_c": "cutlass::layout::TensorNHWC", + "opcode_class": "cutlass::arch::OpClassTensorOp", + "arch": "cutlass::arch::Sm80", + "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + # alpha is always float! + "element_epilogue": "float", + "math_operator": "cutlass::arch::OpMultiplyAdd", + "element_residul": cutlass_dtype, + } + + kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided" + + # iterate over this loop + iterator_algorithms = [ + "cutlass::conv::IteratorAlgorithm::kOptimized", + ] + + math_instructions = [ + ( + "16,8,8", + cutlass_dtype, + cutlass_dtype, + "float", + ), + ] + + alignments = [4] + + kernel_dict["align_a"] = "4" + kernel_dict["align_b"] = "4" + kernel_dict["epilogue_vector_length"] = "4" + kernel_dict["split_k_slices"] = "1" + + sm80_code = "" + for epi_res_block in SupportedEpilogue: + op_dict = {} + op_dict["func_name"] = ( + UnderScoreName[epi_res_block].lower() + "_sm80_fp32" + ) + op_dict["enum_op_name"] = UnderScoreName[epi_res_block].upper() + # for a op, we record all its kernels into a std::vector in C++ code + all_kernel_names = "" + all_kernel_declares = "" + suffix = 0 + for iterator_algorithm in iterator_algorithms: + for alignment in alignments: + for math_inst in math_instructions: + tiles = [ + TileDesc("128, 128, 16", 4, "32, 64, 16", math_inst), + TileDesc("128, 128, 16", 3, "32, 64, 16", math_inst), + TileDesc("256, 64, 16", 3, "64, 32, 16", math_inst), + TileDesc("64, 256, 16", 3, "32, 64, 16", math_inst), + TileDesc("128, 64, 16", 4, "64, 32, 16", math_inst), + TileDesc("64, 128, 16", 4, "32, 64, 16", math_inst), + TileDesc("64, 64, 16", 3, "32, 32, 16", math_inst), + TileDesc("128, 128, 32", 3, "32, 64, 32", math_inst), + TileDesc("256, 64, 32", 3, "64, 32, 32", math_inst), + TileDesc("64, 256, 32", 3, "32, 64, 32", math_inst), + TileDesc("128, 64, 32", 3, "64, 32, 32", math_inst), + TileDesc("64, 128, 32", 3, "32, 64, 32", math_inst), + TileDesc("64, 64, 32", 3, "32, 32, 32", math_inst), + ] + + for tile in tiles: + kernel_dict["iterator_algorithm"] = iterator_algorithm + kernel_dict["Tshape"] = tile.Tshape + kernel_dict["Wshape"] = tile.Wshape + kernel_dict["Ishape"] = tile.math_inst[0] + kernel_dict["stages"] = str(tile.stages) + kernel_dict["element_accum"] = tile.math_inst[3] + kernel_dict["kernel_func_name"] = op_dict[ + "func_name" + ] + str(suffix) + suffix += 1 + kernel_dict["act1"] = ActTag[epi_res_block[0]] + kernel_dict["binary"] = epi_res_block[1] + kernel_dict["act2"] = ActTag[epi_res_block[2]] + + # sm80_code += SubstituteTemplate(cbr_kernel, kernel_dict) + kernel_str = ( + cbr_header + + SubstituteTemplate(cbr_kernel, kernel_dict) + + CommonTail + ) + file_name = ( + "generated_tmp/" + + kernel_dict["kernel_func_name"] + + ".cu" + ) + write_kernel_to_file(kernel_str, file_name) + + all_kernel_names += ( + kernel_dict["kernel_func_name"] + ", \n" + ) + all_kernel_declares += ( + "cutlass::Status " + + kernel_dict["kernel_func_name"] + + "(const ConvAllParams& params);" + ) + + # Generate op code + op_dict["kernel_func_declare"] = all_kernel_declares + op_dict["all_kernel_func_name"] = all_kernel_names + sm80_code += SubstituteTemplate(CommonConvFunction, op_dict) + return sm80_code + + if __name__ == "__main__": - sm_versions = ["75"] + sm_versions_and_types = [] + args = parse_args() + all_code = cbr_header - all_code += generate_sm75_1688() + if args.cuda_arch == "75": + sm_versions_and_types.append(["75", "fp16"]) + all_code += generate_sm75_1688() + if args.cuda_arch in ["80", "86", "89"]: + sm_versions_and_types.append(["80", "fp16"]) + sm_versions_and_types.append(["80", "bf16"]) + sm_versions_and_types.append(["80", "fp32"]) + all_code += generate_sm80_16816() + all_code += generate_sm80_16816(cutlass_dtype="cutlass::bfloat16_t") + all_code += generate_sm80_1688(cutlass_dtype="cutlass::tfloat32_t") + all_code += GenerateFunctionForPhi( - sm_versions, SupportedEpilogue, UnderScoreName, CamelName + sm_versions_and_types, SupportedEpilogue, UnderScoreName, CamelName ) all_code += CommonTail with open("generated_tmp/conv2d_bias_residual.cu", "w") as f: diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py index 7c95892006c43..29f9e443d9c53 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py @@ -51,10 +51,14 @@ using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - half *output = params.output; + + ${element_a} *input = (${element_a} *)(params.input); + ${element_b} *weight = (${element_b} *)(params.weight); + ${element_c} *bias = (${element_c} *)(params.bias); + ${element_c} *output = (${element_c} *)(params.output); + // only used by conv2d_bias_residual + auto residual = (${element_c} *)(params.residual); + int batch = params.batch; int ic = params.ic; int ih = params.ih; @@ -112,6 +116,9 @@ # ${enum_op_name} is like CONV2D_BIAS_SILU CommonConvFunction = """ + +${kernel_func_declare} + std::vector> ${func_name}_all_func = {${all_kernel_func_name}}; @@ -163,8 +170,17 @@ """ +def convert_c_data_type(dtype): + if dtype == "fp16": + return "Conv2dDataType::fp16" + elif dtype == "bf16": + return "Conv2dDataType::bf16" + elif dtype == "fp32": + return "Conv2dDataType::fp32" + + CommonDispatchTemp = ''' - if (params.sm_version == ${sm_code}) + if (params.sm_version == ${sm_code} && params.data_type == ${data_type}) { ${op_name_with_sm}(params); } @@ -182,16 +198,21 @@ # Wrap different sm versions into a function called by phi def GenerateFunctionForPhi( - sm_versions, support_epi_funcs, underscore_names, camel_names + sm_versions_and_types, support_epi_funcs, underscore_names, camel_names ): generated_code = "" for epi_func in support_epi_funcs: dispatch_body = "" - for sm_version in sm_versions: + for sm_version, data_type in sm_versions_and_types: sm_dicts = {} sm_dicts["sm_code"] = sm_version + sm_dicts["data_type"] = convert_c_data_type(data_type) sm_dicts["op_name_with_sm"] = ( - underscore_names[epi_func].lower() + "_sm" + sm_version + underscore_names[epi_func].lower() + + "_sm" + + sm_version + + "_" + + data_type ) dispatch_body += SubstituteTemplate(CommonDispatchTemp, sm_dicts) op_dicts = {} diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h index aaad46de5cb0d..b29ce65f5230a 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h @@ -20,12 +20,18 @@ namespace phi { namespace fusion { namespace cutlass_internal { +typedef enum { + fp32, + fp16, + bf16, +} Conv2dDataType; + typedef struct { - const half *input; - const half *weight; - const half *bias; - const half *residual; - half *output; + const void *input; + const void *weight; + const void *bias; + const void *residual; + void *output; int batch; int ic; int ih; @@ -48,6 +54,7 @@ typedef struct { cudaStream_t stream; float alpha; // for leaky_relu use int sm_version = 75; + Conv2dDataType data_type; void *workspace = nullptr; } ConvAllParams; diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_depthwise_bias_act.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_depthwise_bias_act.py index fb2f2be096110..5114d69e97060 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_depthwise_bias_act.py +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_depthwise_bias_act.py @@ -208,6 +208,7 @@ def generate_conv2d_depthwise(): ) # generate op code op_dict["all_kernel_func_name"] = all_kernel_names + op_dict["kernel_func_declare"] = ";" all_code += SubstituteTemplate(CommonConvFunction, op_dict) return all_code diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu index 51bc71983105a..6aed60cf1c23b 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu @@ -26,10 +26,11 @@ struct logical_coord { int w; }; -float diff(const half *c, const float *c_baseline, int n) { +template +float diff(const T *c, const float *c_baseline, int n) { float max_diff = -1.; for (int i = 0; i < n; i++) { - float c_value = __half2float(c[i]); + float c_value = static_cast(c[i]); if (std::abs(c_baseline[i] - c_value) > max_diff) { max_diff = std::abs(c_baseline[i] - c_value); } @@ -42,10 +43,10 @@ __device__ int gpu_nhwc(struct logical_coord shape, return index.n * shape.h * shape.w * shape.c + index.h * shape.w * shape.c + index.w * shape.c + index.c; } - -__global__ void naive_conv2d_kernel(const half *input, - const half *weight, - const half *bias, +template +__global__ void naive_conv2d_kernel(const T *input, + const T *weight, + const T *bias, float *output, int batch, int ic, @@ -63,7 +64,7 @@ __global__ void naive_conv2d_kernel(const half *input, int oh, int ow, int groups, - const half *residual, + const T *residual, float alpha, // for leaky_relu OpType op_type) { int M = batch * oh * ow; @@ -100,12 +101,12 @@ __global__ void naive_conv2d_kernel(const half *input, if (iw_i < 0 || iw_i >= iw) continue; struct logical_coord input_index = {batch_i, ic_i, ih_i, iw_i}; - const half *weight_ptr = weight + gpu_nhwc(weight_shape, weight_index); - const half *in_ptr = input + gpu_nhwc(input_shape, input_index); - sum += __half2float(*in_ptr) * __half2float(*weight_ptr); + const T *weight_ptr = weight + gpu_nhwc(weight_shape, weight_index); + const T *in_ptr = input + gpu_nhwc(input_shape, input_index); + sum += static_cast(*in_ptr) * static_cast(*weight_ptr); } - sum += __half2float(*(bias + oc_i)); + sum += static_cast(*(bias + oc_i)); float x = sum; switch (op_type) { @@ -121,10 +122,19 @@ __global__ void naive_conv2d_kernel(const half *input, case CONV2D_DEPTHWISE_BIAS_SILU: *out_ptr = x * (1.f / (1 + exp(-x))); break; + case CONV2D_BIAS_SILU_ADD: + x = x * (1.f / (1 + exp(-x))); + x += static_cast(*(residual + out_offset)); + *out_ptr = x; + break; case CONV2D_BIAS_ADD_RELU: - x += __half2float(*(residual + out_offset)); + x += static_cast(*(residual + out_offset)); *out_ptr = x > 0 ? x : 0; break; + case CONV2D_BIAS_ADD: + x += static_cast(*(residual + out_offset)); + *out_ptr = x; + break; case CONV2D_BIAS_LEAKY_RELU: *out_ptr = x > 0 ? x : (x * alpha); break; @@ -136,12 +146,12 @@ __global__ void naive_conv2d_kernel(const half *input, break; } } - -float conv2d_diff_gpu(const ConvAllParams ¶ms, OpType op_type) { - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - half *output = params.output; +template +float conv2d_diff_gpu(const ConvAllParams ¶ms, OpType op_type, T a) { + const T *input = (const T *)(params.input); + const T *weight = (const T *)(params.weight); + const T *bias = (const T *)(params.bias); + T *output = static_cast(params.output); int batch = params.batch; int ic = params.ic; int ih = params.ih; @@ -155,7 +165,7 @@ float conv2d_diff_gpu(const ConvAllParams ¶ms, OpType op_type) { int stride_w = params.stride_w; int dilation_h = params.dilation_h; int dilation_w = params.dilation_w; - const half *residual = params.residual; + const T *residual = (const T *)(params.residual); int groups = params.groups; int oh = params.oh; @@ -169,11 +179,11 @@ float conv2d_diff_gpu(const ConvAllParams ¶ms, OpType op_type) { uint3 block = {blockM, blockN, 1}; int output_size = batch * oc * oh * ow; - half *output_from_cutlass = - reinterpret_cast(malloc(sizeof(half) * output_size)); + T *output_from_cutlass = + reinterpret_cast(malloc(sizeof(T) * output_size)); cudaMemcpy(output_from_cutlass, output, - output_size * sizeof(half), + output_size * sizeof(T), cudaMemcpyDeviceToHost); float *gpu_output; @@ -207,6 +217,13 @@ float conv2d_diff_gpu(const ConvAllParams ¶ms, OpType op_type) { gpu_output, output_size * sizeof(float), cudaMemcpyDeviceToHost); + + // cudaMemcpy(output, + // gpu_output, + // output_size * sizeof(T), + // cudaMemcpyDeviceToDevice); + // cudaMemset(output, 0, output_size * sizeof(T)); + float max_diff = diff(output_from_cutlass, output_from_gpu, output_size); free(output_from_cutlass); @@ -232,6 +249,12 @@ std::string OpType2String(OpType op_type) { case CONV2D_BIAS_ADD_RELU: return "conv2d_bias_add_relu"; break; + case CONV2D_BIAS_ADD: + return "conv2d_bias_add"; + break; + case CONV2D_BIAS_SILU_ADD: + return "conv2d_bias_silu_add"; + break; case CONV2D_BIAS_LEAKY_RELU: return "conv2d_bias_leaky_relu"; case CONV2D_DEPTHWISE_BIAS: @@ -253,7 +276,7 @@ int ProfileToGetBestConfig( const ConvAllParams ¶ms, OpType op_type) { constexpr int WARMUP = 10; - constexpr int REPEAT = 100; + constexpr int REPEAT = 10; float min_time = 100000.f; int min_time_index = -1; for (int i = 0; i < all_func.size(); i++) { @@ -286,11 +309,31 @@ int ProfileToGetBestConfig( if (elapsed_time < min_time && status == cutlass::Status::kSuccess) { min_time = elapsed_time; min_time_index = i; - // debug code - std::cout << OpType2String(op_type) << ": tactic " << i - << " has max diff " << conv2d_diff_gpu(params, op_type) - << " compared with baseline," - << "cost_time: " << elapsed_time << "ms." << std::endl; + + if (params.data_type == Conv2dDataType::fp16) { + // debug code + std::cout << OpType2String(op_type) << ": tactic " << i + << " has max diff " + << conv2d_diff_gpu(params, op_type, (half)(1.0)) + << " compared with baseline," + << "cost_time: " << elapsed_time << "ms." << std::endl; + } else if (params.data_type == Conv2dDataType::bf16) { + // debug code + std::cout << OpType2String(op_type) << ": tactic " << i + << " has max diff " + << conv2d_diff_gpu( + params, op_type, static_cast(1.0)) + << " compared with baseline," + << "cost_time: " << elapsed_time << "ms." << std::endl; + } else if (params.data_type == Conv2dDataType::fp32) { + // debug code + std::cout << OpType2String(op_type) << ": tactic " << i + << " has max diff " + << conv2d_diff_gpu( + params, op_type, static_cast(1.0)) + << " compared with baseline," + << "cost_time: " << elapsed_time << "ms." << std::endl; + } } } @@ -301,11 +344,6 @@ int ProfileToGetBestConfig( return min_time_index; } -__attribute__((dllexport)) int HelloFromCutlassConv2d(int a, int b) { - std::cout << "welcom using Cutlass Conv2d" << std::endl; - return 1; -} - } // namespace cutlass_internal } // namespace fusion } // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h index 80865e0e1cded..508b8a8f1ae3b 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h @@ -37,6 +37,7 @@ typedef enum { CONV2D_BIAS, CONV2D_BIAS_RELU, CONV2D_BIAS_ADD_RELU, + CONV2D_BIAS_ADD, CONV2D_BIAS_SILU, CONV2D_BIAS_LEAKY_RELU, CONV2D_BIAS_SIGMOID, diff --git a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py index 5847956020ceb..17911e4898220 100644 --- a/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py +++ b/paddle/phi/kernels/fusion/cutlass/cutlass_kernels/fpA_intB_gemm/generic_mixed_gemm_kernelLauncher.py @@ -234,9 +234,7 @@ def generate_source_cu( for arch in archs: for epilogue_tag in EpilogueTags.keys(): for stages in StagesList[arch]: - file_name = "autogen_tmp/generic_mixed_gemm_kernelLauncher_{}_sm{}_stages{}_{}.cu".format( - element_type, arch, stages, epilogue_tag - ) + file_name = f"autogen_tmp/generic_mixed_gemm_kernelLauncher_{element_type}_sm{arch}_stages{stages}_{epilogue_tag}.cu" all_code = generate_source_cu( element_type, arch, diff --git a/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu b/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu index dceaafd2e7172..79057bee76219 100644 --- a/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu +++ b/paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu @@ -51,19 +51,53 @@ void FusedConv2dAddActKernel(const Context& ctx, auto in_dims = x.dims(); auto filter_dims = filter.dims(); auto out_dims = output->dims(); - CHECK_EQ(in_dims.size() == 4UL, true); - CHECK_EQ(filter_dims.size() == 4UL, true); - CHECK_EQ(strides.size() == 2UL, true); - CHECK_EQ(dilations.size() == 2UL, true); + PADDLE_ENFORCE_EQ( + in_dims.size(), + 4UL, + phi::errors::InvalidArgument( + "The input tensor X's dimensions should be 4, but got %d.", + in_dims.size())); + PADDLE_ENFORCE_EQ( + filter_dims.size(), + 4UL, + phi::errors::InvalidArgument( + "The input tensor filter's dimensions must be 4, but got %d.", + filter_dims.size())); + PADDLE_ENFORCE_EQ( + strides.size(), + 2UL, + phi::errors::InvalidArgument("The size of strides must be 2, but got %d.", + strides.size())); + PADDLE_ENFORCE_EQ( + dilations.size(), + 2UL, + phi::errors::InvalidArgument( + "The size of dilations must be 2, but got %d.", dilations.size())); - CHECK_EQ(padding_algorithm == "EXPLICIT", true); - CHECK_EQ(data_format == "NHWC", true); + PADDLE_ENFORCE_EQ(padding_algorithm, + "EXPLICIT", + phi::errors::InvalidArgument( + "The padding_algorithm must be EXPLICIT, but got %s.", + padding_algorithm)); + PADDLE_ENFORCE_EQ( + data_format, + "NHWC", + phi::errors::InvalidArgument("The data_format must be NHWC, but got %s.", + data_format)); const int batch = in_dims[0]; const int ic = in_dims[3]; const int ih = in_dims[1]; const int iw = in_dims[2]; - CHECK_EQ(ic == groups * filter_dims[3], true); + PADDLE_ENFORCE_EQ( + ic, + groups * filter_dims[3], + phi::errors::InvalidArgument( + "The last dimension of X (%d) must be equal to " + "groups (%d) multiply the last dimension of filter (%d).", + ic, + groups, + filter_dims[3])); int pad_h0 = 0; int pad_h1 = 0; int pad_w0 = 0; @@ -94,38 +128,79 @@ void FusedConv2dAddActKernel(const Context& ctx, const int kh = filter_dims[1]; const int kw = filter_dims[2]; - CHECK_EQ(out_dims.size() == 4UL, true); + PADDLE_ENFORCE_EQ( + out_dims.size(), + 4UL, + phi::errors::InvalidArgument( + "The output's dimensions must be 4, but got %d.", out_dims.size())); const int oh = out_dims[1]; const int ow = out_dims[2]; - ConvAllParams params = {reinterpret_cast(x.data()), - reinterpret_cast(filter.data()), - reinterpret_cast(bias.data()), - nullptr, - reinterpret_cast(output->data()), - batch, - ic, - ih, - iw, - kh, - kw, - oc, - pad_h0, - pad_h1, - pad_w0, - pad_w1, - stride_h, - stride_w, - dilation_h, - dilation_w, - oh, - ow, - groups, - ctx.stream()}; + int64_t device_id = ctx.GetPlace().GetDeviceId(); + int sm_version = backends::gpu::GetGPUComputeCapability(device_id); + + auto get_conv2d_dtype = [&](decltype(x.dtype()) x_type) + -> phi::fusion::cutlass_internal::Conv2dDataType { + switch (x_type) { + case phi::DataType::FLOAT32: + return Conv2dDataType::fp32; + case phi::DataType::FLOAT16: + return Conv2dDataType::fp16; + case phi::DataType::BFLOAT16: + return Conv2dDataType::bf16; + } + }; + + auto cutlass_dispatch_sm_version = [&](int device_sm_version) -> int { + if (device_sm_version < 75) { + PADDLE_ENFORCE_GE( + device_sm_version, + 75, + phi::errors::PreconditionNotMet( + "fused_conv2d_add_act only supports sm >= 75, but got %d.", + device_sm_version)); + } else if (device_sm_version > 80) { + return 80; + } else { + return device_sm_version; + } + }; + + ConvAllParams params = { + reinterpret_cast(x.data()), + reinterpret_cast(filter.data()), + reinterpret_cast(bias.data()), + nullptr, + reinterpret_cast(output->data()), + batch, + ic, + ih, + iw, + kh, + kw, + oc, + pad_h0, + pad_h1, + pad_w0, + pad_w1, + stride_h, + stride_w, + dilation_h, + dilation_w, + oh, + ow, + groups, + ctx.stream(), + 0, // alpha + cutlass_dispatch_sm_version(sm_version), + get_conv2d_dtype(x.dtype()), + nullptr, + }; void* dlhandler = phi::dynload::GetCutlassConv2dHandle(); func conv_func = NULL; - CHECK_EQ(dlhandler == NULL, false); + PADDLE_ENFORCE_NOT_NULL( + dlhandler, phi::errors::NotFound("Fail to get CutlassConv2d handler.")); // conv2d_depthwise if (groups == ic && ic == oc) { @@ -137,7 +212,10 @@ void FusedConv2dAddActKernel(const Context& ctx, params.workspace = tmp_ptr->ptr(); // cutlass conv2d_depthwise not support residual if (residual) { - CHECK_EQ(residual->data() == nullptr, true); + PADDLE_ENFORCE_EQ(residual->data(), + nullptr, + phi::errors::InvalidArgument( + "The pointer of residual's data must be null.")); } if (activation == "relu") { conv_func = (func)(dlsym(dlhandler, "Conv2dDepthwiseBiasRelu")); @@ -158,14 +236,19 @@ void FusedConv2dAddActKernel(const Context& ctx, } // below: fused_conv2d_add_act && groups == 1 - CHECK_EQ(groups == 1, true); + PADDLE_ENFORCE_EQ(groups, + 1, + phi::errors::InvalidArgument( + "The groups must be 1, but got %d.", groups)); if (residual) { if (activation == "relu") { - params.residual = reinterpret_cast(residual->data()); + params.residual = reinterpret_cast(residual->data()); conv_func = (func)(dlsym(dlhandler, "Conv2dBiasAddRelu")); } else { PADDLE_THROW(phi::errors::InvalidArgument( - "Cutlass now only support relu activation in a residual block")); + "Cutlass now only support relu activation in a residual block, but " + "got %s.", + activation.c_str())); } } else if (activation == "relu") { conv_func = (func)(dlsym(dlhandler, "Conv2dBiasRelu")); @@ -194,4 +277,5 @@ PD_REGISTER_KERNEL(fused_conv2d_add_act, ALL_LAYOUT, phi::fusion::cutlass_internal::FusedConv2dAddActKernel, float, + phi::dtype::bfloat16, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_backward.h b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_backward.h index 31ce0bd3574ee..2bd3ac2db5f5b 100644 --- a/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_backward.h +++ b/paddle/phi/kernels/fusion/cutlass/memory_efficient_attention/kernel_backward.h @@ -492,8 +492,6 @@ struct AttentionBackwardKernel { scalar_t, // ElementC accum_t // ElementAccumulator >; - static constexpr auto kOptimalAlignement = - std::max(DefaultConfig::kAlignmentA, DefaultConfig::kAlignmentB); static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; struct MatmulQK { diff --git a/paddle/phi/kernels/fusion/cutlass/util.py b/paddle/phi/kernels/fusion/cutlass/util.py index 200960f39c56e..d3ffb648362f6 100644 --- a/paddle/phi/kernels/fusion/cutlass/util.py +++ b/paddle/phi/kernels/fusion/cutlass/util.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import re @@ -35,3 +36,28 @@ def SubstituteTemplate(template, values): changed = True text = newtext return text + + +def parse_args(): + parser = argparse.ArgumentParser( + description="The argument for generating the conv2d_bias_act kernels." + ) + + parser.add_argument( + "--cuda_arch", + type=str, + default=None, + help="The CUDA architecture to be generated.", + ) + args = parser.parse_args() + + return args + + +def write_kernel_to_file(kernel, file_name): + with open( + file_name, + "w", + ) as f: + f.write(kernel) + f.close() diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu index 0f93e21553a74..60a82cfe7c198 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_grad_kernel.cu @@ -30,7 +30,6 @@ namespace fusion { template void FusedBiasDropoutResidualLnGradKernel( const Context& dev_ctx, - const DenseTensor& y_grad, const DenseTensor& x, const DenseTensor& residual, const paddle::optional& bias, @@ -40,6 +39,7 @@ void FusedBiasDropoutResidualLnGradKernel( const DenseTensor& ln_variance, const DenseTensor& bias_dropout_residual_out, const DenseTensor& dropout_mask_out, + const DenseTensor& y_grad, const float dropout_rate, const bool is_test, const bool dropout_fix_seed, diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu index ff6380ceeec0a..801f070251fb2 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -218,7 +218,7 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, // seed_offset_data should preserved by cudaGraph pool const phi::GPUContext* dev_ctx_p = &dev_ctx; auto parameterSetter = [offset, dev_ctx_p, seed_offset]( - phi::backends::gpu::CUDAKernelParams& params) { + phi::backends::gpu::gpuKernelParams& params) { const auto* seed_offset_data = seed_offset.data(); const uint64_t seed_data = static_cast(seed_offset_data[0]); const uint64_t increment = static_cast(seed_offset_data[1]); @@ -229,7 +229,7 @@ void FusedDropoutAddGradKernel(const Context& dev_ctx, << ", increment = " << increment; }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void* functionPtr = reinterpret_cast( &(VectorizedDropoutBackward>)); diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu index 5ec23e777211b..c95c5fbf0ca3d 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu @@ -211,7 +211,7 @@ void FusedDropoutAddKernel(const Context& dev_ctx, seed_offset_data, state_index, seed_tensor_ptr, - fix_seed](phi::backends::gpu::CUDAKernelParams& params) { + fix_seed](phi::backends::gpu::gpuKernelParams& params) { if (!fix_seed) { auto gen_cuda = dev_ctx_p->GetGenerator(); // ensure the generator use correct state index @@ -233,7 +233,7 @@ void FusedDropoutAddKernel(const Context& dev_ctx, seed_offset_data[1] = static_cast(increment); } }; - phi::backends::gpu::CUDAGraphNodeLauncher::cudaKernelCallback_t + phi::backends::gpu::CUDAGraphNodeLauncher::gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) { void* functionPtr = reinterpret_cast( &(VectorizedDropoutForward>)); diff --git a/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc b/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc index a7f9e49e32560..78fd2cfd964d7 100644 --- a/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc +++ b/paddle/phi/kernels/fusion/onednn/fused_transpose_kernel.cc @@ -34,7 +34,7 @@ void SetInMemDescWithSqueeze2FuseSupport( int j = 0; for (size_t i = 0; i < x_vec_dims.size(); ++i) { if (squeeze2_axes_set.count(i) || - squeeze2_axes_set.count(i - x_vec_dims.size())) { + squeeze2_axes_set.count(i - x_vec_dims.size())) { // NOLINT PADDLE_ENFORCE_EQ( x_vec_dims[i], 1, @@ -68,12 +68,12 @@ void FusedTransposeKernel(const Context& dev_ctx, if ((x_dims.size() >= 3) && (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC)) { - int axis_size = axis.size(); - std::vector formated_axis = axis; + int axis_size = static_cast(axis.size()); + std::vector formatted_axis = axis; std::vector count(axis_size, 0); for (int i = 0; i < axis_size; i++) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + axis_size; + formatted_axis[i] = axis[i] + axis_size; } } auto dims = common::vectorize(x_dims); @@ -85,7 +85,7 @@ void FusedTransposeKernel(const Context& dev_ctx, phi::DDim out_dims(x_dims); for (size_t i = 0; i < axis.size(); i++) { - out_dims[i] = x_dims[formated_axis[i]]; + out_dims[i] = x_dims[formatted_axis[i]]; // NOLINT } out->Resize(out_dims); } diff --git a/paddle/phi/kernels/fusion/xpu/bn_act_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/bn_act_xpu_kernel.cc index 82840ec1b3537..17ff819d346d3 100644 --- a/paddle/phi/kernels/fusion/xpu/bn_act_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/bn_act_xpu_kernel.cc @@ -69,7 +69,7 @@ void BNActXPUKernel(const Context& dev_ctx, 5, phi::errors::InvalidArgument( "The size of input X's dimensions should be less than 6." - "But received: the size of input X's dimensionss is [%d]", + "But received: the size of input X's dimensions is [%d]", x_dims.size())); bool is_nchw = data_layout_str == "NCHW"; diff --git a/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc index 58f40f3040f74..cc66ee88b0787 100644 --- a/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/conv_transpose_xpu_kernel.cc @@ -39,7 +39,7 @@ void Conv2dTransposeXPUKernel(const Context& ctx, const std::string& act_type, DenseTensor* out, DenseTensor* out_max) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; ctx.template Alloc(out); ctx.template Alloc(out_max); @@ -71,11 +71,11 @@ void Conv2dTransposeXPUKernel(const Context& ctx, x_max.get_ptr() == nullptr ? nullptr : x_max.get_ptr()->data(); auto filter_max_data = filter_max.data(); - int r = xpu::conv2d_transpose_fusion_v2( + int r = xpu::conv2d_transpose_fusion_v2( ctx.x_context(), - reinterpret_cast(x.data()), + reinterpret_cast(x.data()), filter.data(), - reinterpret_cast(out->data()), + reinterpret_cast(out->data()), batch_size, img_yc, img_xh, diff --git a/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc new file mode 100644 index 0000000000000..d36d7416a023a --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_bias_act_kernel.cc @@ -0,0 +1,138 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +static void DispatchComputeImpl(const phi::XPUContext *xpu_ctx, + const DenseTensor &x, + const DenseTensor *bias, + const DenseTensor &dequant_scales, + const DenseTensor &shift, + const DenseTensor &smooth, + const std::string &act_method, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor *out) { + PADDLE_THROW( + phi::errors::Unimplemented("fused_bias_act with smooth " + "quant on xpu is not implemented yet.")); +} + +template +static void ComputeImpl(const phi::XPUContext *xpu_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const std::string &act_method, + DenseTensor *out) { + using XPUType = typename XPUTypeTrait::Type; + int rows = x.dims()[0]; + int cols = x.dims()[1]; + int r = 0; + if (bias) { + r = baidu::xpu::api::broadcast_add( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(bias.get().data()), + reinterpret_cast(const_cast(x.data())), + {rows, cols}, + {1, cols}); + PD_CHECK(r == 0, "baidu::xpu::api::broadcast_add failed."); + } + if (act_method == "geglu") { + PD_THROW( + "NOT supported GeGLU. " + "Currently Only Support SwiGLU, GeLU, ReLU"); + } else if (act_method == "swiglu") { + r = baidu::xpu::api::swiglu( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + {rows, cols}, + 1, + true); + PD_CHECK(r == 0, "baidu::xpu::api::swiglu failed."); + } else if (act_method == "gelu") { + r = baidu::xpu::api::gelu( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + rows * cols); + PD_CHECK(r == 0, "baidu::xpu::api::gelu failed."); + } else if (act_method == "relu") { + r = baidu::xpu::api::relu( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + rows * cols); + PD_CHECK(r == 0, "baidu::xpu::api::relu failed."); + } else { + PD_THROW( + "NOT supported. " + "Currently Only Support SwiGLU, GeLU, ReLU"); + } +} + +template +void FusedBiasActKernel(const Context &dev_ctx, + const DenseTensor &x, + const paddle::optional &bias, + const paddle::optional &dequant_scales, + const paddle::optional &shift, + const paddle::optional &smooth, + const std::string &act_method, + const std::string &compute_dtype, + float quant_scale, + int quant_round_type, + float quant_max_bound, + float quant_min_bound, + DenseTensor *out) { + auto xpu_ctx = static_cast(&dev_ctx); + dev_ctx.template Alloc(out); + + if (dequant_scales && dequant_scales.get().numel() > 0) { + return DispatchComputeImpl(xpu_ctx, + x, + bias ? &(bias.get()) : nullptr, + dequant_scales.get(), + shift.get(), + smooth.get(), + act_method, + quant_scale, + quant_round_type, + quant_max_bound, + quant_min_bound, + out); + } else { + return ComputeImpl(xpu_ctx, x, bias, act_method, out); + } +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(fused_bias_act, + XPU, + ALL_LAYOUT, + phi::fusion::FusedBiasActKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc index 29f74e8e1fe23..aeb5cb22cbe66 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_feedforward_grad_kernel.cc @@ -231,7 +231,7 @@ void FFNGrad(const phi::XPUContext& dev_ctx, std::tie(info_d_dropout1, info_dw2, a_1, b_1, a_2, b_2) = fc_info; - // if l3_total_size >= dim_feedforward * bsz_seq * sizeof(T), first transpos + // if l3_total_size >= dim_feedforward * bsz_seq * sizeof(T), first transpose if (l3_total_size >= dim_feedforward * bsz_seq * sizeof(T) && info_dw2.trans_x) { r = xpu::transpose(xpu_ctx, diff --git a/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc new file mode 100644 index 0000000000000..833caa6688787 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/fused_layernorm_kernel.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +namespace fusion { + +template +void FusedLayerNormKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& bias, + const paddle::optional& residual, + const paddle::optional& norm_weight, + const paddle::optional& norm_bias, + const float epsilon, + const float residual_alpha, + const int begin_norm_axis, + const float quant_scale, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + DenseTensor* out, + DenseTensor* residual_out, + DenseTensor* mean, + DenseTensor* variance) { + int r = xpu::SUCCESS; + auto xpu_ctx = static_cast(&dev_ctx); + using XPUType = typename XPUTypeTrait::Type; + auto x_shape = x.dims(); + int m = 1; + int n = 1; + for (int i = 0; i < begin_norm_axis; i++) { + m *= x_shape[i]; + } + for (int i = begin_norm_axis; i < x_shape.size(); i++) { + n *= x_shape[i]; + } + + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(mean); + dev_ctx.template Alloc(variance); + + DenseTensor residual_alpha_tmp; + residual_alpha_tmp.Resize({1}); + + DenseTensor residual_alpha_ptr; + residual_alpha_ptr.Resize({1}); + + dev_ctx.template Alloc(&residual_alpha_tmp); + dev_ctx.template Alloc(&residual_alpha_ptr); + + r = baidu::xpu::api::constant(xpu_ctx->x_context(), + residual_alpha_tmp.data(), + 1, + residual_alpha); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + + r = baidu::xpu::api::cast_v2( + xpu_ctx->x_context(), + residual_alpha_tmp.data(), + reinterpret_cast(residual_alpha_ptr.data()), + 1); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast_v2"); + + if (residual) { + dev_ctx.template Alloc(residual_out); + r = baidu::xpu::api::broadcast_mul( + xpu_ctx->x_context(), + reinterpret_cast(residual.get().data()), + reinterpret_cast(residual_alpha_ptr.data()), + reinterpret_cast(const_cast(residual.get().data())), + {m, n}, + {1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul"); + } + + if (!norm_weight && !norm_bias) { + if (bias) { + r = baidu::xpu::api::broadcast_add( + xpu_ctx->x_context(), + reinterpret_cast(out->data()), + reinterpret_cast(bias.get().data()), + reinterpret_cast(out->data()), + {m, n}, + {n}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + } + if (residual) { + r = baidu::xpu::api::add( + xpu_ctx->x_context(), + reinterpret_cast(out->data()), + reinterpret_cast(residual.get().data()), + reinterpret_cast(out->data()), + m * n); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + } + + r = baidu::xpu::api::add(xpu_ctx->x_context(), + reinterpret_cast(out->data()), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + m * n); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add"); + return; + } else { + if (bias) { + r = baidu::xpu::api::broadcast_add( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(bias.get().data()), + reinterpret_cast(const_cast((x.data()))), + {m, n}, + {n}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); + } + if (residual) { + r = baidu::xpu::api::add_layer_norm_fusion( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(residual.get().data()), + reinterpret_cast(out->data()), + m, + n, + epsilon, + norm_weight.get().data(), + norm_bias.get().data(), + mean->data(), + variance->data(), + reinterpret_cast(residual_out->data())); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_layer_norm_fusion"); + } else { + r = baidu::xpu::api::layer_norm( + xpu_ctx->x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + m, + n, + epsilon, + norm_weight.get().data(), + norm_bias.get().data(), + mean->data(), + variance->data()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm"); + } + if (quant_scale > 0.0f) { + PD_THROW("NOT supported quant int8. "); + } else { + return; + } + } +} + +} // namespace fusion + +} // namespace phi + +PD_REGISTER_KERNEL(fused_bias_residual_layernorm, + XPU, + ALL_LAYOUT, + phi::fusion::FusedLayerNormKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_int8_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_int8_xpu_kernel.cc index 236e276cb937d..e252349ce186b 100755 --- a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_int8_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_int8_xpu_kernel.cc @@ -465,9 +465,9 @@ void FusedMultiTransformerInt8XpuKernel( attn_layout); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xft::fused_multi_transformer_gpt_int8"); #else - LOG(FATAL) - << "fused_multi_transformer_gpt_int8 is not supported since it's not " - "compiled with XPU_XFT"; + PADDLE_THROW( + phi::errors::Fatal("fused_multi_transformer_gpt_int8 is not supported " + "since it's not compiled with XPU_XFT")); #endif } diff --git a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc index 8c151e0257e0e..7d26e056ed7f9 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_multi_transformer_xpu_kernel.cc @@ -366,8 +366,9 @@ void FusedMultiTransformerXpuKernel( attn_layout); PADDLE_ENFORCE_XDNN_SUCCESS(r, "xft::fused_multi_transformer_gpt"); #else - LOG(FATAL) << "fused_multi_transformer_xpu is not supported since it's not " - "compiled with XPU_XFT"; + PADDLE_THROW( + phi::errors::Fatal("fused_multi_transformer_xpu is not supported since " + "it's not compiled with XPU_XFT")); #endif } diff --git a/paddle/phi/kernels/fusion/xpu/fused_rope_grad_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_rope_grad_kernel.cc index 1e988ca9ea03e..dba65efd0a179 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_rope_grad_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_rope_grad_kernel.cc @@ -32,7 +32,7 @@ void FusedRopeGradKernel(const Context& dev_ctx, DenseTensor* dq, DenseTensor* dk, DenseTensor* dv) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; if (dout_q.numel() <= 0) { return; } @@ -48,8 +48,8 @@ void FusedRopeGradKernel(const Context& dev_ctx, xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int64_t sin_cos_len = batch_size * seq_len * head_dim; - auto* sin_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); - auto* cos_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); + auto* sin_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); + auto* cos_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); if (sin.get_ptr() && cos.get_ptr()) { PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(), @@ -61,9 +61,9 @@ void FusedRopeGradKernel(const Context& dev_ctx, cos.get_ptr()->dims())); } - XPUGetSinCosData( + XPUGetSinCosData( dev_ctx, sin, position_ids, sin_data, batch_size, seq_len, head_dim); - XPUGetSinCosData( + XPUGetSinCosData( dev_ctx, cos, position_ids, cos_data, batch_size, seq_len, head_dim); if (use_neox_rotary_style) { @@ -72,39 +72,58 @@ void FusedRopeGradKernel(const Context& dev_ctx, phi::errors::Unimplemented("XPU do not support rotary_embedding_grad " "with use_neox_rotary_style set.")); } else { - auto* dq_data = reinterpret_cast(dev_ctx.template Alloc(dq)); - XPUFusedRotaryHalf( - dev_ctx, - reinterpret_cast(dout_q.data()), - sin_data, - cos_data, - dq_data, - batch_size, - seq_len, - num_heads, - head_dim, - true); - - if (dout_k.get_ptr()) { - auto* dk_data = reinterpret_cast(dev_ctx.template Alloc(dk)); - XPUFusedRotaryHalf( - dev_ctx, - reinterpret_cast(dout_k->data()), + if (head_dim * sizeof(T) <= 1024 && head_dim % 64 == 0 && dout_k) { + auto* dq_data = reinterpret_cast(dev_ctx.template Alloc(dq)); + auto* dk_data = reinterpret_cast(dev_ctx.template Alloc(dk)); + int ret = xpu::rotary_no_freqs_qk_embedding_v2_grad( + dev_ctx.x_context(), + reinterpret_cast(dout_q.data()), + reinterpret_cast(dout_k->data()), sin_data, cos_data, + dq_data, dk_data, + {batch_size, seq_len, num_heads, head_dim}, + {batch_size, seq_len, 1, head_dim}, + {seq_len * num_heads * head_dim, num_heads * head_dim, head_dim, 1}, + {seq_len * head_dim, head_dim, head_dim, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rotary_no_freqs_qk_embedding_v2_grad"); + } else { + auto* dq_data = reinterpret_cast(dev_ctx.template Alloc(dq)); + XPUFusedRotaryHalf( + dev_ctx, + reinterpret_cast(dout_q.data()), + sin_data, + cos_data, + dq_data, batch_size, seq_len, num_heads, head_dim, true); + + if (dout_k.get_ptr()) { + auto* dk_data = + reinterpret_cast(dev_ctx.template Alloc(dk)); + XPUFusedRotaryHalf( + dev_ctx, + reinterpret_cast(dout_k->data()), + sin_data, + cos_data, + dk_data, + batch_size, + seq_len, + num_heads, + head_dim, + true); + } } if (dout_v.get_ptr()) { - auto* dv_data = reinterpret_cast(dev_ctx.template Alloc(dv)); - XPUFusedRotaryHalf( + auto* dv_data = reinterpret_cast(dev_ctx.template Alloc(dv)); + XPUFusedRotaryHalf( dev_ctx, - reinterpret_cast(dout_v->data()), + reinterpret_cast(dout_v->data()), sin_data, cos_data, dv_data, diff --git a/paddle/phi/kernels/fusion/xpu/fused_rope_kernel.cc b/paddle/phi/kernels/fusion/xpu/fused_rope_kernel.cc index c8980310fb0f9..38141a9bfaf6c 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_rope_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/fused_rope_kernel.cc @@ -33,7 +33,7 @@ void FusedRopeKernel(const Context& dev_ctx, DenseTensor* out_q, DenseTensor* out_k, DenseTensor* out_v) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; if (q.numel() <= 0) { return; } @@ -54,8 +54,8 @@ void FusedRopeKernel(const Context& dev_ctx, xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int64_t sin_cos_len = batch_size * seq_len * head_dim; - auto* sin_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); - auto* cos_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); + auto* sin_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); + auto* cos_data = RAII_GUARD.alloc_l3_or_gm(sin_cos_len); if (sin.get_ptr() && cos.get_ptr()) { PADDLE_ENFORCE_EQ(sin.get_ptr()->dims(), @@ -67,9 +67,9 @@ void FusedRopeKernel(const Context& dev_ctx, cos.get_ptr()->dims())); } - XPUGetSinCosData( + XPUGetSinCosData( dev_ctx, sin, position_ids, sin_data, batch_size, seq_len, head_dim); - XPUGetSinCosData( + XPUGetSinCosData( dev_ctx, cos, position_ids, cos_data, batch_size, seq_len, head_dim); if (use_neox_rotary_style) { @@ -77,39 +77,60 @@ void FusedRopeKernel(const Context& dev_ctx, PADDLE_THROW(phi::errors::Unimplemented( "XPU do not support rotary_embedding with use_neox_rotary_style set.")); } else { - auto* outq_data = reinterpret_cast(dev_ctx.template Alloc(out_q)); - XPUFusedRotaryHalf( - dev_ctx, - reinterpret_cast(q.data()), - sin_data, - cos_data, - outq_data, - batch_size, - seq_len, - num_heads, - head_dim); - - if (k) { + if (head_dim * sizeof(T) <= 1024 && head_dim % 64 == 0 && k) { + auto* outq_data = + reinterpret_cast(dev_ctx.template Alloc(out_q)); auto* outk_data = - reinterpret_cast(dev_ctx.template Alloc(out_k)); - XPUFusedRotaryHalf( - dev_ctx, - reinterpret_cast(k->data()), + reinterpret_cast(dev_ctx.template Alloc(out_k)); + int ret = xpu::rotary_no_freqs_qk_embedding_v2( + dev_ctx.x_context(), + reinterpret_cast(q.data()), + reinterpret_cast(k->data()), sin_data, cos_data, + outq_data, outk_data, + {batch_size, seq_len, num_heads, head_dim}, + {batch_size, seq_len, 1, head_dim}, + {seq_len * num_heads * head_dim, num_heads * head_dim, head_dim, 1}, + {seq_len * head_dim, head_dim, head_dim, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "rotary_no_freqs_qk_embedding_v2"); + } else { + auto* outq_data = + reinterpret_cast(dev_ctx.template Alloc(out_q)); + XPUFusedRotaryHalf( + dev_ctx, + reinterpret_cast(q.data()), + sin_data, + cos_data, + outq_data, batch_size, seq_len, num_heads, head_dim); + + if (k) { + auto* outk_data = + reinterpret_cast(dev_ctx.template Alloc(out_k)); + XPUFusedRotaryHalf( + dev_ctx, + reinterpret_cast(k->data()), + sin_data, + cos_data, + outk_data, + batch_size, + seq_len, + num_heads, + head_dim); + } } if (v) { auto* outv_data = - reinterpret_cast(dev_ctx.template Alloc(out_v)); - XPUFusedRotaryHalf( + reinterpret_cast(dev_ctx.template Alloc(out_v)); + XPUFusedRotaryHalf( dev_ctx, - reinterpret_cast(v->data()), + reinterpret_cast(v->data()), sin_data, cos_data, outv_data, diff --git a/paddle/phi/kernels/fusion/xpu/fused_rope_utils.h b/paddle/phi/kernels/fusion/xpu/fused_rope_utils.h index 6432815b36489..393d6955d19a6 100644 --- a/paddle/phi/kernels/fusion/xpu/fused_rope_utils.h +++ b/paddle/phi/kernels/fusion/xpu/fused_rope_utils.h @@ -17,11 +17,11 @@ namespace phi { namespace fusion { -template +template void XPUGetSinCosData(const Context& dev_ctx, const paddle::optional& sin_cos, const paddle::optional& position_ids, - XPUT* sin_cos_data, + XPUType* sin_cos_data, int64_t batch_size, int64_t seq_len, int64_t head_dim) { @@ -68,22 +68,22 @@ void XPUGetSinCosData(const Context& dev_ctx, phi::errors::InvalidArgument( "The batch_size and seq_len of position_ids must be the same as " "those of q.")); - using XPUTFp16 = typename XPUTypeTrait::Type; - using XPUTBf16 = typename XPUTypeTrait::Type; - if (std::is_same::value) { - int ret = xpu::gather( + using XPUTypeFp16 = typename XPUTypeTrait::Type; + using XPUTypeBf16 = typename XPUTypeTrait::Type; + if (std::is_same::value) { + int ret = xpu::gather( dev_ctx.x_context(), - reinterpret_cast(sin_cos->data()), + reinterpret_cast(sin_cos->data()), position_ids->data(), - reinterpret_cast(sin_cos_data), + reinterpret_cast(sin_cos_data), {seq_len, head_dim}, batch_size * seq_len, 0); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather"); } else { - int ret = xpu::gather( + int ret = xpu::gather( dev_ctx.x_context(), - reinterpret_cast(sin_cos->data()), + reinterpret_cast(sin_cos->data()), position_ids->data(), sin_cos_data, {seq_len, head_dim}, @@ -92,37 +92,37 @@ void XPUGetSinCosData(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(ret, "gather"); } } else { - int ret = - xpu::broadcast(dev_ctx.x_context(), - reinterpret_cast(sin_cos->data()), - sin_cos_data, - {1, seq_len, head_dim}, - {batch_size, seq_len, head_dim}); + int ret = xpu::broadcast( + dev_ctx.x_context(), + reinterpret_cast(sin_cos->data()), + sin_cos_data, + {1, seq_len, head_dim}, + {batch_size, seq_len, head_dim}); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); } } else { int ret = xpu::constant(dev_ctx.x_context(), sin_cos_data, batch_size * seq_len * head_dim, - static_cast(0.0f)); + static_cast(0.0f)); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "constant"); } } -template +template void XPUFusedRotaryHalf(const Context& dev_ctx, - const XPUT* in_data, - const XPUT* sin_data, - const XPUT* cos_data, - XPUT* out_data, + const XPUType* in_data, + const XPUType* sin_data, + const XPUType* cos_data, + XPUType* out_data, int64_t batch_size, int64_t seq_len, int64_t num_heads, int64_t head_dim, bool is_bwd = false) { - auto func = &xpu::rotary_no_freqs_embedding_v2; + auto func = &xpu::rotary_no_freqs_embedding_v2; if (is_bwd) { - func = &xpu::rotary_no_freqs_embedding_v2_grad; + func = &xpu::rotary_no_freqs_embedding_v2_grad; } int ret = diff --git a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc index 1f76fc3ef02d8..8b65964671b0b 100644 --- a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc @@ -6,7 +6,7 @@ // // http://www.apache.org/licenses/LICENSE-2.0 // -// Unless required by applicable law or agreed to in writing, sofint16_tare +// Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and @@ -47,6 +47,7 @@ void MultiEncoderXPUKernel( const std::vector& ln_scale, const std::vector& ln_bias, const std::vector& smooth_scale_weight, + const std::vector& roformer_embedding, const paddle::optional& mask, const paddle::optional& seq_lod, const paddle::optional& max_seq_len, @@ -60,6 +61,7 @@ void MultiEncoderXPUKernel( int relative_type, int slice_idx, bool is_per_channel, + int max_pos_len, const std::vector& softmax_max_value, const std::vector& quant_types, DenseTensor* out, @@ -150,7 +152,6 @@ void MultiEncoderXPUKernel( } } - std::vector test_data(6, 0); for (size_t i = 0; i < fc_input_max.size(); i++) { fc_input_max_data.push_back(fc_input_max[i]->data()); } @@ -199,6 +200,16 @@ void MultiEncoderXPUKernel( qkv_attn_param.quant_type_.assign(set_quant_types.begin(), set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + if (!roformer_embedding.empty()) { + std::vector roformer_embedding_data; + for (size_t i = 0; i < roformer_embedding.size(); i++) { + roformer_embedding_data.push_back(roformer_embedding[i]->data()); + } + qkv_attn_param.relative_type = relative_type; + qkv_attn_param.max_pos_len = max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_data.begin(), + roformer_embedding_data.end()); + } if (!enable_int8) { if (local_quant) { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) @@ -242,6 +253,16 @@ void MultiEncoderXPUKernel( qkv_attn_param.quant_type_.assign(set_quant_types.begin(), set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + if (!roformer_embedding.empty()) { + std::vector roformer_embedding_data; + for (size_t i = 0; i < roformer_embedding.size(); i++) { + roformer_embedding_data.push_back(roformer_embedding[i]->data()); + } + qkv_attn_param.relative_type = relative_type; + qkv_attn_param.max_pos_len = max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_data.begin(), + roformer_embedding_data.end()); + } if (!enable_int8) { if (local_quant) { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) @@ -288,6 +309,16 @@ void MultiEncoderXPUKernel( qkv_attn_param.quant_type_.assign(set_quant_types.begin(), set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + if (!roformer_embedding.empty()) { + std::vector roformer_embedding_data; + for (size_t i = 0; i < roformer_embedding.size(); i++) { + roformer_embedding_data.push_back(roformer_embedding[i]->data()); + } + qkv_attn_param.relative_type = relative_type; + qkv_attn_param.max_pos_len = max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_data.begin(), + roformer_embedding_data.end()); + } if (!enable_int8) { if (local_quant) { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) @@ -319,6 +350,6 @@ PD_REGISTER_KERNEL(multi_encoder_xpu, phi::fusion::MultiEncoderXPUKernel, float, phi::dtype::float16) { - kernel->InputAt(9).SetBackend(phi::Backend::CPU); kernel->InputAt(10).SetBackend(phi::Backend::CPU); + kernel->InputAt(11).SetBackend(phi::Backend::CPU); } diff --git a/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc index b08921e750a80..5c8562d6c3969 100644 --- a/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/qkv_attention_xpu_kernel.cc @@ -6,7 +6,7 @@ // // http://www.apache.org/licenses/LICENSE-2.0 // -// Unless required by applicable law or agreed to in writing, sofint16_tare +// Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and diff --git a/paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc b/paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc new file mode 100644 index 0000000000000..ae42b0eabc614 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void RoformerRelativePosXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& sin_emb, + const DenseTensor& cos_emb, + int max_pos_len, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + auto* x_data = reinterpret_cast(x.data()); + auto* sin_emb_data = sin_emb.data(); + auto* cos_emb_data = cos_emb.data(); + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + auto x_dims = x.dims(); + int batch = x_dims[0]; + int head_num = x_dims[1]; + int seqlen = x_dims[2]; + int head_dim = x_dims[3]; + if (seqlen > max_pos_len) { + PADDLE_THROW(phi::errors::InvalidArgument( + "The input sequence length should be less than or equal to the " + "maximum position length. But received seqlen: %d, max_pos_len: %d", + seqlen, + max_pos_len)); + } + std::vector lod; + lod.resize(batch + 1); + for (int i = 0; i < batch + 1; i++) { + lod[i] = i * seqlen; + } + int r = + xpu::rope(ctx.x_context(), + x_data, + out_data, + cos_emb_data, + sin_emb_data, + batch, + head_num, + head_dim, + head_num * head_dim, + lod, + max_pos_len, + false, // no vsl + true); // transpose to [n, seql, head_num, head_dim] + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "roformer_relative_embedding_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(roformer_relative_embedding_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::RoformerRelativePosXPUKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/fusion/xpu/variable_length_memory_efficient_attention_kernel.cc b/paddle/phi/kernels/fusion/xpu/variable_length_memory_efficient_attention_kernel.cc new file mode 100644 index 0000000000000..8f6a25ddc5c86 --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/variable_length_memory_efficient_attention_kernel.cc @@ -0,0 +1,122 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void MultiHeadAttentionVariableForwardKernel( + const Context& ctx, + const DenseTensor& query, + const DenseTensor& key, + const DenseTensor& value, + const DenseTensor& seq_lens, + const DenseTensor& kv_seq_lens, + const paddle::optional& mask, + const float scale, + const bool causal, + const int pre_cache_length, + DenseTensor* output) { + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + + using XPUType = typename XPUTypeTrait::Type; + + int64_t num_batches = query.dims()[0]; + int64_t num_heads = query.dims()[1]; + int64_t kv_num_heads = key.dims()[1]; + int64_t query_seq_len = query.dims()[2]; + int64_t head_size = query.dims()[3]; + std::vector mask_shape = {}; + if (mask) { + // [B, 1, S, D] + auto mask_tensor = mask.get(); + mask_shape = common::vectorize(mask_tensor.dims()); + } + + xpu::QKVAttnParam qkv_attn_param( + num_batches, /* batch */ + query_seq_len, /* max_seqlen */ + num_heads, /* head_num */ + head_size, /* head_dim */ + mask_shape, /* mask_shape */ + xpu::Activation_t::RELU, /* act */ + -1, /* last_slice_seq */ + false, /* do_fc_qkv_fusion */ + -1, /* hidden_dim */ + false, /* is_pre_norm */ + false, /* is_perchannel */ + 2, /* qkv_shape */ + AttnMacMaxPtrType_t::ATTN_WHOLE_BATCH, /* max_ptr_type */ + -1, /* ldz */ + scale /* alpha */ + ); + qkv_attn_param.key_value_head_num = kv_num_heads; + + const XPUType* mask_ptr = + mask ? reinterpret_cast(mask.get().data()) : nullptr; + auto* out_data = reinterpret_cast(ctx.template Alloc(output)); + XPUType* qk_buf = RAII_GUARD.alloc_l3_or_gm( + num_batches * num_heads * query_seq_len * query_seq_len); + float* maxptr_buf = RAII_GUARD.alloc_l3_or_gm(32); + int r = xpu::qk_attention( + ctx.x_context(), /* ctx */ + reinterpret_cast(query.data()), /* q */ + reinterpret_cast(key.data()), /* k */ + qk_buf, /* qk */ + nullptr, /* max q */ + nullptr, /* max k */ + maxptr_buf, /* max qk */ + qkv_attn_param, /* param */ + mask_ptr /* mask */ + ); + PADDLE_ENFORCE_EQ( + r, 0, phi::errors::InvalidArgument("xpu::qk_attention run failed")); + XPUType* out_tmp_buf = RAII_GUARD.alloc_l3_or_gm( + num_batches * query_seq_len * num_heads * head_size); + r = xpu::qk_v_attention( + ctx.x_context(), /* ctx */ + qk_buf, /* qk */ + reinterpret_cast(value.data()), /* v */ + out_tmp_buf, /* output */ + maxptr_buf, /* max qk */ + nullptr, /* max v */ + nullptr, /* max qkv */ + qkv_attn_param /* mask */ + ); + PADDLE_ENFORCE_EQ( + r, 0, phi::errors::InvalidArgument("xpu::qk_v_attention run failed")); + r = xpu::transpose( + ctx.x_context(), + out_tmp_buf, + out_data, + {num_batches, query_seq_len, num_heads, head_size}, + {0, 2, 1, 3}); + PADDLE_ENFORCE_EQ( + r, 0, phi::errors::InvalidArgument("xpu::transpose run failed")); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(variable_length_memory_efficient_attention, + XPU, + ALL_LAYOUT, + phi::fusion::MultiHeadAttentionVariableForwardKernel, + float, + phi::dtype::float16) { + kernel->InputAt(3).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 7af857345cdd6..594eefe5b8de1 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -510,10 +510,10 @@ PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(sigmoid_triple_grad, PD_REGISTER_ACTIVATION_GRAD_KERNEL(hardsigmoid_grad, HardSigmoidGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(logsigmoid_grad, LogSigmoidGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log2_grad, Log2GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log10_grad, Log10GradKernel) -PD_REGISTER_ACTIVATION_GRAD_KERNEL(log1p_grad, Log1pGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log2_grad, Log2GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log10_grad, Log10GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(log1p_grad, Log1pGradKernel) PD_REGISTER_KERNEL(log_double_grad, GPU, ALL_LAYOUT, @@ -521,7 +521,9 @@ PD_REGISTER_KERNEL(log_double_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL_WITH_COMPLEX(hardswish_grad, HardSwishGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index e8dadf31fd945..1bf3d92d80620 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -315,7 +315,9 @@ PD_REGISTER_KERNEL(log, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log2, GPU, ALL_LAYOUT, @@ -325,7 +327,9 @@ PD_REGISTER_KERNEL(log2, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log10, GPU, ALL_LAYOUT, @@ -335,7 +339,9 @@ PD_REGISTER_KERNEL(log10, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(log1p, GPU, ALL_LAYOUT, @@ -345,7 +351,9 @@ PD_REGISTER_KERNEL(log1p, int, int64_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(pow, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index 5292d7d29c07b..56be43fecb0d1 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -46,12 +46,12 @@ __global__ void AdamKernelREG(MT beta1, T* param_out, const MT* master_param, MT* master_param_out, - int ndim) { + int64_t ndim) { MT lr = *lr_; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); @@ -89,12 +89,12 @@ __global__ void AdamKernelMEM(MT beta1, T* param_out, const MT* master_param, MT* master_param_out, - int ndim) { + int64_t ndim) { MT lr = *lr_; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); diff --git a/paddle/phi/kernels/gpu/adamw_kernel.cu b/paddle/phi/kernels/gpu/adamw_kernel.cu index d40fdf392b1a2..97d0563d51ff8 100644 --- a/paddle/phi/kernels/gpu/adamw_kernel.cu +++ b/paddle/phi/kernels/gpu/adamw_kernel.cu @@ -49,12 +49,12 @@ __global__ void AdamWKernelREG(MT beta1, T* param_out, const MT* master_param, MT* master_param_out, - int ndim) { + int64_t ndim) { MT lr = *lr_ * lr_ratio; MT beta1_pow = beta1_pow_; MT beta2_pow = beta2_pow_; - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); @@ -98,12 +98,12 @@ __global__ void AdamWKernelMEM(MT beta1, T* param_out, const MT* master_param, MT* master_param_out, - int ndim) { + int64_t ndim) { MT lr = *lr_ * lr_ratio; MT beta1_pow = *beta1_pow_; MT beta2_pow = *beta2_pow_; - int id = blockIdx.x * blockDim.x + threadIdx.x; + int64_t id = blockIdx.x * blockDim.x + threadIdx.x; for (; id < ndim; id += gridDim.x * blockDim.x) { MT p = master_param ? master_param[id] : static_cast(param[id]); diff --git a/paddle/phi/kernels/gpu/all_gather_kernel.cu b/paddle/phi/kernels/gpu/all_gather_kernel.cu index ca6bfd7b4517b..c8ec6c63c5a98 100644 --- a/paddle/phi/kernels/gpu/all_gather_kernel.cu +++ b/paddle/phi/kernels/gpu/all_gather_kernel.cu @@ -73,7 +73,9 @@ PD_REGISTER_KERNEL(all_gather, int64_t, bool, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #else PD_REGISTER_KERNEL(all_gather, GPU, @@ -87,5 +89,7 @@ PD_REGISTER_KERNEL(all_gather, int16_t, int64_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu index cb766597c3142..9a34b9dd5bc26 100644 --- a/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu @@ -148,7 +148,9 @@ PD_REGISTER_KERNEL(c_embedding_grad, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #else PD_REGISTER_KERNEL(c_embedding_grad, GPU, @@ -156,5 +158,7 @@ PD_REGISTER_KERNEL(c_embedding_grad, phi::CEmbeddingGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/gpu/c_embedding_kernel.cu b/paddle/phi/kernels/gpu/c_embedding_kernel.cu index 869d226445d85..50aebe82417d4 100644 --- a/paddle/phi/kernels/gpu/c_embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/c_embedding_kernel.cu @@ -121,7 +121,9 @@ PD_REGISTER_KERNEL(c_embedding, float, double, phi::dtype::bfloat16, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #else PD_REGISTER_KERNEL(c_embedding, GPU, @@ -129,5 +131,7 @@ PD_REGISTER_KERNEL(c_embedding, phi::CEmbeddingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::complex, + phi::dtype::complex) {} #endif diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_cache.h b/paddle/phi/kernels/gpu/cudnn_lstm_cache.h index 197049452f97f..c5b3873ce5504 100644 --- a/paddle/phi/kernels/gpu/cudnn_lstm_cache.h +++ b/paddle/phi/kernels/gpu/cudnn_lstm_cache.h @@ -67,7 +67,30 @@ class ScopedRNNBase { y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); } -#if CUDNN_VERSION >= 7201 +#if CUDNN_VERSION >= 90000 + auto seqlen_is_empty = sequence_length.empty(); + if (seqlen_is_empty) { + std::vector seqlen_array(batch_size_); + for (int i = 0; i < batch_size_; ++i) { + seqlen_array[i] = seq_length_; + } + x_seq_desc_.descriptor( + seq_length_, batch_size_, input_size_, true, seqlen_array); + y_seq_desc_.descriptor(seq_length_, + batch_size_, + hidden_size_ * numDirections, + true, + seqlen_array); + } else { + x_seq_desc_.descriptor( + seq_length_, batch_size_, input_size_, true, sequence_length); + y_seq_desc_.descriptor(seq_length_, + batch_size_, + hidden_size_ * numDirections, + true, + sequence_length); + } +#elif CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { x_seq_desc_.descriptor( seq_length_, batch_size_, input_size_, true, sequence_length); @@ -107,6 +130,25 @@ class ScopedRNNBase { state_size); // ------------------- cudnn rnn descriptors --------------------- +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNDescriptor_v8( + rnn_desc_.desc(), + CUDNN_RNN_ALGO_STANDARD, + CUDNN_LSTM, + CUDNN_RNN_DOUBLE_BIAS, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, + CUDNN_LINEAR_INPUT, + cudnn_type, + cudnn_type, + CUDNN_DEFAULT_MATH, + input_size_, + hidden_size_, + hidden_size_, + num_layers_, + dropout_desc_.desc(), + seqlen_is_empty ? CUDNN_RNN_PADDED_IO_DISABLED + : CUDNN_RNN_PADDED_IO_ENABLED)); +#else PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNDescriptor_v6( handle, rnn_desc_.desc(), @@ -118,8 +160,9 @@ class ScopedRNNBase { CUDNN_LSTM, CUDNN_RNN_ALGO_STANDARD, cudnn_type)); +#endif -#if CUDNN_VERSION >= 7201 +#if CUDNN_VERSION < 90000 && CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNPaddingMode( rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); @@ -127,9 +170,14 @@ class ScopedRNNBase { #endif // ------------------- cudnn weights_size --------------------- - size_t weights_size_; +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNWeightSpaceSize( + handle, rnn_desc_.desc(), &weights_size_)); +#else PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNParamsSize( handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); +#endif + PADDLE_ENFORCE_EQ( weights_size_, sizeof(T) * weight_numel_, @@ -142,6 +190,15 @@ class ScopedRNNBase { std::vector dim_w = {dim_tmp, 1, 1}; weight_desc_.descriptor(layout, dim_w); // ------------------- cudnn workspace, reserve size --------------------- +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetRNNTempSpaceSizes(handle, + rnn_desc_.desc(), + CUDNN_FWD_MODE_TRAINING, + x_seq_desc_.desc(), + workspace_size, + reserve_size)); +#else PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnGetRNNWorkspaceSize(handle, rnn_desc_.desc(), @@ -150,6 +207,7 @@ class ScopedRNNBase { workspace_size)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNTrainingReserveSize( handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size)); +#endif } cudnnTensorDescriptor_t* x_descs() { return x_descs_.data(); } cudnnTensorDescriptor_t* y_descs() { return y_descs_.data(); } @@ -164,6 +222,7 @@ class ScopedRNNBase { cudnnRNNDescriptor_t rnn_desc() { return rnn_desc_.desc(); } cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); } + size_t weights_size() { return weights_size_; } private: int seq_length_; @@ -176,6 +235,7 @@ class ScopedRNNBase { int weight_numel_; bool initialized_; bool is_bidirec_; + size_t weights_size_; std::vector x_descs_; std::vector y_descs_; diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu b/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu index 661a1dd90e7e9..5d3998849d118 100644 --- a/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cudnn_lstm_grad_kernel.cu @@ -145,6 +145,50 @@ void CudnnLSTMGradKernel( ctx.template Alloc(&workspace_data_); const uint8_t *reserve_data = reserve.data(); +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardData_v8( + handle, + rnn.rnn_desc(), + nullptr, + rnn.y_seq_desc(), + out_data, + out_grad_data, + rnn.x_seq_desc(), + in_grad_data, + rnn.init_h_desc(), + init_h_data, + last_h_grad_data, + init_h_grad_data, + rnn.init_c_desc(), + init_c_data, + last_c_grad_data, + init_c_grad_data, + rnn.weights_size(), + weight_data, + workspace_size, + workspace_data_.data(), + reserve_size, + const_cast(reserve_data))); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeights_v8( + handle, + rnn.rnn_desc(), + CUDNN_WGRAD_MODE_ADD, + nullptr, + rnn.x_seq_desc(), + x.data(), + rnn.init_h_desc(), + init_h.data(), + rnn.y_seq_desc(), + out.data(), + rnn.weights_size(), + weight_grad_data, + workspace_size, + workspace_data_.data(), + reserve_size, + const_cast(reserve_data))); +#else + if (!has_seq_length) { // This interface is used when the input/output is unpadded. #ifdef PADDLE_WITH_HIP @@ -298,6 +342,8 @@ void CudnnLSTMGradKernel( "of cudnn is larger than 7.2.1")); #endif } + +#endif // end CUDNN_VERSION >= 90000 } } // namespace phi diff --git a/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu b/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu index f3a03727e0bc4..73d11244e8f06 100644 --- a/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu +++ b/paddle/phi/kernels/gpu/cudnn_lstm_kernel.cu @@ -40,6 +40,31 @@ void LSTMInferece(const bool &has_seq_length, T *last_c_data, phi::DenseTensor *workspace_data, const size_t &workspace_size) { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForward(handle, + rnn->rnn_desc(), + CUDNN_FWD_MODE_INFERENCE, + nullptr, + rnn->x_seq_desc(), + x_data, + rnn->y_seq_desc(), + out_data, + rnn->init_h_desc(), + init_h_data, + last_h_data, + rnn->init_c_desc(), + init_c_data, + last_c_data, + rnn->weights_size(), + w_data, + workspace_size, + workspace_data->data(), + 0, + nullptr)); + +#else + if (!has_seq_length) { // for inference // This interface is used when the input/output is unpadded. @@ -125,6 +150,8 @@ void LSTMInferece(const bool &has_seq_length, "the version of cudnn is larger than 7.2.1")); #endif } + +#endif // end CUDNN_VERSION >= 90000 } template @@ -265,6 +292,30 @@ void CudnnLSTMKernel( &workspace_data_, workspace_size); } else { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForward(handle, + rnn.rnn_desc(), + CUDNN_FWD_MODE_TRAINING, + nullptr, + rnn.x_seq_desc(), + x_data, + rnn.y_seq_desc(), + out_data, + rnn.init_h_desc(), + init_h_data, + last_h_data, + rnn.init_c_desc(), + init_c_data, + last_c_data, + rnn.weights_size(), + w_data, + workspace_size, + workspace_data_.data(), + reserve_size, + reserve_data)); +#else + if (!has_seq_length) { // for train // This interface is used when the input/output is unpadded. @@ -355,6 +406,7 @@ void CudnnLSTMKernel( "the version of cudnn is larger than 7.2.1")); #endif } +#endif // end CUDNN_VERSION >= 90000 } } diff --git a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu index 08b8b89afe4b3..f7953fcc3194f 100644 --- a/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu +++ b/paddle/phi/kernels/gpu/cum_maxmin_kernel.cu @@ -265,9 +265,9 @@ void ScanWithIndicesKernel(const Context& dev_ctx, int num_rows = x.numel() / row_size; dim3 threads(16, 32); - dim3 grid( - std::min(dev_ctx.GetCUDAMaxGridDimSize()[0], - static_cast(std::ceil(static_cast(num_rows) / + dim3 grid(std::min( + dev_ctx.GetCUDAMaxGridDimSize()[0], + static_cast(std::ceil(static_cast(num_rows) / static_cast(threads.y))))); KernelScanInnerWithIndices diff --git a/paddle/phi/kernels/gpu/data_kernel.cu b/paddle/phi/kernels/gpu/data_kernel.cu index e4bd9c58b75dd..e1634fce75274 100644 --- a/paddle/phi/kernels/gpu/data_kernel.cu +++ b/paddle/phi/kernels/gpu/data_kernel.cu @@ -35,6 +35,23 @@ PD_REGISTER_KERNEL(shadow_feed, phi::complex64, phi::complex128) {} +PD_REGISTER_KERNEL(shadow_feed_tensors, + GPU, + ALL_LAYOUT, + phi::ShadowFeedTensorsKernel, + bool, + uint8_t, + float, + int8_t, + int16_t, + int32_t, + int64_t, + double, + phi::float16, + phi::bfloat16, + phi::complex64, + phi::complex128) {} + PD_REGISTER_KERNEL(print_kernel, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/diag_grad_kernel.cu b/paddle/phi/kernels/gpu/diag_grad_kernel.cu index 71d451ba4f380..a4e0861f180ab 100644 --- a/paddle/phi/kernels/gpu/diag_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/diag_grad_kernel.cu @@ -136,4 +136,6 @@ PD_REGISTER_KERNEL(diag_grad, int, int64_t, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/diag_kernel.cu b/paddle/phi/kernels/gpu/diag_kernel.cu index 7548c822fa753..bc5c8a4017491 100644 --- a/paddle/phi/kernels/gpu/diag_kernel.cu +++ b/paddle/phi/kernels/gpu/diag_kernel.cu @@ -139,4 +139,6 @@ PD_REGISTER_KERNEL(diag, int, int64_t, float, - double) {} + double, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 7d95c6c050bbd..1f292d9854ed3 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -266,7 +266,9 @@ PD_REGISTER_KERNEL(embedding_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} PD_REGISTER_KERNEL(embedding_sparse_grad, GPU, @@ -275,4 +277,6 @@ PD_REGISTER_KERNEL(embedding_sparse_grad, float, double, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index fdf453522e10d..328eb2484dee6 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -136,4 +136,6 @@ PD_REGISTER_KERNEL(embedding, double, int8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/eye_kernel.cu b/paddle/phi/kernels/gpu/eye_kernel.cu index 04735aaa228a6..faf36495b28a7 100644 --- a/paddle/phi/kernels/gpu/eye_kernel.cu +++ b/paddle/phi/kernels/gpu/eye_kernel.cu @@ -26,4 +26,6 @@ PD_REGISTER_KERNEL(eye, int64_t, int, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu index 4774bebf5620b..4f93288edaf14 100644 --- a/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu @@ -119,8 +119,10 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, dropout, scale, causal, + 0, // attn_mask_start_row q.dtype(), attn_mask, + nullptr, // attn_mask_start_row_indices seed_offset.data()); VLOG(10) << "FlashAttn bwd seed: " << params.seed @@ -174,22 +176,24 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx, RaiseNotSupportedError(); #endif } - template -void FlashAttnGradKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const DenseTensor& out, - const DenseTensor& softmax_lse, - const DenseTensor& seed_offset, - const paddle::optional& attn_mask, - const DenseTensor& dout, - float dropout, - bool causal, - DenseTensor* dq, - DenseTensor* dk, - DenseTensor* dv) { +void FlashAttnGradBaseKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const paddle::optional& attn_mask_start_row_indices, + const DenseTensor& dout, + float dropout, + bool causal, + int attn_mask_start_row, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { #ifdef PADDLE_WITH_FLASHATTN // q, k, v [batch_size, seq_len, num_heads, head_dim] const auto& dims = q.dims(); @@ -259,8 +263,10 @@ void FlashAttnGradKernel(const Context& ctx, dropout, softmax_scale, causal, + attn_mask_start_row, q.dtype(), attn_mask, + attn_mask_start_row_indices, seed_offset.data()); VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.dims() << "], k.shape=[" @@ -308,7 +314,14 @@ void FlashAttnGradKernel(const Context& ctx, params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + params.attn_mask_tensor ? params.mask_dims.data() : nullptr, + params.attn_mask_start_row_indices_tensor + ? params.attn_mask_start_row_indices_tensor->data() + : nullptr, + params.attn_mask_start_row_indices_tensor + ? params.attn_mask_start_row_indices_dims.data() + : nullptr, + params.attn_mask_start_row); CheckFlashAttnStatus(succ); if (!is_mha) { if (dk) { @@ -323,6 +336,73 @@ void FlashAttnGradKernel(const Context& ctx, #endif } +template +void FlashAttnGradKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const paddle::optional& attn_mask, + const DenseTensor& dout, + float dropout, + bool causal, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { + FlashAttnGradBaseKernel(ctx, + q, + k, + v, + out, + softmax_lse, + seed_offset, + attn_mask, + paddle::none, + dout, + dropout, + causal, + 0, + dq, + dk, + dv); +} + +template +void FlashAttnWithSparseGradKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& attn_mask_start_row_indices, + const DenseTensor& out, + const DenseTensor& softmax_lse, + const DenseTensor& seed_offset, + const DenseTensor& dout, + float dropout, + bool causal, + int attn_mask_start_row, + DenseTensor* dq, + DenseTensor* dk, + DenseTensor* dv) { + FlashAttnGradBaseKernel(ctx, + q, + k, + v, + out, + softmax_lse, + seed_offset, + paddle::none, + attn_mask_start_row_indices, + dout, + dropout, + causal, + attn_mask_start_row, + dq, + dk, + dv); +} } // namespace phi PD_REGISTER_KERNEL(flash_attn_unpadded_grad, @@ -342,3 +422,12 @@ PD_REGISTER_KERNEL(flash_attn_grad, phi::dtype::bfloat16) { kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset } + +PD_REGISTER_KERNEL(flash_attn_with_sparse_mask_grad, + GPU, + ALL_LAYOUT, + phi::FlashAttnWithSparseGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} diff --git a/paddle/phi/kernels/gpu/flash_attn_kernel.cu b/paddle/phi/kernels/gpu/flash_attn_kernel.cu index 9f1ffd6bc4c69..7eb2d342feb79 100644 --- a/paddle/phi/kernels/gpu/flash_attn_kernel.cu +++ b/paddle/phi/kernels/gpu/flash_attn_kernel.cu @@ -65,25 +65,28 @@ void FlashAttnUnpaddedKernel( // TODO(umiswing): add shape check - FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, - batch_size, - max_seqlen_q, - max_seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset, - attn_mask, - softmax, - softmax_lse, - seed_offset); + FlashAttnFwdParamsV2 params = + FlashAttnFwdParamsV2(ctx, + batch_size, + max_seqlen_q, + max_seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + 0, // attn_mask_start_row + fixed_seed_offset, + attn_mask, + nullptr, // attn_mask_start_row_indices + softmax, + softmax_lse, + seed_offset); VLOG(10) << "FlashAttn fwd seed: " << params.seed << ", offset: " << params.offset; @@ -125,21 +128,24 @@ void FlashAttnUnpaddedKernel( } template -void FlashAttnKernel(const Context& ctx, - const DenseTensor& q, - const DenseTensor& k, - const DenseTensor& v, - const paddle::optional& fixed_seed_offset, - const paddle::optional& attn_mask, - float dropout, - bool causal, - bool return_softmax, - bool is_test, - const std::string& rng_name, - DenseTensor* out, - DenseTensor* softmax, - DenseTensor* softmax_lse, - DenseTensor* seed_offset) { +void FlashAttnBaseKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + const paddle::optional& attn_mask_start_row_indices, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + int attn_mask_start_row, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { #ifdef PADDLE_WITH_FLASHATTN // q, k, v [batch_size, seq_len, num_heads, head_dim] const auto& dims = q.dims(); @@ -161,25 +167,28 @@ void FlashAttnKernel(const Context& ctx, const float softmax_scale = 1.0f / std::sqrt(head_size); const float softmax_unscale = std::sqrt(head_size); - FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, - batch_size, - seqlen_q, - seqlen_k, - num_heads, - num_heads_k, - head_size, - dropout, - softmax_scale, - causal, - return_softmax, - q.dtype(), - is_test, - rng_name, - fixed_seed_offset, - attn_mask, - softmax, - softmax_lse, - seed_offset); + FlashAttnFwdParamsV2 params = + FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + softmax_scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + attn_mask_start_row, + fixed_seed_offset, + attn_mask, + attn_mask_start_row_indices, + softmax, + softmax_lse, + seed_offset); VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.dims() << "], k.shape=[" << k.dims() << "], v.shape=[" << v.dims() << "]"; @@ -223,13 +232,92 @@ void FlashAttnKernel(const Context& ctx, params.seed, params.offset, params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, - params.mask_dims.data()); + params.mask_dims.data(), + params.attn_mask_start_row_indices_tensor + ? params.attn_mask_start_row_indices_tensor->data() + : nullptr, + params.attn_mask_start_row_indices_tensor + ? params.attn_mask_start_row_indices_dims.data() + : nullptr, + params.attn_mask_start_row); CheckFlashAttnStatus(succ); #else RaiseNotSupportedError(); #endif } +template +void FlashAttnKernel(const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { + FlashAttnBaseKernel(ctx, + q, + k, + v, + fixed_seed_offset, + attn_mask, + paddle::none, + dropout, + causal, + return_softmax, + is_test, + rng_name, + 0, + out, + softmax, + softmax_lse, + seed_offset); +} + +template +void FlashAttnWithSparseMaskKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& attn_mask_start_row_indices, + const paddle::optional& fixed_seed_offset, + float dropout, + bool causal, + int attn_mask_start_row, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { + FlashAttnBaseKernel(ctx, + q, + k, + v, + fixed_seed_offset, + paddle::none, + attn_mask_start_row_indices, + dropout, + causal, + return_softmax, + is_test, + rng_name, + attn_mask_start_row, + out, + softmax, + softmax_lse, + seed_offset); +} + } // namespace phi PD_REGISTER_KERNEL(flash_attn_unpadded, @@ -251,3 +339,13 @@ PD_REGISTER_KERNEL(flash_attn, kernel->InputAt(3).SetBackend( phi::Backend::ALL_BACKEND); // fixed_seed_offset } + +PD_REGISTER_KERNEL(flash_attn_with_sparse_mask, + GPU, + ALL_LAYOUT, + phi::FlashAttnWithSparseMaskKernel, + phi::dtype::float16, + phi::dtype::bfloat16) { + kernel->InputAt(4).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} diff --git a/paddle/phi/kernels/gpu/flash_attn_utils.h b/paddle/phi/kernels/gpu/flash_attn_utils.h index 8fdc51f1d1eeb..1cb99dbb98207 100644 --- a/paddle/phi/kernels/gpu/flash_attn_utils.h +++ b/paddle/phi/kernels/gpu/flash_attn_utils.h @@ -78,6 +78,58 @@ static std::vector GetAttnMaskDims(const DenseTensor* attn_mask) { return mask_dim_4d; } +static std::vector GetAttnSparseMaskDims( + const DenseTensor* attn_mask_start_row_indices, + int64_t attn_mask_start_row, + int max_seqlen_q) { + std::vector mask_dim_3d; + if (attn_mask_start_row_indices) { + const auto& dtype = attn_mask_start_row_indices->dtype(); + const auto& origin_dims = attn_mask_start_row_indices->dims(); + auto rank = origin_dims.size(); + PADDLE_ENFORCE_EQ(dtype, + DataType::INT32, + phi::errors::InvalidArgument( + "dtype of attn_mask_start_row_indices must be " + "int32, but recieved %d", + dtype)); + PADDLE_ENFORCE_GE( + rank, + 3, + phi::errors::InvalidArgument( + "The number of dimenstions of attn_mask_start_row_indices is " + "expected to be greater or " + "equal to 3, but recieved %d. The shape of " + "attn_mask_start_row_indices is [%s]", + rank, + origin_dims)); + PADDLE_ENFORCE_EQ(origin_dims[rank - 1], + max_seqlen_q, + phi::errors::InvalidArgument( + "The sparse_mask_dims[%d] of " + "attn_mask_start_row_indices is expected to be " + "equal to %d, but recieved %d.", + rank - 1, + max_seqlen_q, + origin_dims[2])); + PADDLE_ENFORCE_GE(attn_mask_start_row, + 0, + phi::errors::InvalidArgument( + "attn_mask_start_row should be greater or equal than " + "0 when using attn_mask_start_row_indices, " + "but recieved %d.", + attn_mask_start_row)); + + int64_t first_dim = 1; + for (int i = 0; i < rank - 2; i++) { + first_dim *= origin_dims[i]; + } + mask_dim_3d = {first_dim, origin_dims[rank - 2], origin_dims[rank - 1]}; + } + + return mask_dim_3d; +} + struct FlashAttnParamsBase { int batch_size; // for padded kernel, max_seqlen_q and seqlen_q is the same. @@ -100,16 +152,23 @@ struct FlashAttnParamsBase { std::vector mask_dims; const DenseTensor* attn_mask_tensor; - FlashAttnParamsBase(const int _batch_size, - const int64_t _max_seqlen_q, - const int64_t _max_seqlen_k, - const int _num_heads, - const int _num_heads_k, - const int _head_size, - const float _scale, - const bool _causal, - const DataType q_dtype, - const paddle::optional& attn_mask) + const DenseTensor* attn_mask_start_row_indices_tensor; + std::vector attn_mask_start_row_indices_dims; + int attn_mask_start_row; + + FlashAttnParamsBase( + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _scale, + const bool _causal, + const int _attn_mask_start_row, + const DataType q_dtype, + const paddle::optional& attn_mask, + const paddle::optional& attn_mask_start_row_indices) : batch_size(_batch_size), max_seqlen_q(_max_seqlen_q), max_seqlen_k(_max_seqlen_k), @@ -118,7 +177,10 @@ struct FlashAttnParamsBase { head_size(_head_size), softmax_scale(_scale), causal(_causal), - attn_mask_tensor(attn_mask.get_ptr()) { + attn_mask_start_row(_attn_mask_start_row), + attn_mask_tensor(attn_mask.get_ptr()), + attn_mask_start_row_indices_tensor( + attn_mask_start_row_indices.get_ptr()) { is_bf16 = q_dtype == DataType::BFLOAT16; auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; @@ -142,6 +204,15 @@ struct FlashAttnParamsBase { mask_dims = GetAttnMaskDims(attn_mask_tensor); } + + attn_mask_start_row_indices_dims = GetAttnSparseMaskDims( + attn_mask_start_row_indices_tensor, attn_mask_start_row, max_seqlen_q); + + PADDLE_ENFORCE_NE(attn_mask_tensor && attn_mask_start_row_indices, + true, + phi::errors::InvalidArgument( + "attn_mask and attn_mask_start_row_indices cannot be " + "set at same time.")); } }; @@ -156,25 +227,28 @@ struct FlashAttnFwdParamsV2 : public FlashAttnParamsBase { DenseTensor* softmax_lse; DenseTensor* seed_offset; - FlashAttnFwdParamsV2(const GPUContext& ctx, - const int _batch_size, - const int64_t _max_seqlen_q, - const int64_t _max_seqlen_k, - const int _num_heads, - const int _num_heads_k, - const int _head_size, - const float _dropout, - const float _scale, - const bool _causal, - const bool _return_softmax, - const DataType q_dtype, - const bool is_test, - const std::string& rng_name, - const paddle::optional& fixed_seed_offset, - const paddle::optional& attn_mask, - DenseTensor* _softmax, - DenseTensor* _softmax_lse, - DenseTensor* _seed_offset) + FlashAttnFwdParamsV2( + const GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const bool _return_softmax, + const DataType q_dtype, + const bool is_test, + const std::string& rng_name, + const int _attn_mask_start_row, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + const paddle::optional& attn_mask_start_row_indices, + DenseTensor* _softmax, + DenseTensor* _softmax_lse, + DenseTensor* _seed_offset) : FlashAttnParamsBase(_batch_size, _max_seqlen_q, _max_seqlen_k, @@ -183,8 +257,10 @@ struct FlashAttnFwdParamsV2 : public FlashAttnParamsBase { _head_size, _scale, _causal, + _attn_mask_start_row, q_dtype, - attn_mask), + attn_mask, + attn_mask_start_row_indices), dropout(_dropout), return_softmax(_return_softmax), softmax(_softmax), @@ -231,19 +307,22 @@ struct FlashAttnBwdParamsV2 : public FlashAttnParamsBase { DenseTensor dq_accum; DenseTensor rng_state; - FlashAttnBwdParamsV2(const GPUContext& ctx, - const int _batch_size, - const int64_t _max_seqlen_q, - const int64_t _max_seqlen_k, - const int _num_heads, - const int _num_heads_k, - const int _head_size, - const float _dropout, - const float _scale, - const bool _causal, - const DataType q_dtype, - const paddle::optional& attn_mask, - const int64_t* seed_offset_data) + FlashAttnBwdParamsV2( + const GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const int _attn_mask_start_row, + const DataType q_dtype, + const paddle::optional& attn_mask, + const paddle::optional& attn_mask_start_row_indices, + const int64_t* seed_offset_data) : FlashAttnParamsBase(_batch_size, _max_seqlen_q, _max_seqlen_k, @@ -252,8 +331,10 @@ struct FlashAttnBwdParamsV2 : public FlashAttnParamsBase { _head_size, _scale, _causal, + _attn_mask_start_row, q_dtype, - attn_mask), + attn_mask, + attn_mask_start_row_indices), dropout(_dropout) { seed = static_cast(seed_offset_data[0]); offset = static_cast(seed_offset_data[1]); diff --git a/paddle/phi/kernels/gpu/gather_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_grad_kernel.cu index 23c3eb3997257..22a4a065dfb7c 100644 --- a/paddle/phi/kernels/gpu/gather_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_grad_kernel.cu @@ -72,4 +72,6 @@ PD_REGISTER_KERNEL(gather_grad, int64_t, int, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/gather_kernel.cu b/paddle/phi/kernels/gpu/gather_kernel.cu index 931f7b6431d9b..e824480229da3 100644 --- a/paddle/phi/kernels/gpu/gather_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_kernel.cu @@ -74,4 +74,6 @@ PD_REGISTER_KERNEL(gather, uint8_t, int8_t, phi::dtype::float16, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16, + phi::dtype::complex, + phi::dtype::complex) {} diff --git a/paddle/phi/kernels/gpu/gather_tree_kernel.cu b/paddle/phi/kernels/gpu/gather_tree_kernel.cu index 3ae71992d2423..adf892184223e 100644 --- a/paddle/phi/kernels/gpu/gather_tree_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_tree_kernel.cu @@ -37,11 +37,17 @@ __global__ void GatherTree(const T *ids_data, auto parent = parents_data[idx]; for (int step = max_length - 2; step >= 0; step--) { PADDLE_ENFORCE((parent < beam_size), - "The parents must be less than beam size, but received" + "The parents must be less than beam size, but received " "parents %ld is greater than or equal to beam size %ld. ", parent, beam_size); + PADDLE_ENFORCE( + (parent >= 0), + "The parents must be greater than or equal to 0, but received " + "parents %ld is less than 0. ", + parent); + idx = step * batch_size * beam_size + batch * beam_size; out_data[idx + beam] = ids_data[idx + parent]; parent = parents_data[idx + parent]; diff --git a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu index c0454619b657c..c1f635bfdf8aa 100644 --- a/paddle/phi/kernels/gpu/graph_reindex_kernel.cu +++ b/paddle/phi/kernels/gpu/graph_reindex_kernel.cu @@ -67,53 +67,34 @@ std::shared_ptr FillHashTable(const Context& dev_ctx, input, num_input, len_hashtable, keys, key_index); // Get item index count. - auto item_count = - phi::memory_utils::Alloc(place, (num_input + 1) * sizeof(int)); - int* item_count_ptr = reinterpret_cast(item_count->ptr()); -#ifdef PADDLE_WITH_HIP - hipMemset(item_count_ptr, 0, sizeof(int) * (num_input + 1)); -#else - cudaMemset(item_count_ptr, 0, sizeof(int) * (num_input + 1)); -#endif + thrust::device_vector item_count(num_input + 1, 0); GetItemIndexCount<<>>( - input, item_count_ptr, num_input, len_hashtable, keys, key_index); - - size_t temp_storage_bytes = 0; - cub::DeviceScan::ExclusiveSum( - NULL, temp_storage_bytes, item_count_ptr, item_count_ptr, num_input + 1); - auto d_temp_storage = phi::memory_utils::Alloc(place, temp_storage_bytes); - cub::DeviceScan::ExclusiveSum(d_temp_storage->ptr(), - temp_storage_bytes, - item_count_ptr, - item_count_ptr, - num_input + 1); - int total_unique_items = 0; -#ifdef PADDLE_WITH_HIP - hipMemcpy(&total_unique_items, - item_count_ptr + num_input, - sizeof(int), - hipMemcpyDeviceToHost); -#else - cudaMemcpy(&total_unique_items, - item_count_ptr + num_input, - sizeof(int), - cudaMemcpyDeviceToHost); -#endif + input, + thrust::raw_pointer_cast(item_count.data()), + num_input, + len_hashtable, + keys, + key_index); + thrust::exclusive_scan( + item_count.begin(), item_count.end(), item_count.begin()); + + int total_unique_items = item_count[num_input]; auto unique_items = phi::memory_utils::AllocShared(place, total_unique_items * sizeof(T)); T* unique_items_data = reinterpret_cast(unique_items->ptr()); *final_nodes_len = total_unique_items; // Get unique items - FillUniqueItems<<>>(input, - num_input, - len_hashtable, - unique_items_data, - item_count_ptr, - keys, - values, - key_index); + FillUniqueItems<<>>( + input, + num_input, + len_hashtable, + unique_items_data, + thrust::raw_pointer_cast(item_count.data()), + keys, + values, + key_index); return unique_items; } diff --git a/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu index 6e8b12c4b1b90..2b6ceff59afa7 100644 --- a/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu @@ -121,16 +121,13 @@ ComputePositionsWithMask(T coord, coord = ClipIndexesWithMask(coord, size, &grad_clip); *grad_in = (*grad_in) * grad_clip; } else if (padding_mode == PaddingMode::reflect) { - if (align_corners) { - coord = ReflectIndexesWithMask(coord, 0, 2 * (size - 1), &grad_refl); - } else { - coord = ReflectIndexesWithMask(coord, -1, 2 * size - 1, &grad_refl); - } + coord = align_corners + ? ReflectIndexesWithMask(coord, 0, 2 * (size - 1), &grad_refl) + : ReflectIndexesWithMask(coord, -1, 2 * size - 1, &grad_refl); coord = ClipIndexesWithMask(coord, size, &grad_clip); *grad_in = (*grad_in) * grad_refl * grad_clip; } - - return coord; + return SafeDownGradeToIntRange(coord); } template diff --git a/paddle/phi/kernels/gpu/grid_sample_kernel.cu b/paddle/phi/kernels/gpu/grid_sample_kernel.cu index 3809ae7d5c338..8499e371d10cf 100644 --- a/paddle/phi/kernels/gpu/grid_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/grid_sample_kernel.cu @@ -27,16 +27,13 @@ template static __forceinline__ __device__ T Unnormalize(T coord, int size, bool align_corners) { - if (align_corners) { - return ((coord + 1.f) / 2) * (size - 1); - } else { - return ((coord + 1.f) * size - 1) / 2; - } + return align_corners ? ((coord + 1.f) / 2) * (size - 1) + : ((coord + 1.f) * size - 1) / 2; } template static __forceinline__ __device__ T ClipIndexes(T in, int max_value) { - return min(static_cast(max_value), max(in, static_cast(0))); + return min(static_cast(max_value - 1), max(in, static_cast(0))); } template @@ -51,11 +48,7 @@ static __forceinline__ __device__ T ReflectIndexes(T in, in = fabs(in - min); T extra = fmod(in, span); int flips = static_cast(floor(in / span)); - if (flips % 2 == 0) { - return extra + min; - } else { - return span - extra + min; - } + return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even } template @@ -65,16 +58,13 @@ static __forceinline__ __device__ T ComputePositions(T coord, bool align_corners) { coord = Unnormalize(coord, size, align_corners); if (padding_mode == PaddingMode::border) { - coord = ClipIndexes(coord, size - 1); + coord = ClipIndexes(coord, size); } else if (padding_mode == PaddingMode::reflect) { - if (align_corners) { - coord = ReflectIndexes(coord, 0, 2 * (size - 1)); - } else { - coord = ReflectIndexes(coord, -1, 2 * size - 1); - } - coord = ClipIndexes(coord, size - 1); + coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1)) + : ReflectIndexes(coord, -1, 2 * size - 1); + coord = ClipIndexes(coord, size); } - return coord; + return SafeDownGradeToIntRange(coord); } template diff --git a/paddle/phi/kernels/gpu/grid_sample_utils.h b/paddle/phi/kernels/gpu/grid_sample_utils.h index bd5e859a59d1d..415305efaa105 100644 --- a/paddle/phi/kernels/gpu/grid_sample_utils.h +++ b/paddle/phi/kernels/gpu/grid_sample_utils.h @@ -14,6 +14,8 @@ #pragma once +#include + namespace phi { enum class Mode { @@ -21,6 +23,13 @@ enum class Mode { nearest, }; +template +__forceinline__ __device__ T SafeDownGradeToIntRange(T x) { + bool unsafe_cond = + x > INT_MAX - 1 || x < INT_MIN || !::isfinite(static_cast(x)); + return unsafe_cond ? static_cast(-100.0) : x; +} + enum class PaddingMode { zeros, border, reflect }; static __forceinline__ __device__ bool InBounds(int h, int w, int H, int W) { diff --git a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu index 33de3c8e17876..9773db68362e8 100644 --- a/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu +++ b/paddle/phi/kernels/gpu/matrix_rank_tol_kernel.cu @@ -361,7 +361,7 @@ void MatrixRankTolKernel(const Context& dev_ctx, rtol_T = std::numeric_limits::epsilon() * std::max(rows, cols); } - // Must Copy X once, because the gesvdj will destory the content when exit. + // Must Copy X once, because the gesvdj will destroy the content when exit. DenseTensor x_tmp; phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, &x_tmp); auto info = phi::memory_utils::Alloc( diff --git a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu index c2989e6e6075f..61508285038a3 100644 --- a/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_grad_kernel.cu @@ -30,17 +30,13 @@ inline int GET_BLOCKS(const int N) { } template -__global__ void KernelNanmedianGrad(const T* x_data, - const int64_t* medians_ptr, - const T* out_grad_ptr, - T* dx_data, - int64_t stride, - int64_t pre_dim) { +__global__ void KernelNanmedianMeanGrad(const int64_t* medians_ptr, + const T* out_grad_ptr, + T* dx_data, + int64_t stride, + int64_t pre_dim) { CUDA_KERNEL_LOOP(index, pre_dim) { int64_t offset = index * stride; - printf("index: %d\n", index); - printf("medians_ptr[2 * index]: %d\n", medians_ptr[2 * index]); - printf("medians_ptr[2 * index+1]: %d\n", medians_ptr[2 * index + 1]); if (medians_ptr[2 * index] >= 0) { if (medians_ptr[2 * index] == medians_ptr[2 * index + 1]) { @@ -55,18 +51,34 @@ __global__ void KernelNanmedianGrad(const T* x_data, } } +template +__global__ void KernelNanmedianMinGrad(const int64_t* medians_ptr, + const T* out_grad_ptr, + T* dx_data, + int64_t stride, + int64_t pre_dim) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t offset = index * stride; + + if (medians_ptr[index] >= 0) { + dx_data[offset + medians_ptr[index]] = out_grad_ptr[index]; + } + } +} + template void CalcMedianGradKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& median_index, const DenseTensor& out_grad, + const std::string& mode, DenseTensor* x_grad) { T* dx_data = dev_ctx.template Alloc(x_grad); if (!dx_data) return; phi::funcs::SetConstant set_zero; set_zero(dev_ctx, x_grad, static_cast(0)); - VLOG(0) << "x_grad->dims(): " << x_grad->dims(); + // VLOG(0) << "x_grad->dims(): " << x_grad->dims(); auto stream = dev_ctx.stream(); const T* x_data = x.data(); @@ -79,9 +91,15 @@ void CalcMedianGradKernel(const Context& dev_ctx, int64_t stride = x_dim[x_rank - 1]; int64_t pre_dim = numel / stride; - KernelNanmedianGrad - <<>>( - x_data, m_data, out_grad_ptr, dx_data, stride, pre_dim); + if (mode == "avg") { + KernelNanmedianMeanGrad + <<>>( + m_data, out_grad_ptr, dx_data, stride, pre_dim); + } else { // mode == "min" + KernelNanmedianMinGrad + <<>>( + m_data, out_grad_ptr, dx_data, stride, pre_dim); + } } template @@ -91,6 +109,7 @@ void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, const IntArray& axes, bool keepdim UNUSED, + const std::string& mode, DenseTensor* x_grad) { DenseTensor tmp_x; auto rank = x.dims().size(); @@ -98,14 +117,14 @@ void NanmedianGradKernel(const Context& dev_ctx, tmp_x = x; tmp_x.Resize({x.numel()}); CalcMedianGradKernel( - dev_ctx, tmp_x, median_index, out_grad, x_grad); + dev_ctx, tmp_x, median_index, out_grad, mode, x_grad); } else { funcs::PreprocessMedianKernel(dev_ctx, x, axes, &tmp_x); DenseTensor tmp_x_grad; tmp_x_grad.Resize(x_grad->dims()); CalcMedianGradKernel( - dev_ctx, tmp_x, median_index, out_grad, &tmp_x_grad); + dev_ctx, tmp_x, median_index, out_grad, mode, &tmp_x_grad); dev_ctx.template Alloc(x_grad); funcs::PostprocessMedianGradKernel( diff --git a/paddle/phi/kernels/gpu/nanmedian_kernel.cu b/paddle/phi/kernels/gpu/nanmedian_kernel.cu index 01144442f3904..87f948152ac8d 100644 --- a/paddle/phi/kernels/gpu/nanmedian_kernel.cu +++ b/paddle/phi/kernels/gpu/nanmedian_kernel.cu @@ -69,14 +69,14 @@ __global__ void KernelNanCounts(const T* input, } template -__global__ void CalcMedianKernel(const T* sort_out_ptr, - const int64_t* sort_indices_ptr, - int64_t* median_val, - T* output, - T div_factor, - const bool is_odd, - const int64_t pre_dim, - const int64_t stride) { +__global__ void CalcMedianMeanKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, + int64_t* median_val, + T* output, + T div_factor, + const bool is_odd, + const int64_t pre_dim, + const int64_t stride) { CUDA_KERNEL_LOOP(index, pre_dim) { int64_t pos = static_cast((index + 1) * stride) - 1; if (is_odd) { @@ -84,28 +84,51 @@ __global__ void CalcMedianKernel(const T* sort_out_ptr, median_val[index * 2 + 1] = sort_indices_ptr[pos]; output[index] = sort_out_ptr[pos]; } else { + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T median_val_right = sort_out_ptr[pos]; median_val[index * 2] = pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; median_val[index * 2 + 1] = sort_indices_ptr[pos]; - T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; - T median_val_right = sort_out_ptr[pos]; output[index] = (median_val_left + median_val_right) / div_factor; } } } template -__global__ void CalcNanmedianKernel(const T* sort_out_ptr, +__global__ void CalcMedianMinKernel(const T* sort_out_ptr, const int64_t* sort_indices_ptr, - int64_t* nan_counts, int64_t* median_val, T* output, + T div_factor, const bool is_odd, const int64_t pre_dim, - const int64_t max_valid_num, - const int64_t stride, - const T div_factor, - const T nan_val) { + const int64_t stride) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t pos = static_cast((index + 1) * stride) - 1; + if (is_odd) { + median_val[index] = sort_indices_ptr[pos]; + output[index] = sort_out_ptr[pos]; + } else { + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + median_val[index] = + pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + output[index] = median_val_left; + } + } +} + +template +__global__ void CalcNanmedianMeanKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, + int64_t* nan_counts, + int64_t* median_val, + T* output, + const bool is_odd, + const int64_t pre_dim, + const int64_t max_valid_num, + const int64_t stride, + const T div_factor, + const T nan_val) { CUDA_KERNEL_LOOP(index, pre_dim) { int64_t pos = static_cast(index * max_valid_num); int64_t nan_cnt = nan_counts[index]; @@ -124,20 +147,58 @@ __global__ void CalcNanmedianKernel(const T* sort_out_ptr, median_val[index * 2 + 1] = sort_indices_ptr[pos]; output[index] = sort_out_ptr[pos]; } else { + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + T median_val_right = sort_out_ptr[pos]; median_val[index * 2] = pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; median_val[index * 2 + 1] = sort_indices_ptr[pos]; - T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; - T median_val_right = sort_out_ptr[pos]; output[index] = (median_val_left + median_val_right) / div_factor; } } } } +template +__global__ void CalcNanmedianMinKernel(const T* sort_out_ptr, + const int64_t* sort_indices_ptr, + int64_t* nan_counts, + int64_t* median_val, + T* output, + const bool is_odd, + const int64_t pre_dim, + const int64_t max_valid_num, + const int64_t stride, + const T div_factor, + const T nan_val) { + CUDA_KERNEL_LOOP(index, pre_dim) { + int64_t pos = static_cast(index * max_valid_num); + int64_t nan_cnt = nan_counts[index]; + if (nan_cnt == stride) { + median_val[index] = -1; + output[index] = nan_val; + } else { + int64_t nan_k = + nan_cnt > 0 ? static_cast(stride - nan_cnt) : max_valid_num; + int64_t row_pos = static_cast(nan_k >> 1); + pos += row_pos; + + if (nan_k & 1) { + median_val[index] = sort_indices_ptr[pos]; + output[index] = sort_out_ptr[pos]; + } else { + T median_val_left = pos > 0 ? sort_out_ptr[pos - 1] : sort_out_ptr[pos]; + median_val[index] = + pos > 0 ? sort_indices_ptr[pos - 1] : sort_indices_ptr[pos]; + output[index] = median_val_left; + } + } + } +} + template void ProcessMedianKernel(const Context& dev_ctx, const DenseTensor& x, + const std::string& mode, DenseTensor* out, DenseTensor* median_index) { auto stream = dev_ctx.stream(); @@ -231,30 +292,59 @@ void ProcessMedianKernel(const Context& dev_ctx, T div_factor = static_cast(2.0); T nan_val = std::numeric_limits::quiet_NaN(); if (ignore_nan) { - CalcNanmedianKernel - <<>>( - sort_out_ptr, - sort_indices_ptr, - nan_counts_ptr, - m_data, - out_data, - is_ori_odd, - pre_dim, - max_valid_num, - stride, - div_factor, - nan_val); + if (mode == "avg") { + CalcNanmedianMeanKernel + <<>>( + sort_out_ptr, + sort_indices_ptr, + nan_counts_ptr, + m_data, + out_data, + is_ori_odd, + pre_dim, + max_valid_num, + stride, + div_factor, + nan_val); + } else { // mode == "min" + CalcNanmedianMinKernel + <<>>( + sort_out_ptr, + sort_indices_ptr, + nan_counts_ptr, + m_data, + out_data, + is_ori_odd, + pre_dim, + max_valid_num, + stride, + div_factor, + nan_val); + } } else { - CalcMedianKernel - <<>>( - sort_out_ptr, - sort_indices_ptr, - m_data, - out_data, - div_factor, - is_ori_odd, - pre_dim, - sort_k); + if (mode == "avg") { + CalcMedianMeanKernel + <<>>( + sort_out_ptr, + sort_indices_ptr, + m_data, + out_data, + div_factor, + is_ori_odd, + pre_dim, + sort_k); + } else { // mode == "min" + CalcMedianMinKernel + <<>>( + sort_out_ptr, + sort_indices_ptr, + m_data, + out_data, + div_factor, + is_ori_odd, + pre_dim, + sort_k); + } } } @@ -263,6 +353,7 @@ void NanmedianKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, bool keepdim, + const std::string& mode, DenseTensor* out, DenseTensor* median_index) { DenseTensor tmp_x; @@ -274,7 +365,7 @@ void NanmedianKernel(const Context& dev_ctx, funcs::PreprocessMedianKernel(dev_ctx, x, axes, &tmp_x); } - ProcessMedianKernel(dev_ctx, tmp_x, out, median_index); + ProcessMedianKernel(dev_ctx, tmp_x, mode, out, median_index); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/rms_norm_funcs.h b/paddle/phi/kernels/gpu/rms_norm_funcs.h index a9601d7ce800e..2bf035d30e1dc 100644 --- a/paddle/phi/kernels/gpu/rms_norm_funcs.h +++ b/paddle/phi/kernels/gpu/rms_norm_funcs.h @@ -12,6 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + #pragma once #include diff --git a/paddle/phi/kernels/gpu/rms_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/rms_norm_grad_kernel.cu index bfc73faf21b9b..fab312470fe9f 100644 --- a/paddle/phi/kernels/gpu/rms_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/rms_norm_grad_kernel.cu @@ -12,6 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */ +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + #include #include "paddle/phi/backends/gpu/cuda/cudnn_helper.h" #include "paddle/phi/backends/gpu/gpu_context.h" diff --git a/paddle/phi/kernels/gpu/rnn_functor.h b/paddle/phi/kernels/gpu/rnn_functor.h index 359218bbcb75f..8870f7d407c57 100644 --- a/paddle/phi/kernels/gpu/rnn_functor.h +++ b/paddle/phi/kernels/gpu/rnn_functor.h @@ -75,7 +75,30 @@ class RNNDescriptors { y_descs_.emplace_back(y_desc_.descriptor(dims_y, strides_y)); } -#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 +#if CUDNN_VERSION >= 90000 + auto seqlen_is_empty = sequence_length.empty(); + if (seqlen_is_empty) { + std::vector seqlen_array(batch_size_); + for (int i = 0; i < batch_size_; ++i) { + seqlen_array[i] = seq_length_; + } + x_seq_desc_.descriptor( + seq_length_, batch_size_, input_size_, true, seqlen_array); + y_seq_desc_.descriptor(seq_length_, + batch_size_, + hidden_size_ * numDirections, + true, + seqlen_array); + } else { + x_seq_desc_.descriptor( + seq_length_, batch_size_, input_size_, true, sequence_length); + y_seq_desc_.descriptor(seq_length_, + batch_size_, + hidden_size_ * numDirections, + true, + sequence_length); + } +#elif defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { x_seq_desc_.descriptor( seq_length_, batch_size_, input_size_, true, sequence_length); @@ -97,7 +120,7 @@ class RNNDescriptors { last_c_desc_.descriptor(dims_hx, strides_hx); // ------------------- cudnn dropout descriptors --------------------- - size_t state_size; + size_t state_size = 0; bool is_initialized = dropout_state->initialized(); #ifdef PADDLE_WITH_HIP if (!is_initialized) { @@ -148,6 +171,24 @@ class RNNDescriptors { miopenRNNwithBias, miopenRNNdefault, cudnn_type)); +#elif CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNDescriptor_v8( + rnn_desc_.desc(), + CUDNN_RNN_ALGO_STANDARD, + mode_, + CUDNN_RNN_DOUBLE_BIAS, + is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, + CUDNN_LINEAR_INPUT, + cudnn_type, + cudnn_type, + CUDNN_DEFAULT_MATH, + input_size_, + hidden_size_, + hidden_size_, + num_layers_, + dropout_desc_.desc(), + seqlen_is_empty ? CUDNN_RNN_PADDED_IO_DISABLED + : CUDNN_RNN_PADDED_IO_ENABLED)); #elif CUDNN_VERSION >= 6000 PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNDescriptor_v6( handle, @@ -172,7 +213,7 @@ class RNNDescriptors { cudnn_type)); #endif -#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION >= 7201 +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION < 90000 && CUDNN_VERSION >= 7201 if (!sequence_length.empty()) { PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSetRNNPaddingMode( rnn_desc_.desc(), CUDNN_RNN_PADDED_IO_ENABLED)); @@ -180,14 +221,17 @@ class RNNDescriptors { #endif // ------------------- cudnn weights_size --------------------- - size_t weights_size_; #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenGetRNNParamsSize( handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); +#elif CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNWeightSpaceSize( + handle, rnn_desc_.desc(), &weights_size_)); #else PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnGetRNNParamsSize( handle, rnn_desc_.desc(), x_descs_[0], &weights_size_, cudnn_type)); #endif + PADDLE_ENFORCE_EQ( weights_size_, sizeof(T) * weight_numel_, @@ -208,6 +252,14 @@ class RNNDescriptors { workspace_size)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::miopenGetRNNTrainingReserveSize( handle, rnn_desc_.desc(), seq_length_, x_descs_.data(), reserve_size)); +#elif CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnGetRNNTempSpaceSizes(handle, + rnn_desc_.desc(), + CUDNN_FWD_MODE_TRAINING, + x_seq_desc_.desc(), + workspace_size, + reserve_size)); #else PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::cudnnGetRNNWorkspaceSize(handle, @@ -244,6 +296,7 @@ class RNNDescriptors { cudnnDropoutDescriptor_t dropout_desc() { return dropout_desc_.desc(); } cudnnFilterDescriptor_t weight_desc() { return weight_desc_.desc(); } #endif + size_t weights_size() { return weights_size_; } private: int seq_length_; @@ -257,6 +310,7 @@ class RNNDescriptors { gpuRNNMode_t mode_; bool is_bidirec_; bool is_test_; + size_t weights_size_; #ifdef PADDLE_WITH_HIP std::vector x_descs_; std::vector y_descs_; diff --git a/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc index 3e8dfe813cad7..caf00a61fa7f9 100644 --- a/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/rnn_grad_kernel.cu.cc @@ -256,6 +256,55 @@ void RnnGradKernel(const Context &dev_ctx, Empty(dev_ctx, {static_cast(workspace_size)}); const uint8_t *reserve_data = reserve.data(); +#if CUDNN_VERSION >= 90000 + if (x_grad) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardData_v8( + handle, + rnn.rnn_desc(), + nullptr, + rnn.y_seq_desc(), + out_data, + out_grad_data, + rnn.x_seq_desc(), + x_grad_data, + rnn.init_h_desc(), + init_h_data, + last_h_grad_data, + init_h_grad_data, + rnn.init_c_desc(), + init_c_data, + last_c_grad_data, + init_c_grad_data, + rnn.weights_size(), + weight_data, + workspace_size, + workspace_data_.data(), + reserve_size, + const_cast(reserve_data))); + } + + if (!weight_grad_list.empty()) { + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnRNNBackwardWeights_v8( + handle, + rnn.rnn_desc(), + CUDNN_WGRAD_MODE_ADD, + nullptr, + rnn.x_seq_desc(), + x.data(), + rnn.init_h_desc(), + init_h_data, + rnn.y_seq_desc(), + out.data(), + rnn.weights_size(), + weight_grad_data, + workspace_size, + workspace_data_.data(), + reserve_size, + const_cast(reserve_data))); + } + +#else + if (!has_seq_length) { if (x_grad) { #ifdef PADDLE_WITH_HIP @@ -421,6 +470,8 @@ void RnnGradKernel(const Context &dev_ctx, "of cudnn is larger than 7.2.1")); #endif } + +#endif // end CUDNN_VERSION >= 90000 } } // namespace phi diff --git a/paddle/phi/kernels/gpu/rnn_kernel.cu.cc b/paddle/phi/kernels/gpu/rnn_kernel.cu.cc index 82800607bae9d..c098e2db2413a 100644 --- a/paddle/phi/kernels/gpu/rnn_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/rnn_kernel.cu.cc @@ -39,6 +39,31 @@ void RNNInferece(bool has_seq_length, T *last_c_data, DenseTensor *workspace_data, size_t workspace_size) { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForward(handle, + rnn->rnn_desc(), + CUDNN_FWD_MODE_INFERENCE, + nullptr, + rnn->x_seq_desc(), + x_data, + rnn->y_seq_desc(), + out_data, + rnn->init_h_desc(), + init_h_data, + last_h_data, + rnn->init_c_desc(), + init_c_data, + last_c_data, + rnn->weights_size(), + w_data, + workspace_size, + workspace_data->data(), + 0, + nullptr)); + +#else + if (!has_seq_length) { // for inference // This interface is used when the input/output is unpadded. @@ -124,6 +149,8 @@ void RNNInferece(bool has_seq_length, "the version of cudnn is larger than 7.2.1")); #endif } + +#endif // end CUDNN_VERSION >= 90000 } template @@ -305,6 +332,30 @@ void RnnKernel(const Context &dev_ctx, &workspace_data_, workspace_size); } else { +#if CUDNN_VERSION >= 90000 + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnRNNForward(handle, + rnn.rnn_desc(), + CUDNN_FWD_MODE_TRAINING, + nullptr, + rnn.x_seq_desc(), + x_data, + rnn.y_seq_desc(), + out_data, + rnn.init_h_desc(), + init_h_data, + last_h_data, + rnn.init_c_desc(), + init_c_data, + last_c_data, + rnn.weights_size(), + w_data, + workspace_size, + workspace_data_.data(), + reserve_size, + reserve_data)); +#else + if (!has_seq_length) { // for train // This interface is used when the input/output is unpadded. @@ -395,6 +446,7 @@ void RnnKernel(const Context &dev_ctx, "the version of cudnn is larger than 7.2.1")); #endif } +#endif // end CUDNN_VERSION >= 90000 } } diff --git a/paddle/phi/kernels/gpu/scale_kernel.cu b/paddle/phi/kernels/gpu/scale_kernel.cu index 871ccaec19ee4..447e229977c21 100644 --- a/paddle/phi/kernels/gpu/scale_kernel.cu +++ b/paddle/phi/kernels/gpu/scale_kernel.cu @@ -45,7 +45,7 @@ template void ScaleKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, DenseTensor* out) { using MT = typename phi::dtype::MPTypeTrait::Type; @@ -61,8 +61,7 @@ void ScaleKernel(const Context& dev_ctx, dev_ctx, inputs, &outputs, - ScaleFunctor( - scale.to(), static_cast(bias), bias_after_scale)); + ScaleFunctor(scale.to(), bias.to(), bias_after_scale)); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/shuffle_batch_utils.h b/paddle/phi/kernels/gpu/shuffle_batch_utils.h index 3a7c2230d3213..dfcbcf5716f04 100644 --- a/paddle/phi/kernels/gpu/shuffle_batch_utils.h +++ b/paddle/phi/kernels/gpu/shuffle_batch_utils.h @@ -27,7 +27,7 @@ struct CacheAllocator { place_ = place; } - ~CacheAllocator() { VLOG(2) << "destory allocator"; } + ~CacheAllocator() { VLOG(2) << "destroy allocator"; } char* allocate(std::ptrdiff_t num_bytes) { VLOG(2) << "allocate " << num_bytes << " bytes"; diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 1d93ef1a2790f..d946bc50adfca 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -117,7 +117,7 @@ void TopkKernel(const Context& dev_ctx, out, indices, largest)) { - // Successed, return. + // Succeed, return. return; } else { VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use " diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu index 323c228c16039..809d28ee616e6 100644 --- a/paddle/phi/kernels/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -31,10 +31,10 @@ void TransposeKernel(const Context& ctx, const std::vector& axis, DenseTensor* out) { size_t x_rank = x.dims().size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + x_rank; + formatted_axis[i] = axis[i] + x_rank; } } @@ -42,11 +42,11 @@ void TransposeKernel(const Context& ctx, if (out->numel() == 0) { return; } - if (formated_axis.size() == 0) { + if (formatted_axis.size() == 0) { phi::Copy(ctx, x, ctx.GetPlace(), false, out); return; } - phi::funcs::TransposeGPUKernelDriver(ctx, x, formated_axis, out); + phi::funcs::TransposeGPUKernelDriver(ctx, x, formatted_axis, out); } } // namespace phi diff --git a/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.h b/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.h index 264491214d2c7..dcb031311ffaa 100644 --- a/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.h +++ b/paddle/phi/kernels/gpudnn/mha_cudnn_frontend.h @@ -17,13 +17,14 @@ #ifdef PADDLE_WITH_CUDNN_FRONTEND #include "paddle/phi/backends/dynload/cudnn_frontend.h" -#define CUDNN_CALL(func) \ - { \ - auto status = func; \ - if (status != CUDNN_STATUS_SUCCESS) { \ - LOG(FATAL) << "CUDNN Error : " \ - << phi::dynload::cudnnGetErrorString(status); \ - } \ +#define CUDNN_CALL(func) \ + { \ + auto status = func; \ + if (status != CUDNN_STATUS_SUCCESS) { \ + std::stringstream ss; \ + ss << "CUDNN Error : " << phi::dynload::cudnnGetErrorString(status); \ + PADDLE_THROW(phi::errors::Fatal(ss.str())); \ + } \ } enum class MHA_Layout { diff --git a/paddle/phi/kernels/gpudnn/pool_kernel.cu b/paddle/phi/kernels/gpudnn/pool_kernel.cu index 5bd1e2d6a12c1..c6cd7151003d8 100644 --- a/paddle/phi/kernels/gpudnn/pool_kernel.cu +++ b/paddle/phi/kernels/gpudnn/pool_kernel.cu @@ -142,8 +142,8 @@ void PoolRawGPUDNNKernel(const Context& ctx, transformed_output = *output; } - const T* tranformed_input_data = transformed_input.data(); - T* tranformed_output_data = ctx.template Alloc(&transformed_output); + const T* transformed_input_data = transformed_input.data(); + T* transformed_output_data = ctx.template Alloc(&transformed_output); // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; @@ -192,10 +192,10 @@ void PoolRawGPUDNNKernel(const Context& ctx, cudnn_pool_desc, &alpha, cudnn_input_desc, - tranformed_input_data, + transformed_input_data, &beta, cudnn_output_desc, - tranformed_output_data, + transformed_output_data, false, pool_workspace, pool_workernel_size_)); @@ -206,10 +206,10 @@ void PoolRawGPUDNNKernel(const Context& ctx, cudnn_pool_desc, &alpha, cudnn_input_desc, - tranformed_input_data, + transformed_input_data, &beta, cudnn_output_desc, - tranformed_output_data)); + transformed_output_data)); #endif // add if (data_format == str_NDHWC) { diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 5d61322e336dd..d93690a78baf5 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -772,7 +772,6 @@ void SwitchWarpSoftmaxForward(const IndexType blocks, SOFTMAX_WARP_FORWARD_CASE(7, AccT); SOFTMAX_WARP_FORWARD_CASE(8, AccT); SOFTMAX_WARP_FORWARD_CASE(9, AccT); - SOFTMAX_WARP_FORWARD_CASE(10, AccT); default: PADDLE_THROW(phi::errors::Unimplemented( "Unsupported softmax dim: element_count=%d, log2_element_count=%d!", @@ -815,7 +814,6 @@ void SwitchWarpSoftmaxBackward(const int blocks, SOFTMAX_WARP_BACKWARD_CASE(7, AccT); SOFTMAX_WARP_BACKWARD_CASE(8, AccT); SOFTMAX_WARP_BACKWARD_CASE(9, AccT); - SOFTMAX_WARP_BACKWARD_CASE(10, AccT); default: // PADDLE_THROW(phi::errors::Unimplemented( // "Unsupported softmax dim: element_count=%d, @@ -1228,7 +1226,7 @@ bool UseCudnnSoftmax(const GPUContext& ctx, #endif } } - constexpr int max_dim = 1024; + constexpr int max_dim = 512; if (!cudnn_available || !last_dim || (softmax_dim <= max_dim && sizeof(T) <= 4)) { return false; @@ -1271,7 +1269,27 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, using T4 = typename VecT4::Type; using T2 = typename VecT2::Type; - if (std::is_same::value) { + if (dim % 4 == 0) { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } else if (dim % 2 == 0) { + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); + } else { SwitchWarpSoftmaxForward(blocks, threads, dev_ctx, @@ -1281,38 +1299,6 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, dim, dim, dim_log2); - } else { - if (dim % 4 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); - } else if (dim % 2 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); - } else { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); - } } } else { LaunchSoftmaxForwardCudnnKernel(dev_ctx, x, axis, LogMode, out); diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index c4bb7676381f7..3ba4b42a2eb77 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -669,7 +669,7 @@ void SquareDoubleGradKernel(const Context& dev_ctx, template void SinDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const paddle::optional& dout, + const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { @@ -680,7 +680,7 @@ void SinDoubleGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(ddout); } phi::funcs::SinDoubleGradFunctor functor; - functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout); + functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } template @@ -717,7 +717,7 @@ void SinTripleGradKernel(const Context& dev_ctx, template void CosDoubleGradKernel(const Context& dev_ctx, const DenseTensor& x, - const paddle::optional& dout, + const DenseTensor& dout, const DenseTensor& ddx, DenseTensor* dx, DenseTensor* ddout) { @@ -728,7 +728,7 @@ void CosDoubleGradKernel(const Context& dev_ctx, dev_ctx.template Alloc(ddout); } phi::funcs::CosDoubleGradFunctor functor; - functor(dev_ctx, &x, dout.get_ptr(), &ddx, dx, ddout); + functor(dev_ctx, &x, &dout, &ddx, dx, ddout); } template diff --git a/paddle/phi/kernels/impl/data_impl.h b/paddle/phi/kernels/impl/data_impl.h index c5d2f7b309592..fb089d1664535 100644 --- a/paddle/phi/kernels/impl/data_impl.h +++ b/paddle/phi/kernels/impl/data_impl.h @@ -39,6 +39,15 @@ void ShadowFeedKernel(const Context& ctx, } } +template +void ShadowFeedTensorsKernel(const Context& ctx, + const std::vector& xs, + std::vector outs) { + for (size_t i = 0; i < xs.size(); ++i) { + ShadowFeedKernel(ctx, *(xs[i]), outs[i]); + } +} + template void PrintKernel(const Context& ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h index db6858bc9d7d7..16b927e83aabe 100644 --- a/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h @@ -21,10 +21,12 @@ limitations under the License. */ #include "paddle/phi/common/float16.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/full_kernel.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/elementwise_utils.h" namespace phi { @@ -65,26 +67,63 @@ void AddDoubleGradImpl(const Context& dev_ctx, DenseTensor* ddout) { // ddOut = ddx + ddy if (ddout) { - DenseTensor ddx_safe, ddy_safe; - funcs::GetDoubleGradSafeTensor( - dev_ctx, dout, ddx.get_ptr(), &ddx_safe); - funcs::GetDoubleGradSafeTensor( - dev_ctx, y, ddy.get_ptr(), &ddy_safe); - + auto* ddx_tensor = ddx.get_ptr(); + auto* ddy_tensor = ddy.get_ptr(); + auto out_shape = dout.dims(); dev_ctx.template Alloc(ddout); - auto ddx_dims = ddx_safe.dims(); - auto ddy_dims = ddy_safe.dims(); - if (ddx_dims.size() >= ddy_dims.size()) { - funcs::ElementwiseCompute, T>( - dev_ctx, ddx_safe, ddy_safe, funcs::AddFunctor(), ddout, axis); + if (ddx_tensor == nullptr && ddy_tensor == nullptr) { + VLOG(4) << "Special case when ddx and ddy are not needed \n"; + ddout = nullptr; + } else if (ddx_tensor == nullptr && ddy_tensor != nullptr) { + if (ddy_tensor->dims() != out_shape) { + VLOG(4) << "Special case when ddx is not needed and ddy needs to " + "broadcast\n"; + std::vector ins = {ddy_tensor}; + std::vector outs = {ddout}; + ExpandKernel(dev_ctx, + *ddy_tensor, + IntArray{phi::vectorize(out_shape)}, + ddout); + } else { + VLOG(4) << "Special case when ddx is not needed and ddy doesn't need " + "to broadcast\n"; + phi::Copy(dev_ctx, *ddy_tensor, dev_ctx.GetPlace(), false, ddout); + } + } else if (ddx_tensor != nullptr && ddy_tensor == nullptr) { + if (ddx_tensor->dims() != out_shape) { + VLOG(4) << "Special case when ddy is not needed and ddx need to " + "broadcast\n"; + std::vector ins = {ddx_tensor}; + std::vector outs = {ddout}; + ExpandKernel(dev_ctx, + *ddx_tensor, + IntArray{phi::vectorize(out_shape)}, + ddout); + } else { + VLOG(4) << "Special case when ddx is not needed and ddy doesn't need " + "to broadcast\n"; + phi::Copy(dev_ctx, *ddx_tensor, dev_ctx.GetPlace(), false, ddout); + } } else { - funcs::ElementwiseCompute, T>( - dev_ctx, - ddx_safe, - ddy_safe, - funcs::InverseAddFunctor(), - ddout, - axis); + auto ddx_dims = ddx_tensor->dims(); + auto ddy_dims = ddy_tensor->dims(); + if (ddx_dims.size() >= ddy_dims.size()) { + funcs::ElementwiseCompute, T>( + dev_ctx, + *ddx_tensor, + *ddy_tensor, + funcs::AddFunctor(), + ddout, + axis); + } else { + funcs::ElementwiseCompute, T>( + dev_ctx, + *ddx_tensor, + *ddy_tensor, + funcs::InverseAddFunctor(), + ddout, + axis); + } } } } @@ -157,42 +196,325 @@ struct DivGradDY> { template struct DivDoubleDY { - HOSTDEVICE T operator()(T x, T y, T out, T dout) const { - return y * out * dout - x * dout; + HOSTDEVICE T operator()(const T& x, + const T& y, + const T& out, + const T& dout) const { + return (y * out - x) * dout; } }; +template +struct DivDoubleDY_Only_DDY { + HOSTDEVICE T operator()(const T& x, + const T& y, + const T& out, + const T& dout) const { + return y * out * dout; + } +}; + +template +struct DivDoubleDY_Only_DDX { + HOSTDEVICE T operator()(const T& x, + const T& y, + const T& out, + const T& dout) const { + return -x * dout; + } +}; + +// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y +template +struct DivDoubleDDOut { + HOSTDEVICE T operator()(const T& ddx, + const T& ddy, + const T& y, + const T& out) const { + return (ddx - out * ddy) / y; + } +}; + +template +struct DivDoubleDDOut_Only_DDY { + HOSTDEVICE T operator()(const T& ddx, + const T& ddy, + const T& y, + const T& out) const { + return -out * ddy / y; + } +}; + +template +void ComputeDDoutWithoutBroadcast(const CPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + for (int i = 0; i < out_numel; i++) { + ddout_data[i] = dout_op(ddx_data[i], ddy_data[i], y_data[i], out_data[i]); + } +} + +template +void ComputeDDoutWithBroadcast(const CPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + const int* x_dims_array, + const int* y_dims_array, + const int* out_dims_array, + const int max_dim, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + std::vector index_array(max_dim, 0); + for (int i = 0; i < out_numel; i++) { + int x_index = phi::funcs::GetElementwiseIndex( + x_dims_array, max_dim, index_array.data()); + int y_index = phi::funcs::GetElementwiseIndex( + y_dims_array, max_dim, index_array.data()); + ddout_data[i] = dout_op( + ddx_data[x_index], ddy_data[y_index], y_data[y_index], out_data[i]); + phi::funcs::UpdateElementwiseIndexArray( + out_dims_array, max_dim, index_array.data()); + } +} + +#if defined(__NVCC__) || defined(__HIPCC__) + +template +__global__ void ComputeDDoutWithoutBroadcastGPUKernel(const T* ddx_data, + const T* ddy_data, + const T* y_data, + const T* out_data, + T* ddout_data, + int numel, + DDout_OP dout_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + ddout_data[tid] = + dout_op(ddx_data[tid], ddy_data[tid], y_data[tid], out_data[tid]); +} +template +void ComputeDDoutWithoutBroadcast(const GPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + int block = 512; + int64_t grid = (out_numel + block - 1) / block; + auto stream = reinterpret_cast(dev_ctx).stream(); + ComputeDDoutWithoutBroadcastGPUKernel + <<>>( + ddx_data, ddy_data, y_data, out_data, ddout_data, out_numel, dout_op); +} + +template +__global__ void ComputeDDoutWithBroadcastGPUKernel(const T* ddx_data, + const T* ddy_data, + const T* y_data, + const T* out_data, + T* ddout_data, + int numel, + const int* x_dims_array, + const int* y_dims_array, + const int* out_dims_array, + const int max_dim, + DDout_OP dout_op) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= numel) return; + int x_index = 0, y_index = 0, x_index_prod = 1, y_index_prod = 1, + out_index = tid, dim_index; + for (int64_t i = max_dim - 1; i >= 0; i--) { + if (out_index == 0) break; + dim_index = out_index % out_dims_array[i]; + out_index = out_index / out_dims_array[i]; + if (x_dims_array[i] > 1) { + x_index += dim_index * x_index_prod; + x_index_prod *= x_dims_array[i]; + } + if (y_dims_array[i] > 1) { + y_index += dim_index * y_index_prod; + y_index_prod *= y_dims_array[i]; + } + } + ddout_data[tid] = dout_op( + ddx_data[x_index], ddy_data[y_index], y_data[y_index], out_data[tid]); +} + +template +void ComputeDDoutWithBroadcast(const GPUContext& dev_ctx UNUSED, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + phi::DenseTensor* ddout, + const int* x_dims_array, + const int* y_dims_array, + const int* out_dims_array, + const int max_dim, + DDout_OP dout_op) { + auto out_numel = out.numel(); + auto* ddx_data = ddx.data(); + auto* ddy_data = ddy.data(); + auto* y_data = y.data(); + auto* out_data = out.data(); + auto* ddout_data = ddout->data(); + DenseTensor x_dims_array_gpu; + x_dims_array_gpu.Resize({max_dim}); + int* x_dims_array_gpu_data = dev_ctx.template Alloc(&x_dims_array_gpu); +#if defined(__NVCC__) + cudaMemcpy(x_dims_array_gpu_data, + x_dims_array, + sizeof(int) * max_dim, + cudaMemcpyHostToDevice); +#else + hipMemcpy(x_dims_array_gpu_data, + x_dims_array, + sizeof(int) * max_dim, + hipMemcpyHostToDevice); +#endif + DenseTensor y_dims_array_gpu; + y_dims_array_gpu.Resize({max_dim}); + int* y_dims_array_gpu_data = dev_ctx.template Alloc(&y_dims_array_gpu); +#if defined(__NVCC__) + cudaMemcpy(y_dims_array_gpu_data, + y_dims_array, + sizeof(int) * max_dim, + cudaMemcpyHostToDevice); +#else + hipMemcpy(y_dims_array_gpu_data, + y_dims_array, + sizeof(int) * max_dim, + hipMemcpyHostToDevice); +#endif + DenseTensor out_dims_array_gpu; + out_dims_array_gpu.Resize({max_dim}); + int* out_dims_array_gpu_data = + dev_ctx.template Alloc(&out_dims_array_gpu); +#if defined(__NVCC__) + cudaMemcpy(out_dims_array_gpu_data, + out_dims_array, + sizeof(int) * max_dim, + cudaMemcpyHostToDevice); +#else + hipMemcpy(out_dims_array_gpu_data, + out_dims_array, + sizeof(int) * max_dim, + hipMemcpyHostToDevice); +#endif + int block = 512; + int64_t grid = (out_numel + block - 1) / block; + auto stream = reinterpret_cast(dev_ctx).stream(); + ComputeDDoutWithBroadcastGPUKernel + <<>>(ddx_data, + ddy_data, + y_data, + out_data, + ddout_data, + out_numel, + x_dims_array_gpu_data, + y_dims_array_gpu_data, + out_dims_array_gpu_data, + max_dim, + dout_op); +} + +#endif + +template +void DivDoubleDDoutCompute(const DeviceContext& dev_ctx, + const phi::DenseTensor& ddx, + const phi::DenseTensor& ddy, + const phi::DenseTensor& y, + const phi::DenseTensor& out, + int axis, + phi::DenseTensor* ddout, + DDout_OP dout_op) { + auto x_dims = ddx.dims(); + auto y_dims = ddy.dims(); + if (x_dims == y_dims) { + ComputeDDoutWithoutBroadcast( + dev_ctx, ddx, ddy, y, out, ddout, dout_op); + } else { + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + std::vector x_dims_array(max_dim, 0); + std::vector y_dims_array(max_dim, 0); + std::vector out_dims_array(max_dim, 0); + phi::funcs::GetBroadcastDimsArrays(x_dims, + y_dims, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + axis); + ComputeDDoutWithBroadcast(dev_ctx, + ddx, + ddy, + y, + out, + ddout, + x_dims_array.data(), + y_dims_array.data(), + out_dims_array.data(), + max_dim, + dout_op); + } +} + template void DivideDoubleGradKernel(const Context& dev_ctx, const DenseTensor& y, const DenseTensor& out, - const DenseTensor& dx, + const DenseTensor& grad_out, + const paddle::optional& dx, const paddle::optional& ddx, const paddle::optional& ddy, int axis, DenseTensor* dy, DenseTensor* dout, DenseTensor* ddout) { - if (dy) { - dy->Resize(y.dims()); - dev_ctx.template Alloc(dy); - } - if (dout) { - dout->Resize(out.dims()); - dev_ctx.template Alloc(dout); - } - if (ddout) { - ddout->Resize(out.dims()); - dev_ctx.template Alloc(ddout); + auto* ddx_tensor = ddx.get_ptr(); + auto* ddy_tensor = ddy.get_ptr(); + auto* dx_tensor = dx.get_ptr(); + DenseTensor dz_div_y; + if ((dy || dout) && (!dx_tensor || dx_tensor->dims() != out.dims())) { + dz_div_y.Resize(out.dims()); + dev_ctx.template Alloc(&dz_div_y); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, grad_out, y, &dz_div_y, axis); + dx_tensor = &dz_div_y; } - // ddX_safe == null ? 0 : ddX - // ddY_safe == null ? 0 : ddY - DenseTensor ddX_safe, ddY_safe; - phi::funcs::GetDoubleGradSafeTensor( - dev_ctx, dx, ddx.get_ptr(), &ddX_safe); - phi::funcs::GetDoubleGradSafeTensor( - dev_ctx, y, ddy.get_ptr(), &ddY_safe); - // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y // dY = Out * dX * ddY / Y - dX * ddX / Y // dOut = - dX * ddY @@ -200,69 +522,169 @@ void DivideDoubleGradKernel(const Context& dev_ctx, // inplace ddx DenseTensor tmp; if (dout) { + dout->Resize(out.dims()); + dev_ctx.template Alloc(dout); tmp = *dout; } else { tmp.Resize(out.dims()); dev_ctx.template Alloc(&tmp); } if (dy) { - // dX_div_Y = dX / Y; - DenseTensor dX_div_Y = tmp; - funcs::DefaultElementwiseOperator, - funcs::InverseDivideFunctor>( - dev_ctx, dx, y, &dX_div_Y, axis); - - // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the - // first output tensor is nullptr, the branch to calculate first - // output tensor will not be activated, DivGradDx function will not - // be called and can be ignored, the first branch has little effect - // on running speed. + dy->Resize(y.dims()); + dev_ctx.template Alloc(dy); + if (!ddx_tensor && !ddy_tensor) { + FullLikeKernel( + dev_ctx, y, Scalar(static_cast(0.0)), y.dtype(), dy); + } else { + // pre-compute 'dX / Y' into 'tmp' for 'ddout' and/or 'dy' + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, *dx_tensor, y, &tmp, axis); + if (ddx_tensor && !ddy_tensor) { + // dy = -dX * ddX / Y + phi::funcs::ElemwiseGradCompute, + DivDoubleDY_Only_DDX>( + dev_ctx, + *ddx_tensor, // ddx + y, + out, // out + tmp, // dX /Y + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY_Only_DDX()); + } else if (!ddx_tensor && ddy_tensor) { + // dY = Out * dX * ddY / Y + phi::funcs::ElemwiseGradCompute, + DivDoubleDY_Only_DDY>( + dev_ctx, + *dx_tensor, + *ddy_tensor, // ddy + out, // out + tmp, // dX / Y + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY_Only_DDY()); + } else { + // dY = Out * dX * ddY / Y - dX * ddX / Y - // dY = Out * dX * ddY / Y - dX * ddX / Y - phi::funcs::ElemwiseGradCompute, DivDoubleDY>( - dev_ctx, - ddX_safe, - ddY_safe, - out, - dX_div_Y, - axis, - nullptr, - dy, - DivGradDX(), - DivDoubleDY()); + // NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the + // first output tensor is nullptr, the branch to calculate first + // output tensor will not be activated, DivGradDx function will not + // be called and can be ignored, the first branch has little effect + // on running speed. + phi::funcs:: + ElemwiseGradCompute, DivDoubleDY>( + dev_ctx, + *ddx_tensor, // ddx + *ddy_tensor, // ddy + out, // out + tmp, // dX / Y + axis, + nullptr, + dy, + DivGradDX(), + DivDoubleDY()); + } + } } if (ddout) { + ddout->Resize(out.dims()); + dev_ctx.template Alloc(ddout); // ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, out, ddY_safe, &tmp, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseSubtractFunctor>( - dev_ctx, ddX_safe, tmp, &tmp, axis); - funcs::DefaultElementwiseOperator, - funcs::InverseDivideFunctor>( - dev_ctx, tmp, y, ddout, axis); + if (!ddx_tensor && !ddy_tensor) { + FullLikeKernel( + dev_ctx, out, Scalar(static_cast(0.0)), out.dtype(), ddout); + } else if (ddx_tensor != nullptr && ddy_tensor == nullptr) { + // ddOut = ddX / Y + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, *ddx_tensor, y, ddout, axis); + } else if (!ddx_tensor && ddy_tensor) { +// ddOut = - Out * ddY / Y +#if defined(__xpu__) + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, out, *ddy_tensor, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, tmp, y, ddout, axis); + auto& place = *dev_ctx.eigen_device(); + auto ddout_result = phi::EigenVector::Flatten(*ddout); + ddout_result.device(place) = static_cast(-1) * ddout_result; +#else + DivDoubleDDoutCompute, T>( + dev_ctx, + *dx_tensor, + *ddy_tensor, + y, + out, + axis, + ddout, + DivDoubleDDOut_Only_DDY()); +#endif + } else { +#if defined(__xpu__) + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, out, *ddy_tensor, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseSubtractFunctor>( + dev_ctx, *ddx_tensor, tmp, &tmp, axis); + funcs::DefaultElementwiseOperator, + funcs::InverseDivideFunctor>( + dev_ctx, tmp, y, ddout, axis); +#else + DivDoubleDDoutCompute, T>( + dev_ctx, + *ddx_tensor, + *ddy_tensor, + y, + out, + axis, + ddout, + DivDoubleDDOut()); +#endif + } } if (dout) { - // dOut = - dX * ddY - funcs::DefaultElementwiseOperator, - funcs::InverseMultiplyFunctor>( - dev_ctx, dx, ddY_safe, dout, axis); - auto& place = *dev_ctx.eigen_device(); - auto dout_result = phi::EigenVector::Flatten(*dout); - dout_result.device(place) = static_cast(-1) * dout_result; + if (!ddy_tensor) { + FullLikeKernel( + dev_ctx, out, Scalar(static_cast(0.0)), out.dtype(), dout); + } else { + // dOut = - dX * ddY + funcs::DefaultElementwiseOperator, + funcs::InverseMultiplyFunctor>( + dev_ctx, *dx_tensor, *ddy_tensor, dout, axis); + auto& place = *dev_ctx.eigen_device(); + auto dout_result = phi::EigenVector::Flatten(*dout); + dout_result.device(place) = static_cast(-1) * dout_result; + } } } template diff --git a/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h b/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h index 54ef6e0c1f9cb..2b1d0d60bee50 100644 --- a/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h @@ -116,10 +116,19 @@ void ExpandAsGradKernel(const Context& context, ExpandAsBackward( context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); break; + case 7: + ExpandAsBackward( + context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 8: + ExpandAsBackward( + context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; default: PADDLE_THROW(errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " + "Only support tensor with rank being between 1 and %d. But " "received tensor's rank = %d.", + MAX_RANK_SUPPORTED, dims)); } } diff --git a/paddle/phi/kernels/impl/expand_as_kernel_impl.h b/paddle/phi/kernels/impl/expand_as_kernel_impl.h index cee562b42778e..927cd73b3eb4e 100755 --- a/paddle/phi/kernels/impl/expand_as_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_as_kernel_impl.h @@ -20,7 +20,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace phi { @@ -158,6 +158,12 @@ void ExpandAsKernel(const Context& ctx, case 6: ExpandAs(ctx, x, real_target_shape, out); break; + case 7: + ExpandAs(ctx, x, real_target_shape, out); + break; + case 8: + ExpandAs(ctx, x, real_target_shape, out); + break; } } diff --git a/paddle/phi/kernels/impl/expand_grad_kernel_impl.h b/paddle/phi/kernels/impl/expand_grad_kernel_impl.h index 4dd9dc4d50337..f24fff253558a 100644 --- a/paddle/phi/kernels/impl/expand_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_grad_kernel_impl.h @@ -128,10 +128,19 @@ void ExpandGradKernel(const Context& ctx, ExpandBackward( ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); break; + case 7: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 8: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; default: PADDLE_THROW(phi::errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " + "Only support tensor with rank being between 1 and %d. But " "received tensor's rank = %d.", + MAX_RANK_SUPPORTED, dims)); } } diff --git a/paddle/phi/kernels/impl/expand_kernel_impl.h b/paddle/phi/kernels/impl/expand_kernel_impl.h index 181dd2558fa38..7d675e036a55e 100644 --- a/paddle/phi/kernels/impl/expand_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_kernel_impl.h @@ -19,7 +19,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace phi { using Tensor = DenseTensor; @@ -169,6 +169,12 @@ void ExpandKernel(const Context& ctx, case 6: Expand(ctx, x, shape, out); break; + case 7: + Expand(ctx, x, shape, out); + break; + case 8: + Expand(ctx, x, shape, out); + break; } } diff --git a/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h b/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h index f296ad995cf7f..72ed43f09e152 100644 --- a/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/transpose_grad_kernel_impl.h @@ -26,17 +26,17 @@ void TransposeGradKernel(const Context& dev_ctx, const std::vector& axis, DenseTensor* x_grad) { size_t axis_size = axis.size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + axis_size; + formatted_axis[i] = axis[i] + axis_size; } } std::vector reversed_axis(axis); dev_ctx.template Alloc(x_grad); for (size_t i = 0; i < axis_size; i++) { - reversed_axis[formated_axis[i]] = i; + reversed_axis[formatted_axis[i]] = i; } TransposeKernel(dev_ctx, out_grad, reversed_axis, x_grad); diff --git a/paddle/phi/kernels/isfinite_kernel.h b/paddle/phi/kernels/isfinite_kernel.h index e695a8e074223..291bec9b78436 100644 --- a/paddle/phi/kernels/isfinite_kernel.h +++ b/paddle/phi/kernels/isfinite_kernel.h @@ -20,7 +20,7 @@ namespace phi { #define DEFINE_ISFINITE_KERNEL(isfinite_kernel) \ template \ - void isfinite_kernel( \ + TEST_API void isfinite_kernel( \ const Context& ctx, const DenseTensor& x, DenseTensor* out); DEFINE_ISFINITE_KERNEL(IsinfKernel) diff --git a/paddle/phi/kernels/kps/reduce_kernel.cu b/paddle/phi/kernels/kps/reduce_kernel.cu index 74020a8f0975b..14b7c5809a14c 100644 --- a/paddle/phi/kernels/kps/reduce_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_kernel.cu @@ -248,8 +248,25 @@ void SumRawKernel(const Context& dev_ctx, "now.")); #endif } else { - phi::Reduce( - dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); + if (x.dtype() == phi::DataType::BFLOAT16 && + out_dtype == phi::DataType::FLOAT32) { + std::vector reduce_dims = phi::funcs::details::GetReduceDim( + dims.GetData(), x.dims().size(), reduce_all); + + phi::funcs::ReduceKernel< + phi::dtype::bfloat16, + float, + kps::AddFunctor, + kps::IdentityFunctor>( + dev_ctx, + x, + out, + kps::IdentityFunctor(), + reduce_dims); + } else { + phi::Reduce( + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); + } } } } // namespace phi diff --git a/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc b/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc index dafbf2889277d..84ebbf04fee11 100644 --- a/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/legacy/cpu/elementwise_kernel.cc @@ -55,7 +55,7 @@ void RemainderRawKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); auto x_dims = x.dims(); auto y_dims = y.dims(); - if (x_dims.size() >= y_dims.size()) { + if (x_dims.size() >= y_dims.size()) { // NOLINT funcs::ElementwiseCompute, T>( dev_ctx, x, y, funcs::RemainderFunctor(), out, axis); } else { @@ -74,7 +74,7 @@ void FloorDivideRawKernel(const Context& dev_ctx, dev_ctx.template Alloc(out); auto x_dims = x.dims(); auto y_dims = y.dims(); - if (x_dims.size() >= y_dims.size()) { + if (x_dims.size() >= y_dims.size()) { // NOLINT funcs::ElementwiseCompute, T>( dev_ctx, x, y, funcs::FloorDivideFunctor(), out, axis); } else { diff --git a/paddle/phi/kernels/logical_kernel.h b/paddle/phi/kernels/logical_kernel.h index 3ccc03a5b598a..69214ef1d4532 100644 --- a/paddle/phi/kernels/logical_kernel.h +++ b/paddle/phi/kernels/logical_kernel.h @@ -18,17 +18,17 @@ limitations under the License. */ namespace phi { -#define DECLEAR_LOGICAL_BINARY_KERNEL(type) \ +#define DECLARE_LOGICAL_BINARY_KERNEL(type) \ template \ void Logical##type##Kernel(const Context& dev_ctx, \ const DenseTensor& x, \ const DenseTensor& y, \ DenseTensor* out); -DECLEAR_LOGICAL_BINARY_KERNEL(And) -DECLEAR_LOGICAL_BINARY_KERNEL(Or) -DECLEAR_LOGICAL_BINARY_KERNEL(Xor) -#undef DECLEAR_LOGICAL_BINARY_KERNEL +DECLARE_LOGICAL_BINARY_KERNEL(And) +DECLARE_LOGICAL_BINARY_KERNEL(Or) +DECLARE_LOGICAL_BINARY_KERNEL(Xor) +#undef DECLARE_LOGICAL_BINARY_KERNEL template void LogicalNotKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/nanmedian_grad_kernel.h b/paddle/phi/kernels/nanmedian_grad_kernel.h index e8fb01b7060a7..f76823cbfa3b1 100644 --- a/paddle/phi/kernels/nanmedian_grad_kernel.h +++ b/paddle/phi/kernels/nanmedian_grad_kernel.h @@ -26,5 +26,6 @@ void NanmedianGradKernel(const Context& dev_ctx, const DenseTensor& out_grad, const IntArray& axes, bool keep_dim, + const std::string& mode, DenseTensor* x_grad); } // namespace phi diff --git a/paddle/phi/kernels/nanmedian_kernel.h b/paddle/phi/kernels/nanmedian_kernel.h index 4bb382a443144..95fecafde12cf 100644 --- a/paddle/phi/kernels/nanmedian_kernel.h +++ b/paddle/phi/kernels/nanmedian_kernel.h @@ -24,6 +24,7 @@ void NanmedianKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, bool keep_dim, + const std::string& mode, DenseTensor* out, DenseTensor* medians); } // namespace phi diff --git a/paddle/phi/kernels/onednn/add_n_kernel.cc b/paddle/phi/kernels/onednn/add_n_kernel.cc index f852254043e87..454d6851cfeac 100644 --- a/paddle/phi/kernels/onednn/add_n_kernel.cc +++ b/paddle/phi/kernels/onednn/add_n_kernel.cc @@ -17,6 +17,19 @@ #include "paddle/phi/core/kernel_registry.h" namespace phi { +bool AddNCheckIfOneDNNSupport(const KernelContext* ctx) { + for (size_t i = 0; i < ctx->InputsSize(); i++) { + if (!DenseTensor::classof(ctx->MutableIutputAt(i))) { + return false; + } + } + KernelContext* ctx_tmp = const_cast(ctx); + if (!DenseTensor::classof(ctx_tmp->MutableOutputAt(0))) { + return false; + } + return true; +} + namespace funcs { template class SumOneDNNHandler : public OneDNNHandlerNoCachingT { @@ -122,4 +135,6 @@ void AddNKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - add_n, OneDNN, ONEDNN, phi::AddNKernel, float, phi::dtype::bfloat16) {} + add_n, OneDNN, ONEDNN, phi::AddNKernel, float, phi::dtype::bfloat16) { + kernel->check_if_onednn_kernel_support_ = phi::AddNCheckIfOneDNNSupport; +} diff --git a/paddle/phi/kernels/onednn/concat_grad_kernel.cc b/paddle/phi/kernels/onednn/concat_grad_kernel.cc index fc36fa4ab0fd8..9563f73f0ba92 100644 --- a/paddle/phi/kernels/onednn/concat_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/concat_grad_kernel.cc @@ -40,7 +40,7 @@ void ConcatGradKernel(const Context& dev_ctx, auto out_grad_vec_dims = common::vectorize(out_grad.dims()); - axis = funcs::ComputeAxis(axis, out_grad_vec_dims.size()); + axis = static_cast(funcs::ComputeAxis(axis, out_grad_vec_dims.size())); std::vector offset(out_grad_vec_dims.size(), 0); @@ -60,7 +60,7 @@ void ConcatGradKernel(const Context& dev_ctx, auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( grad, x_grad_vec_dims, - funcs::GetPlainOneDNNFormat(x_grad_vec_dims.size()), + funcs::GetPlainOneDNNFormat(static_cast(x_grad_vec_dims.size())), dev_ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p); diff --git a/paddle/phi/kernels/onednn/conv_transpose_kernel.cc b/paddle/phi/kernels/onednn/conv_transpose_kernel.cc index 208b0f3f6e9be..f79f2f8619c9b 100644 --- a/paddle/phi/kernels/onednn/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/onednn/conv_transpose_kernel.cc @@ -356,15 +356,13 @@ template void Execute(const OneDNNContext& dev_ctx, const DenseTensor* x, const DenseTensor* filter, + const DenseTensor* bias, const std::vector& strides, const std::vector& paddings, const std::string& padding_algorithm, int groups, const std::vector& dilations, DenseTensor* out) { - const auto* bias = - dev_ctx.HasDnnInput("Bias") ? dev_ctx.GetDnnInput("Bias") : nullptr; - std::shared_ptr conv_p; std::shared_ptr src_memory_p; std::shared_ptr weights_memory_p; @@ -407,6 +405,23 @@ void Execute(const OneDNNContext& dev_ctx, args.insert({DNNL_ARG_BIAS, *bias_memory_p}); } } else { + // Check if bias obey the rules + if (bias) { + PADDLE_ENFORCE_EQ( + bias->layout(), + DataLayout::ONEDNN, + phi::errors::InvalidArgument( + "The Bias tensor's layout should be %d, but got %d.", + DataLayout::ONEDNN, + bias->layout())); + + PADDLE_ENFORCE_EQ( + bias->dims().size(), + 1, + phi::errors::InvalidArgument("Bias must only have 1 dimension, " + "i.e. X, but got dimension = %d .", + bias->dims().size())); + } // Caching Key for weights is needed std::string key = funcs::CreateKey(dev_ctx, @@ -494,6 +509,63 @@ void Conv2dTransposeKernel(const Context& dev_ctx, Execute(dev_ctx, &x, &filter, + nullptr, + strides, + paddings, + padding_algorithm, + groups, + dilations, + out); + } else { + Execute(dev_ctx, + &x, + &filter, + nullptr, + strides, + paddings, + padding_algorithm, + groups, + dilations, + out); + } +} + +template +void Conv2dTransposeBiasKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& filter, + const paddle::optional& bias, + const std::vector& strides, + const std::vector& paddings, + const std::vector& output_padding UNUSED, + const IntArray& output_size UNUSED, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format UNUSED, + DenseTensor* out) { + PADDLE_ENFORCE_EQ(dev_ctx.GetPlace().GetType(), + AllocationType::CPU, + phi::errors::PreconditionNotMet( + "Operator oneDNN Conv must use CPUPlace")); + + const bool is_BFLOAT16 = + dev_ctx.HasDnnAttr("mkldnn_data_type") + ? PADDLE_GET_CONST(std::string, + dev_ctx.GetDnnAttr("mkldnn_data_type")) == + "bfloat16" + : false; + const bool force_fp32_output = + dev_ctx.HasDnnAttr("force_fp32_output") + ? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output")) + : false; + const bool use_bfloat16 = (!force_fp32_output && is_BFLOAT16); + + if (use_bfloat16) { + Execute(dev_ctx, + &x, + &filter, + bias.get_ptr(), strides, paddings, padding_algorithm, @@ -504,6 +576,7 @@ void Conv2dTransposeKernel(const Context& dev_ctx, Execute(dev_ctx, &x, &filter, + bias.get_ptr(), strides, paddings, padding_algorithm, @@ -547,3 +620,12 @@ PD_REGISTER_KERNEL(conv2d_transpose, phi::dtype::bfloat16) { kernel->get_kerneltype_forvar_fn_ = phi::ConvTransposeGetKernelTypeForVar; } + +PD_REGISTER_KERNEL(conv2d_transpose_bias, + OneDNN, + ONEDNN, + phi::Conv2dTransposeBiasKernel, + float, + phi::dtype::bfloat16) { + kernel->get_kerneltype_forvar_fn_ = phi::ConvTransposeGetKernelTypeForVar; +} diff --git a/paddle/phi/kernels/onednn/expand_grad_kernel.cc b/paddle/phi/kernels/onednn/expand_grad_kernel.cc index a8b1beb45832f..7de901df9561d 100644 --- a/paddle/phi/kernels/onednn/expand_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/expand_grad_kernel.cc @@ -50,7 +50,7 @@ void ExpandGradKernel(const Context& dev_ctx, auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( in_grad, - funcs::GetPlainOneDNNFormat(in_grad_vec_dims.size()), + funcs::GetPlainOneDNNFormat(static_cast(in_grad_vec_dims.size())), dev_ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, diff --git a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc index 3866a2d06ae45..46a2a7450d41c 100644 --- a/paddle/phi/kernels/onednn/matmul_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/matmul_grad_kernel.cc @@ -51,8 +51,10 @@ void CalculateMatrixDims(const std::vector &x_dims, for (size_t i = 0; i < x_bd_dims->size() - 2; ++i) { (*out_bd_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]); } - int h_idx = trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2; - int w_idx = trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1; + int h_idx = + trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2; // NOLINT + int w_idx = + trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1; // NOLINT (*out_bd_dims)[x_bd_dims->size() - 2] = (*x_bd_dims)[h_idx]; (*out_bd_dims)[y_bd_dims->size() - 1] = (*y_bd_dims)[w_idx]; diff --git a/paddle/phi/kernels/onednn/matmul_kernel.cc b/paddle/phi/kernels/onednn/matmul_kernel.cc index b7b31ff479b30..342fce6f2be02 100644 --- a/paddle/phi/kernels/onednn/matmul_kernel.cc +++ b/paddle/phi/kernels/onednn/matmul_kernel.cc @@ -124,7 +124,7 @@ void MatmulKernel(const Context &dev_ctx, auto x_dims = common::vectorize(x.dims()); auto y_dims = common::vectorize(y.dims()); - int ndims = std::max(x_dims.size(), y_dims.size()); + int ndims = std::max(x_dims.size(), y_dims.size()); // NOLINT ndims = std::max(ndims, 3); std::vector x_bd_dims(ndims, 1); @@ -266,7 +266,7 @@ class MulPrimitiveFactory { auto scale_out_data = force_fp32_output ? 1.0f : scale_out; bool is_multi_channel = scale_y_data.size() > 1; - int count = is_multi_channel ? scale_y_data.size() : 1; + int count = is_multi_channel ? scale_y_data.size() : 1; // NOLINT std::vector output_shift_scale(count); for (int i = 0; i < count; i++) { if (scale_y_data[i] == 0.0) diff --git a/paddle/phi/kernels/onednn/scale_kernel.cc b/paddle/phi/kernels/onednn/scale_kernel.cc index 68bee7a39c8a5..4d65358f96749 100644 --- a/paddle/phi/kernels/onednn/scale_kernel.cc +++ b/paddle/phi/kernels/onednn/scale_kernel.cc @@ -23,11 +23,11 @@ template void ScaleKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, DenseTensor* out) { float alpha = scale.to(); - float beta = bias_after_scale ? bias : bias * alpha; + float beta = bias_after_scale ? bias.to() : bias.to() * alpha; funcs::ActivationOneDNNHandler handler(dnnl::algorithm::eltwise_linear, alpha, diff --git a/paddle/phi/kernels/onednn/sgd_kernel.cc b/paddle/phi/kernels/onednn/sgd_kernel.cc index 6ceba6b2cf7b7..007af969e2787 100644 --- a/paddle/phi/kernels/onednn/sgd_kernel.cc +++ b/paddle/phi/kernels/onednn/sgd_kernel.cc @@ -20,6 +20,22 @@ namespace phi { +bool SgdCheckIfOneDNNSupport(const KernelContext* ctx) { + if (DenseTensor::classof(ctx->MutableIutputAt(0)) && + DenseTensor::classof(ctx->MutableIutputAt(2))) { + return true; + } + return false; +} + +bool SgdSparseCheckIfOneDNNSupport(const KernelContext* ctx) { + if (DenseTensor::classof(ctx->MutableIutputAt(0)) && + SelectedRows::classof(ctx->MutableIutputAt(2))) { + return true; + } + return false; +} + template void SGDDenseKernel(const Context& dev_ctx, const DenseTensor& param, @@ -82,11 +98,15 @@ void SGDDenseParamSparseGradKernel( } // namespace phi PD_REGISTER_KERNEL( - sgd, OneDNN, ONEDNN, phi::SGDDenseKernel, float, phi::dtype::bfloat16) {} + sgd, OneDNN, ONEDNN, phi::SGDDenseKernel, float, phi::dtype::bfloat16) { + kernel->check_if_onednn_kernel_support_ = phi::SgdCheckIfOneDNNSupport; +} PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, OneDNN, ONEDNN, phi::SGDDenseParamSparseGradKernel, float, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16) { + kernel->check_if_onednn_kernel_support_ = phi::SgdSparseCheckIfOneDNNSupport; +} diff --git a/paddle/phi/kernels/onednn/slice_grad_kernel.cc b/paddle/phi/kernels/onednn/slice_grad_kernel.cc index 7f8f6b815b4f0..e2d4aa59c9d46 100644 --- a/paddle/phi/kernels/onednn/slice_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/slice_grad_kernel.cc @@ -19,6 +19,13 @@ namespace phi { +bool SliceGradCheckIfOneDNNSupport(const KernelContext* ctx) { + if (ctx->InputAt(1).mem_desc().get_inner_nblks() == 0) { + return true; + } + return false; +} + template void SliceGradKernel(const Context& dev_ctx, const DenseTensor& input UNUSED, @@ -60,7 +67,7 @@ void SliceGradKernel(const Context& dev_ctx, auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( input_grad, dx_dims, - funcs::GetPlainOneDNNFormat(dx_dims.size()), + funcs::GetPlainOneDNNFormat(static_cast(dx_dims.size())), dev_ctx.GetPlace()); memset(input_grad->data(), 0, reorder_dst_memory_p->get_desc().get_size()); @@ -83,4 +90,6 @@ PD_REGISTER_KERNEL(slice_grad, ONEDNN, phi::SliceGradKernel, float, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16) { + kernel->check_if_onednn_kernel_support_ = phi::SliceGradCheckIfOneDNNSupport; +} diff --git a/paddle/phi/kernels/onednn/slice_kernel.cc b/paddle/phi/kernels/onednn/slice_kernel.cc index bd59d61c17e79..41116033d7237 100644 --- a/paddle/phi/kernels/onednn/slice_kernel.cc +++ b/paddle/phi/kernels/onednn/slice_kernel.cc @@ -19,6 +19,18 @@ namespace phi { +bool SliceCheckIfOneDNNSupport(const KernelContext* ctx) { + auto x = ctx->InputAt(0); + auto vec_dims = common::vectorize(x.dims()); + bool all_zero_dims = std::all_of( + vec_dims.cbegin(), vec_dims.cend(), [](int64_t i) { return i == 0; }); + + if (!all_zero_dims && x.mem_desc().get_inner_nblks() == 0) { + return true; + } + return false; +} + template void SliceKernel(const Context& dev_ctx, const DenseTensor& x, @@ -69,7 +81,7 @@ void SliceKernel(const Context& dev_ctx, auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( out, slice_dims, - funcs::GetPlainOneDNNFormat(x_vec_dims.size()), + funcs::GetPlainOneDNNFormat(static_cast(x_vec_dims.size())), dev_ctx.GetPlace()); auto reorder_p = @@ -106,4 +118,6 @@ PD_REGISTER_KERNEL(slice, float, int8_t, uint8_t, - phi::dtype::bfloat16) {} + phi::dtype::bfloat16) { + kernel->check_if_onednn_kernel_support_ = phi::SliceCheckIfOneDNNSupport; +} diff --git a/paddle/phi/kernels/onednn/split_kernel.cc b/paddle/phi/kernels/onednn/split_kernel.cc index cf0cd1d62a020..713324774ab20 100644 --- a/paddle/phi/kernels/onednn/split_kernel.cc +++ b/paddle/phi/kernels/onednn/split_kernel.cc @@ -19,6 +19,13 @@ namespace phi { +bool SplitCheckIfOneDNNSupport(const KernelContext* ctx) { + if (ctx->InputAt(0).mem_desc().get_inner_nblks() == 0) { + return true; + } + return false; +} + const std::vector get_slice_strides( const std::vector& out_vec_dims, const dnnl::memory::desc& full_md, @@ -104,7 +111,9 @@ PD_REGISTER_KERNEL(split, float, phi::dtype::bfloat16, int8_t, - uint8_t) {} + uint8_t) { + kernel->check_if_onednn_kernel_support_ = phi::SplitCheckIfOneDNNSupport; +} PD_REGISTER_KERNEL(split_with_num, OneDNN, @@ -113,4 +122,6 @@ PD_REGISTER_KERNEL(split_with_num, float, phi::dtype::bfloat16, int8_t, - uint8_t) {} + uint8_t) { + kernel->check_if_onednn_kernel_support_ = phi::SplitCheckIfOneDNNSupport; +} diff --git a/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc b/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc index d8ff4e72c1b11..78a3c4dce6bd3 100644 --- a/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc +++ b/paddle/phi/kernels/onednn/squeeze_grad_kernel.cc @@ -37,7 +37,7 @@ void SqueezeGradKernel(const Context& dev_ctx, dout.mem_desc(), funcs::to_void_cast(dout.data())); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( dx, - funcs::GetPlainOneDNNFormat(dout_vec_dims.size()), + funcs::GetPlainOneDNNFormat(static_cast(dout_vec_dims.size())), dev_ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); diff --git a/paddle/phi/kernels/onednn/squeeze_kernel.cc b/paddle/phi/kernels/onednn/squeeze_kernel.cc index 2d9522277d857..a3c1beb710740 100644 --- a/paddle/phi/kernels/onednn/squeeze_kernel.cc +++ b/paddle/phi/kernels/onednn/squeeze_kernel.cc @@ -62,7 +62,7 @@ void SqueezeInferKernel(const Context& dev_ctx, auto x_dims_tz = x_dims.size(); std::vector tmp(axes.GetData().begin(), axes.GetData().end()); - // Currently there is only tranformation for tensors, while attr axes still + // Currently there is only transformation for tensors, while attr axes still // follows default dtype instead of oneDNN dtype, so here manually change it if ((x_dims_tz >= 3) && (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == diff --git a/paddle/phi/kernels/onednn/transpose_kernel.cc b/paddle/phi/kernels/onednn/transpose_kernel.cc index ef1f3b0d87fdb..c0faaf5e6c7ba 100644 --- a/paddle/phi/kernels/onednn/transpose_kernel.cc +++ b/paddle/phi/kernels/onednn/transpose_kernel.cc @@ -33,11 +33,11 @@ void TransposeKernel(const Context& dev_ctx, (phi::OneDNNContext::tls().get_cur_paddle_data_layout() == phi::DataLayout::kNHWC)) { int axis_size = static_cast(axis.size()); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; std::vector count(axis_size, 0); for (int i = 0; i < axis_size; i++) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + axis_size; + formatted_axis[i] = axis[i] + axis_size; } } auto dims = common::vectorize(x_dims); @@ -49,7 +49,7 @@ void TransposeKernel(const Context& dev_ctx, phi::DDim out_dims(x_dims); for (size_t i = 0; i < axis.size(); i++) { - out_dims[i] = x_dims[formated_axis[i]]; // NOLINT + out_dims[i] = x_dims[formatted_axis[i]]; // NOLINT } out->Resize(out_dims); } diff --git a/paddle/phi/kernels/prior_box_kernel.h b/paddle/phi/kernels/prior_box_kernel.h index 45a741c7a3a72..132efb7b6cc72 100644 --- a/paddle/phi/kernels/prior_box_kernel.h +++ b/paddle/phi/kernels/prior_box_kernel.h @@ -35,25 +35,25 @@ void PriorBoxKernel(const Context& ctx, DenseTensor* out, DenseTensor* var); -inline void ExpandAspectRatios(const std::vector& input_aspect_ratior, +inline void ExpandAspectRatios(const std::vector& input_aspect_ratio, bool flip, - std::vector* output_aspect_ratior) { + std::vector* output_aspect_ratio) { constexpr float epsilon = 1e-6; - output_aspect_ratior->clear(); - output_aspect_ratior->push_back(1.0f); - for (size_t i = 0; i < input_aspect_ratior.size(); ++i) { - float ar = input_aspect_ratior[i]; + output_aspect_ratio->clear(); + output_aspect_ratio->push_back(1.0f); + for (size_t i = 0; i < input_aspect_ratio.size(); ++i) { + float ar = input_aspect_ratio[i]; bool already_exist = false; - for (size_t j = 0; j < output_aspect_ratior->size(); ++j) { - if (fabs(ar - output_aspect_ratior->at(j)) < epsilon) { + for (size_t j = 0; j < output_aspect_ratio->size(); ++j) { + if (fabs(ar - output_aspect_ratio->at(j)) < epsilon) { already_exist = true; break; } } if (!already_exist) { - output_aspect_ratior->push_back(ar); + output_aspect_ratio->push_back(ar); if (flip) { - output_aspect_ratior->push_back(1.0f / ar); + output_aspect_ratio->push_back(1.0f / ar); } } } diff --git a/paddle/phi/kernels/reduce_all_kernel.h b/paddle/phi/kernels/reduce_all_kernel.h index af34a0a5d4c6f..3610ec245ac98 100644 --- a/paddle/phi/kernels/reduce_all_kernel.h +++ b/paddle/phi/kernels/reduce_all_kernel.h @@ -27,10 +27,10 @@ void AllRawKernel(const Context& dev_ctx, DenseTensor* out); template -void AllKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - DenseTensor* out); +TEST_API void AllKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/reduce_any_kernel.h b/paddle/phi/kernels/reduce_any_kernel.h index 9514d02dbdf94..d6a9392e4996b 100644 --- a/paddle/phi/kernels/reduce_any_kernel.h +++ b/paddle/phi/kernels/reduce_any_kernel.h @@ -26,10 +26,10 @@ void AnyRawKernel(const Context& dev_ctx, DenseTensor* out); template -void AnyKernel(const Context& dev_ctx, - const DenseTensor& x, - const std::vector& dims, - bool keep_dim, - DenseTensor* out); +TEST_API void AnyKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& dims, + bool keep_dim, + DenseTensor* out); } // namespace phi diff --git a/paddle/phi/kernels/scale_kernel.h b/paddle/phi/kernels/scale_kernel.h index 7537dc1130b83..5cf95ff207085 100644 --- a/paddle/phi/kernels/scale_kernel.h +++ b/paddle/phi/kernels/scale_kernel.h @@ -24,7 +24,7 @@ template void ScaleKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, DenseTensor* out); @@ -32,7 +32,7 @@ template DenseTensor Scale(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale) { DenseTensor dense_out; MetaTensor meta_out(&dense_out); diff --git a/paddle/phi/kernels/selected_rows/scale_kernel.cc b/paddle/phi/kernels/selected_rows/scale_kernel.cc index 38a0cb75101b7..6eded1219b283 100644 --- a/paddle/phi/kernels/selected_rows/scale_kernel.cc +++ b/paddle/phi/kernels/selected_rows/scale_kernel.cc @@ -26,7 +26,7 @@ template void ScaleKernel(const Context& dev_ctx, const SelectedRows& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, SelectedRows* out) { if (x.value().Holder() != out->value().Holder() || diff --git a/paddle/phi/kernels/selected_rows/scale_kernel.h b/paddle/phi/kernels/selected_rows/scale_kernel.h index 85c2c4ddff033..611d61e1aa56d 100644 --- a/paddle/phi/kernels/selected_rows/scale_kernel.h +++ b/paddle/phi/kernels/selected_rows/scale_kernel.h @@ -24,7 +24,7 @@ template void ScaleKernel(const Context& dev_ctx, const SelectedRows& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, SelectedRows* out); diff --git a/paddle/phi/kernels/shape_kernel.cc b/paddle/phi/kernels/shape_kernel.cc index e4610f51b9247..939515edd725e 100644 --- a/paddle/phi/kernels/shape_kernel.cc +++ b/paddle/phi/kernels/shape_kernel.cc @@ -105,7 +105,8 @@ PD_REGISTER_KERNEL(shape, double, phi::dtype::complex, phi::dtype::complex, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->OutputAt(0).SetBackend(phi::Backend::CPU); kernel->OutputAt(0).SetDataType(phi::DataType::INT32); diff --git a/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu b/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu index 472777d7f3515..7ae8814470f41 100644 --- a/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/addmm_kernel.cu @@ -132,7 +132,8 @@ PD_REGISTER_KERNEL(addmm_coo_dense, ALL_LAYOUT, phi::sparse::AddmmCooDenseKernel, float, - double) { + double, + phi::dtype::float16) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_COO); } @@ -141,6 +142,7 @@ PD_REGISTER_KERNEL(addmm_csr_dense, ALL_LAYOUT, phi::sparse::AddmmCsrDenseKernel, float, - double) { + double, + phi::dtype::float16) { kernel->InputAt(0).SetDataLayout(phi::DataLayout::SPARSE_CSR); } diff --git a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py index bc17ae6eb2c13..b8f3254292bb4 100644 --- a/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py +++ b/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_operation.py @@ -305,7 +305,4 @@ def __init__( } def layout_name(self): - return "{}{}".format( - self.ShortLayoutTypeNames[self.A.layout], - self.ShortLayoutTypeNames[self.B.layout], - ) + return f"{self.ShortLayoutTypeNames[self.A.layout]}{self.ShortLayoutTypeNames[self.B.layout]}" diff --git a/paddle/phi/kernels/stride/as_complex_kernel.cc b/paddle/phi/kernels/stride/as_complex_kernel.cc index 173371283e683..e6d589d8c3a8b 100644 --- a/paddle/phi/kernels/stride/as_complex_kernel.cc +++ b/paddle/phi/kernels/stride/as_complex_kernel.cc @@ -66,3 +66,10 @@ PD_REGISTER_KERNEL( kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL( + as_complex, Custom, STRIDED, phi::AsComplexStridedKernel, float, double) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} +#endif diff --git a/paddle/phi/kernels/stride/as_real_kernel.cc b/paddle/phi/kernels/stride/as_real_kernel.cc index bde22763e91c6..403d2991644a7 100644 --- a/paddle/phi/kernels/stride/as_real_kernel.cc +++ b/paddle/phi/kernels/stride/as_real_kernel.cc @@ -62,3 +62,14 @@ PD_REGISTER_KERNEL(as_real, kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); } #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(as_real, + Custom, + STRIDED, + phi::AsRealStridedKernel, + phi::dtype::complex, + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED); +} +#endif diff --git a/paddle/phi/kernels/stride/as_strided_grad_kernel.cc b/paddle/phi/kernels/stride/as_strided_grad_kernel.cc index edf72e5da026c..08f9dd3d0390a 100644 --- a/paddle/phi/kernels/stride/as_strided_grad_kernel.cc +++ b/paddle/phi/kernels/stride/as_strided_grad_kernel.cc @@ -16,8 +16,7 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/as_strided_kernel.h" -#include "paddle/phi/kernels/fill_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" namespace phi { @@ -32,15 +31,14 @@ void AsStridedGradKernel(const Context& dev_ctx, dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); PD_VISIT_ALL_TYPES(input_grad->dtype(), "AsStridedGradKernel", ([&] { - phi::FillKernel( - dev_ctx, *input_grad, 0, input_grad); + phi::StridedTensorFill( + *input_grad, 0, input_grad); })); DenseTensor tmp; tmp.set_meta(out_grad.meta()); AsStridedKernel(dev_ctx, *input_grad, dims, stride, offset, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "AsStridedGradKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( out_grad, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -48,7 +46,8 @@ void AsStridedGradKernel(const Context& dev_ctx, &tmp); })); } - } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - as_strided_grad, STRIDED, phi::AsStridedGradKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(as_strided_grad, + STRIDED, + phi::AsStridedGradKernel) {} diff --git a/paddle/phi/kernels/stride/as_strided_kernel.cc b/paddle/phi/kernels/stride/as_strided_kernel.cc index 28ea8f4e63842..c1ce1c1167344 100644 --- a/paddle/phi/kernels/stride/as_strided_kernel.cc +++ b/paddle/phi/kernels/stride/as_strided_kernel.cc @@ -34,6 +34,7 @@ void AsStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(as_strided, - STRIDED, - phi::AsStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(as_strided, + STRIDED, + phi::AsStridedKernel) {} diff --git a/paddle/phi/kernels/stride/complex_grad_kernel.cc b/paddle/phi/kernels/stride/complex_grad_kernel.cc index 800e484ea7eb8..528b4aef1a797 100644 --- a/paddle/phi/kernels/stride/complex_grad_kernel.cc +++ b/paddle/phi/kernels/stride/complex_grad_kernel.cc @@ -16,8 +16,7 @@ #include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/complex_kernel.h" -#include "paddle/phi/kernels/fill_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" namespace phi { @@ -28,14 +27,13 @@ void RealGradStridedKernel(const Context& dev_ctx, dev_ctx.Alloc(dx, dx->dtype()); dx->set_strides(DenseTensorMeta::calc_strides(dx->dims())); PD_VISIT_ALL_TYPES(dx->dtype(), "RealGradStridedKernel", ([&] { - phi::FillKernel(dev_ctx, *dx, 0, dx); + phi::StridedTensorFill(*dx, 0, dx); })); DenseTensor tmp; tmp.set_meta(dout.meta()); RealStridedKernel(dev_ctx, *dx, &tmp); PD_VISIT_ALL_TYPES(dout.dtype(), "RealGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( dout, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -51,15 +49,14 @@ void ImagGradStridedKernel(const Context& dev_ctx, dev_ctx.Alloc(dx, dx->dtype()); dx->set_strides(DenseTensorMeta::calc_strides(dx->dims())); PD_VISIT_ALL_TYPES(dx->dtype(), "ImagGradStridedKernel", ([&] { - phi::FillKernel(dev_ctx, *dx, 0, dx); + phi::StridedTensorFill(*dx, 0, dx); })); DenseTensor tmp; tmp.set_meta(dout.meta()); ImagStridedKernel(dev_ctx, *dx, &tmp); PD_VISIT_ALL_TYPES(dout.dtype(), "ImagGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( dout, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -107,3 +104,23 @@ PD_REGISTER_KERNEL(imag_grad, kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(real_grad, + Custom, + STRIDED, + phi::RealGradStridedKernel, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} + +PD_REGISTER_KERNEL(imag_grad, + Custom, + STRIDED, + phi::ImagGradStridedKernel, + phi::dtype::complex, + phi::dtype::complex) { + kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} +#endif diff --git a/paddle/phi/kernels/stride/complex_kernel.cc b/paddle/phi/kernels/stride/complex_kernel.cc index d72bfec2b09f0..815ca06f46ac3 100644 --- a/paddle/phi/kernels/stride/complex_kernel.cc +++ b/paddle/phi/kernels/stride/complex_kernel.cc @@ -97,3 +97,23 @@ PD_REGISTER_KERNEL(imag, kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); } #endif + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +PD_REGISTER_KERNEL(real, + Custom, + STRIDED, + phi::RealStridedKernel, + phi::dtype::complex, + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} + +PD_REGISTER_KERNEL(imag, + Custom, + STRIDED, + phi::ImagStridedKernel, + phi::dtype::complex, + phi::dtype::complex) { + kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); +} +#endif diff --git a/paddle/phi/kernels/stride/diagonal_grad_kernel.cc b/paddle/phi/kernels/stride/diagonal_grad_kernel.cc index fc44c09118fad..b3365b9d6022f 100644 --- a/paddle/phi/kernels/stride/diagonal_grad_kernel.cc +++ b/paddle/phi/kernels/stride/diagonal_grad_kernel.cc @@ -16,8 +16,7 @@ #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/diagonal_kernel.h" -#include "paddle/phi/kernels/fill_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" namespace phi { @@ -32,8 +31,7 @@ void DiagonalGradStridedKernel(const Context& dev_ctx, dev_ctx.Alloc(in_grad, in_grad->dtype()); in_grad->set_strides(DenseTensorMeta::calc_strides(in_grad->dims())); PD_VISIT_ALL_TYPES(in_grad->dtype(), "DiagonalGradStridedKernel", ([&] { - phi::FillKernel( - dev_ctx, *in_grad, 0, in_grad); + phi::StridedTensorFill(*in_grad, 0, in_grad); })); DenseTensor tmp; tmp.set_layout(out_grad.layout()); @@ -43,8 +41,7 @@ void DiagonalGradStridedKernel(const Context& dev_ctx, DiagonalStridedKernel(dev_ctx, *in_grad, offset, axis1, axis2, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "DiagonalGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( out_grad, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -54,5 +51,7 @@ void DiagonalGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - diagonal_grad, STRIDED, phi::DiagonalGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(diagonal_grad, + STRIDED, + phi::DiagonalGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/diagonal_kernel.cc b/paddle/phi/kernels/stride/diagonal_kernel.cc index f21ea6c24ac6f..31c250ee2880a 100644 --- a/paddle/phi/kernels/stride/diagonal_kernel.cc +++ b/paddle/phi/kernels/stride/diagonal_kernel.cc @@ -82,5 +82,7 @@ void DiagonalStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - diagonal, STRIDED, phi::DiagonalStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(diagonal, + STRIDED, + phi::DiagonalStridedKernel) {} diff --git a/paddle/phi/kernels/stride/flatten_grad_kernel.cc b/paddle/phi/kernels/stride/flatten_grad_kernel.cc index be7ed0721fdd2..3bf337797bc0f 100644 --- a/paddle/phi/kernels/stride/flatten_grad_kernel.cc +++ b/paddle/phi/kernels/stride/flatten_grad_kernel.cc @@ -33,5 +33,7 @@ void FlattenGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - flatten_grad, STRIDED, phi::FlattenGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(flatten_grad, + STRIDED, + phi::FlattenGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/flatten_kernel.cc b/paddle/phi/kernels/stride/flatten_kernel.cc index 94b4ae0a89890..f2240aa9bff87 100644 --- a/paddle/phi/kernels/stride/flatten_kernel.cc +++ b/paddle/phi/kernels/stride/flatten_kernel.cc @@ -43,8 +43,11 @@ void FlattenStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - flatten_infer, STRIDED, phi::FlattenInferStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - flatten, STRIDED, phi::FlattenStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(flatten_infer, + STRIDED, + phi::FlattenInferStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(flatten, + STRIDED, + phi::FlattenStridedKernel) {} diff --git a/paddle/phi/kernels/stride/index_select_grad_kernel.cc b/paddle/phi/kernels/stride/index_select_grad_kernel.cc index 99705b396f19e..51b690f78d978 100644 --- a/paddle/phi/kernels/stride/index_select_grad_kernel.cc +++ b/paddle/phi/kernels/stride/index_select_grad_kernel.cc @@ -15,9 +15,9 @@ #include "paddle/phi/kernels/index_select_grad_kernel.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/fill_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" #include "paddle/phi/kernels/index_select_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" + namespace phi { template @@ -30,8 +30,7 @@ void IndexSelectGradStridedKernel(const Context& dev_ctx, dev_ctx.Alloc(x_grad, x_grad->dtype()); x_grad->set_strides(DenseTensorMeta::calc_strides(x_grad->dims())); PD_VISIT_ALL_TYPES(x_grad->dtype(), "IndexSelectGradStridedKernel", ([&] { - phi::FillKernel( - dev_ctx, *x_grad, 0, x_grad); + phi::StridedTensorFill(*x_grad, 0, x_grad); })); DenseTensor tmp; tmp.set_layout(out_grad.layout()); @@ -41,8 +40,7 @@ void IndexSelectGradStridedKernel(const Context& dev_ctx, IndexSelectStridedKernel(dev_ctx, *x_grad, index, dim, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "IndexSelectGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( out_grad, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -52,5 +50,7 @@ void IndexSelectGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - index_select_grad_strided, STRIDED, phi::IndexSelectGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(index_select_strided_grad, + STRIDED, + phi::IndexSelectGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/index_select_kernel.cc b/paddle/phi/kernels/stride/index_select_kernel.cc index ea278226ee6c2..a391fcf14bcd2 100644 --- a/paddle/phi/kernels/stride/index_select_kernel.cc +++ b/paddle/phi/kernels/stride/index_select_kernel.cc @@ -57,5 +57,7 @@ void IndexSelectStridedKernel(const Context& ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - index_select_strided, STRIDED, phi::IndexSelectStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(index_select_strided, + STRIDED, + phi::IndexSelectStridedKernel) {} diff --git a/paddle/phi/kernels/stride/reshape_grad_kernel.cc b/paddle/phi/kernels/stride/reshape_grad_kernel.cc index 4d55c67fbcf0b..9edbb46711757 100644 --- a/paddle/phi/kernels/stride/reshape_grad_kernel.cc +++ b/paddle/phi/kernels/stride/reshape_grad_kernel.cc @@ -40,7 +40,10 @@ void ReshapeDoubleGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - reshape_grad, STRIDED, phi::ReshapeGradStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - reshape_double_grad, STRIDED, phi::ReshapeDoubleGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(reshape_grad, + STRIDED, + phi::ReshapeGradStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(reshape_double_grad, + STRIDED, + phi::ReshapeDoubleGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/reshape_kernel.cc b/paddle/phi/kernels/stride/reshape_kernel.cc index 9d94e53314193..02d36d825c36a 100644 --- a/paddle/phi/kernels/stride/reshape_kernel.cc +++ b/paddle/phi/kernels/stride/reshape_kernel.cc @@ -16,8 +16,8 @@ #include #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/contiguous_kernel.h" #include "paddle/phi/kernels/funcs/strided_reshape_utils.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" namespace phi { template @@ -49,8 +49,7 @@ void ReshapeStridedKernel(const Context& dev_ctx, tmp_x.set_strides(x_stride); tmp.set_meta(tmp_x.meta()); PD_VISIT_ALL_TYPES(x.dtype(), "ReshapeStridedKernel", ([&] { - phi::ContiguousKernel( - dev_ctx, tmp_x, &tmp); + phi::StridedTensorContiguous(tmp_x, &tmp); })); out->set_strides(DenseTensorMeta::calc_strides(out->dims())); out->set_offset(0); @@ -59,5 +58,7 @@ void ReshapeStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - reshape, STRIDED, phi::ReshapeStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(reshape, + STRIDED, + phi::ReshapeStridedKernel) {} diff --git a/paddle/phi/kernels/stride/slice_grad_kernel.cc b/paddle/phi/kernels/stride/slice_grad_kernel.cc index 171c20b3b83ac..5e519ceed4c82 100644 --- a/paddle/phi/kernels/stride/slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/slice_grad_kernel.cc @@ -15,9 +15,8 @@ #include "paddle/phi/kernels/slice_grad_kernel.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/fill_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" #include "paddle/phi/kernels/slice_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" namespace phi { @@ -34,8 +33,8 @@ void SliceGradStridedKernel(const Context& dev_ctx, dev_ctx.Alloc(input_grad, input_grad->dtype()); input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - phi::FillKernel( - dev_ctx, *input_grad, 0, input_grad); + phi::StridedTensorFill( + *input_grad, 0, input_grad); })); DenseTensor tmp; tmp.set_meta(out_grad.meta()); @@ -48,8 +47,7 @@ void SliceGradStridedKernel(const Context& dev_ctx, decrease_axis, &tmp); PD_VISIT_ALL_TYPES(input.dtype(), "SliceGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( out_grad, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -57,7 +55,8 @@ void SliceGradStridedKernel(const Context& dev_ctx, &tmp); })); } - } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - slice_grad, STRIDED, phi::SliceGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice_grad, + STRIDED, + phi::SliceGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/slice_kernel.cc b/paddle/phi/kernels/stride/slice_kernel.cc index 3e21360ce09d0..b5efcd49166fd 100644 --- a/paddle/phi/kernels/stride/slice_kernel.cc +++ b/paddle/phi/kernels/stride/slice_kernel.cc @@ -59,8 +59,7 @@ void SliceStridedKernel(const Context& ctx, std::vector decrease_flag(output_dims.size(), 0); if (!decrease_axis.empty()) { - for (int i = 0; i < static_cast(decrease_axis.size()); ++i) { - int64_t axis = decrease_axis[i]; + for (auto axis : decrease_axis) { decrease_flag[axis] = 1; } @@ -96,5 +95,7 @@ void SliceStridedKernel(const Context& ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - slice, STRIDED, phi::SliceStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(slice, + STRIDED, + phi::SliceStridedKernel) {} diff --git a/paddle/phi/kernels/stride/split_kernel.cc b/paddle/phi/kernels/stride/split_kernel.cc index b5d9d0af69628..d4155186bef2b 100644 --- a/paddle/phi/kernels/stride/split_kernel.cc +++ b/paddle/phi/kernels/stride/split_kernel.cc @@ -65,8 +65,11 @@ void SplitWithNumStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - split_strided, STRIDED, phi::SplitStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - split_with_num_strided, STRIDED, phi::SplitWithNumStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(split_strided, + STRIDED, + phi::SplitStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(split_with_num_strided, + STRIDED, + phi::SplitWithNumStridedKernel) {} diff --git a/paddle/phi/kernels/stride/squeeze_grad_kernel.cc b/paddle/phi/kernels/stride/squeeze_grad_kernel.cc index 27361211e8fc0..bfb5dd508998b 100644 --- a/paddle/phi/kernels/stride/squeeze_grad_kernel.cc +++ b/paddle/phi/kernels/stride/squeeze_grad_kernel.cc @@ -31,5 +31,7 @@ void SqueezeGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - squeeze_grad, STRIDED, phi::SqueezeGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(squeeze_grad, + STRIDED, + phi::SqueezeGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/squeeze_kernel.cc b/paddle/phi/kernels/stride/squeeze_kernel.cc index b03652baee624..455afd608af91 100644 --- a/paddle/phi/kernels/stride/squeeze_kernel.cc +++ b/paddle/phi/kernels/stride/squeeze_kernel.cc @@ -124,8 +124,11 @@ void SqueezeStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - squeeze_infer, STRIDED, phi::SqueezeInferStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - squeeze, STRIDED, phi::SqueezeStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(squeeze_infer, + STRIDED, + phi::SqueezeInferStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(squeeze, + STRIDED, + phi::SqueezeStridedKernel) {} diff --git a/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc b/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc index f0cd2d53bc823..2a48d804399f8 100644 --- a/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc +++ b/paddle/phi/kernels/stride/strided_slice_grad_kernel.cc @@ -15,8 +15,7 @@ #include "paddle/phi/kernels/strided_slice_grad_kernel.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/fill_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" #include "paddle/phi/kernels/strided_slice_kernel.h" namespace phi { @@ -34,8 +33,7 @@ void StridedSliceRawGradStridedKernel(const Context& dev_ctx, dev_ctx.Alloc(x_grad, x_grad->dtype()); x_grad->set_strides(DenseTensorMeta::calc_strides(x_grad->dims())); PD_VISIT_ALL_TYPES(x_grad->dtype(), "StridedSliceRawGradStridedKernel", ([&] { - phi::FillKernel( - dev_ctx, *x_grad, 0, x_grad); + phi::StridedTensorFill(*x_grad, 0, x_grad); })); DenseTensor tmp; tmp.set_layout(out_grad.layout()); @@ -53,8 +51,7 @@ void StridedSliceRawGradStridedKernel(const Context& dev_ctx, &tmp); PD_VISIT_ALL_TYPES( out_grad.dtype(), "StridedSliceRawGradStridedKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( out_grad, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -87,8 +84,10 @@ void StridedSliceGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE( strided_slice_raw_grad, STRIDED, phi::StridedSliceRawGradStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - strided_slice_grad, STRIDED, phi::StridedSliceGradStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(strided_slice_grad, + STRIDED, + phi::StridedSliceGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/strided_slice_kernel.cc b/paddle/phi/kernels/stride/strided_slice_kernel.cc index f3b36565def3e..241a2ac17df74 100644 --- a/paddle/phi/kernels/stride/strided_slice_kernel.cc +++ b/paddle/phi/kernels/stride/strided_slice_kernel.cc @@ -93,8 +93,8 @@ void StridedSliceRawStridedKernel(const Context& dev_ctx, if (!decrease_axis.empty()) { std::vector new_out_shape; std::vector new_out_stride; - for (size_t i = 0; i < decrease_axis.size(); ++i) { - output_dims[decrease_axis[i]] = 0; + for (auto de_axis : decrease_axis) { + output_dims[de_axis] = 0; } for (size_t i = 0; i < output_dims.size(); ++i) { @@ -139,8 +139,11 @@ void StridedSliceStridedKernel(const Context& dev_ctx, dev_ctx, x, axes, starts, ends, strides, infer_flags, decrease_axis, out); } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - strided_slice_raw, STRIDED, phi::StridedSliceRawStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - strided_slice, STRIDED, phi::StridedSliceStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(strided_slice_raw, + STRIDED, + phi::StridedSliceRawStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(strided_slice, + STRIDED, + phi::StridedSliceStridedKernel) {} diff --git a/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc b/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc index 7dc3e6e46361b..03cb979f38363 100644 --- a/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc +++ b/paddle/phi/kernels/stride/tensor_unfold_grad_kernel.cc @@ -14,8 +14,7 @@ #include "paddle/phi/kernels/tensor_unfold_grad_kernel.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/fill_kernel.h" -#include "paddle/phi/kernels/strided_copy_kernel.h" +#include "paddle/phi/kernels/funcs/strided_utils.h" #include "paddle/phi/kernels/tensor_unfold_kernel.h" namespace phi { @@ -35,8 +34,8 @@ void TensorUnfoldGradKernel(const Context& dev_ctx, input_grad->set_strides(DenseTensorMeta::calc_strides(input_grad->dims())); if (out_grad.numel() < input.numel()) { PD_VISIT_ALL_TYPES(input_grad->dtype(), "TensorUnfoldGradKernel", ([&] { - phi::FillKernel( - dev_ctx, *input_grad, 0, input_grad); + phi::StridedTensorFill( + *input_grad, 0, input_grad); })); } DenseTensor tmp; @@ -47,8 +46,7 @@ void TensorUnfoldGradKernel(const Context& dev_ctx, TensorUnfoldKernel(dev_ctx, *input_grad, axis, size, step, &tmp); PD_VISIT_ALL_TYPES(out_grad.dtype(), "TensorUnfoldGradKernel", ([&] { - phi::StridedCopyKernel( - dev_ctx, + phi::StridedTensorCopy( out_grad, common::vectorize(tmp.dims()), common::vectorize(tmp.strides()), @@ -58,5 +56,7 @@ void TensorUnfoldGradKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - tensor_unfold_grad, STRIDED, phi::TensorUnfoldGradKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(tensor_unfold_grad, + STRIDED, + phi::TensorUnfoldGradKernel) {} diff --git a/paddle/phi/kernels/stride/tensor_unfold_kernel.cc b/paddle/phi/kernels/stride/tensor_unfold_kernel.cc index 79643ac3dc514..8c1751737efd8 100644 --- a/paddle/phi/kernels/stride/tensor_unfold_kernel.cc +++ b/paddle/phi/kernels/stride/tensor_unfold_kernel.cc @@ -71,5 +71,7 @@ void TensorUnfoldKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - tensor_unfold, STRIDED, phi::TensorUnfoldKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(tensor_unfold, + STRIDED, + phi::TensorUnfoldKernel) {} diff --git a/paddle/phi/kernels/stride/transpose_grad_kernel.cc b/paddle/phi/kernels/stride/transpose_grad_kernel.cc index 51295658393c4..b20340cb20817 100644 --- a/paddle/phi/kernels/stride/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/stride/transpose_grad_kernel.cc @@ -25,16 +25,16 @@ void TransposeGradStridedKernel(const Context& dev_ctx, const std::vector& axis, DenseTensor* x_grad) { size_t axis_size = axis.size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { - formated_axis[i] = static_cast(axis[i] + axis_size); + formatted_axis[i] = static_cast(axis[i] + axis_size); } } std::vector reversed_axis(axis); for (int i = 0; i < static_cast(axis_size); i++) { - reversed_axis[formated_axis[i]] = i; + reversed_axis[formatted_axis[i]] = i; } TransposeStridedKernel(dev_ctx, out_grad, reversed_axis, x_grad); @@ -42,5 +42,6 @@ void TransposeGradStridedKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - transpose_grad, STRIDED, phi::TransposeGradStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(transpose_grad, + STRIDED, + phi::TransposeGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/transpose_kernel.cc b/paddle/phi/kernels/stride/transpose_kernel.cc index acdc321ad0e8a..82e5e3096e959 100644 --- a/paddle/phi/kernels/stride/transpose_kernel.cc +++ b/paddle/phi/kernels/stride/transpose_kernel.cc @@ -24,18 +24,18 @@ void TransposeStridedKernel(const Context& ctx, const std::vector& axis, DenseTensor* out) { size_t x_rank = x.dims().size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { - formated_axis[i] = static_cast(axis[i] + x_rank); + formatted_axis[i] = static_cast(axis[i] + x_rank); } } auto meta = out->meta(); auto in_stride = x.strides(); meta.strides = in_stride; - for (int i = 0; i < static_cast(formated_axis.size()); i++) { - meta.strides[i] = in_stride[formated_axis[i]]; + for (int i = 0; i < static_cast(formatted_axis.size()); i++) { + meta.strides[i] = in_stride[formatted_axis[i]]; } meta.offset = x.offset(); @@ -46,5 +46,6 @@ void TransposeStridedKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - transpose, STRIDED, phi::TransposeStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(transpose, + STRIDED, + phi::TransposeStridedKernel) {} diff --git a/paddle/phi/kernels/stride/unbind_kernel.cc b/paddle/phi/kernels/stride/unbind_kernel.cc index 4409fa7e786c7..6a0eb6043bb6d 100644 --- a/paddle/phi/kernels/stride/unbind_kernel.cc +++ b/paddle/phi/kernels/stride/unbind_kernel.cc @@ -43,5 +43,7 @@ void UnbindStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - unbind, STRIDED, phi::UnbindStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(unbind, + STRIDED, + phi::UnbindStridedKernel) {} diff --git a/paddle/phi/kernels/stride/unsqueeze_grad_kernel.cc b/paddle/phi/kernels/stride/unsqueeze_grad_kernel.cc index c6c5c117cd94e..d25e96115b7fc 100644 --- a/paddle/phi/kernels/stride/unsqueeze_grad_kernel.cc +++ b/paddle/phi/kernels/stride/unsqueeze_grad_kernel.cc @@ -30,5 +30,7 @@ void UnsqueezeGradStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - unsqueeze_grad, STRIDED, phi::UnsqueezeGradStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(unsqueeze_grad, + STRIDED, + phi::UnsqueezeGradStridedKernel) {} diff --git a/paddle/phi/kernels/stride/unsqueeze_kernel.cc b/paddle/phi/kernels/stride/unsqueeze_kernel.cc index bd1a200ea0eaa..901cf10b569f0 100644 --- a/paddle/phi/kernels/stride/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/stride/unsqueeze_kernel.cc @@ -85,8 +85,11 @@ void UnsqueezeStridedKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - unsqueeze_infer, STRIDED, phi::UnsqueezeInferStridedKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - unsqueeze, STRIDED, phi::UnsqueezeStridedKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(unsqueeze_infer, + STRIDED, + phi::UnsqueezeInferStridedKernel) {} + +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(unsqueeze, + STRIDED, + phi::UnsqueezeStridedKernel) {} diff --git a/paddle/phi/kernels/stride/view_grad_kernel.cc b/paddle/phi/kernels/stride/view_grad_kernel.cc index 19674670b2707..44037c57ab794 100644 --- a/paddle/phi/kernels/stride/view_grad_kernel.cc +++ b/paddle/phi/kernels/stride/view_grad_kernel.cc @@ -38,8 +38,10 @@ void ViewDtypeGradKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - view_shape_grad, STRIDED, phi::ViewShapeGradKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(view_shape_grad, + STRIDED, + phi::ViewShapeGradKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM( - view_dtype_grad, STRIDED, phi::ViewDtypeGradKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(view_dtype_grad, + STRIDED, + phi::ViewDtypeGradKernel) {} diff --git a/paddle/phi/kernels/stride/view_kernel.cc b/paddle/phi/kernels/stride/view_kernel.cc index f4685902da29f..8b6ab5ecfd7ec 100644 --- a/paddle/phi/kernels/stride/view_kernel.cc +++ b/paddle/phi/kernels/stride/view_kernel.cc @@ -149,10 +149,10 @@ void ViewDtypeKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(view_shape, - STRIDED, - phi::ViewShapeKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(view_shape, + STRIDED, + phi::ViewShapeKernel) {} -PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM(view_dtype, - STRIDED, - phi::ViewDtypeKernel) {} +PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE(view_dtype, + STRIDED, + phi::ViewDtypeKernel) {} diff --git a/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu b/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu index 832d9bbf73c0b..2a238e8a49b4d 100644 --- a/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu +++ b/paddle/phi/kernels/strings/gpu/strings_lower_upper_kernel.cu @@ -56,7 +56,7 @@ struct UTF8CaseConverter { pstring* out, size_t num) const { auto unicode_flag_map = GetGPUUniflagMap(); - auto cases_map = GetGPUCharcasesMap(); + auto cases_map = GetGPUCharCasesMap(); thrust::device_vector unicode_offsets(num + 1, 0); uint32_t* unicode_offsets_ptr = thrust::raw_pointer_cast(unicode_offsets.data()); diff --git a/paddle/phi/kernels/strings/strings_lower_upper_kernel.h b/paddle/phi/kernels/strings/strings_lower_upper_kernel.h index a8d7f2dda94f7..a7c1d4a0936fc 100644 --- a/paddle/phi/kernels/strings/strings_lower_upper_kernel.h +++ b/paddle/phi/kernels/strings/strings_lower_upper_kernel.h @@ -60,13 +60,13 @@ StringTensor StringUpper(const ContextT& dev_ctx, return string_out; } -template +template struct StringCaseConvertKernel { void operator()(const ContextT& dev_ctx, const StringTensor& x, bool use_utf8_encoding, StringTensor* out) { - AsciiCoverter ascii_converter; + AsciiConverter ascii_converter; UTF8Converter utf8_converter; const pstring* in_ptr = x.data(); pstring* out_ptr = dev_ctx.template Alloc(out); @@ -101,7 +101,7 @@ struct UTF8CaseConverter { pstring* out, size_t num) const { auto unicode_flag_map = GetUniFlagMap(); - auto cases_map = GetCharcasesMap(); + auto cases_map = GetCharCasesMap(); for (size_t i = 0; i < num; ++i) { uint32_t unicode_len = GetUnicodeStrLen(in[i].data(), in[i].size()); std::vector unicode_in(unicode_len, 0); diff --git a/paddle/phi/kernels/strings/unicode.cc b/paddle/phi/kernels/strings/unicode.cc index 292160e2b2db1..71d9ef36cd16d 100644 --- a/paddle/phi/kernels/strings/unicode.cc +++ b/paddle/phi/kernels/strings/unicode.cc @@ -23,7 +23,7 @@ namespace phi { namespace strings { static const void* utils_map[4] = {nullptr}; // NOLINT -static uint16_t CHARCASES_MAP[65536] = {0}; // NOLINT +static uint16_t CHAR_CASES_MAP[65536] = {0}; // NOLINT const uint8_t* GetUniFlagMap() { if (utils_map[1] == nullptr) { @@ -32,16 +32,16 @@ const uint8_t* GetUniFlagMap() { return reinterpret_cast(utils_map[1]); } -const uint16_t* GetCharcasesMap() { +const uint16_t* GetCharCasesMap() { if (utils_map[0] == nullptr) { for (uint32_t i = 0; i < 65536; ++i) { if (utf8proc_islower(static_cast(i))) { - CHARCASES_MAP[i] = utf8proc_toupper(static_cast(i)); + CHAR_CASES_MAP[i] = utf8proc_toupper(static_cast(i)); } else if (utf8proc_isupper(static_cast(i))) { - CHARCASES_MAP[i] = utf8proc_tolower(static_cast(i)); + CHAR_CASES_MAP[i] = utf8proc_tolower(static_cast(i)); } } - utils_map[0] = CHARCASES_MAP; + utils_map[0] = CHAR_CASES_MAP; } return reinterpret_cast(utils_map[0]); } @@ -67,21 +67,21 @@ const uint8_t* GetGPUUniflagMap() { return reinterpret_cast(utils_map[3]); } -const uint16_t* GetGPUCharcasesMap() { +const uint16_t* GetGPUCharCasesMap() { if (utils_map[2] == nullptr) { - const uint16_t* cpu_charcases = GetCharcasesMap(); - auto size = sizeof(CHARCASES_MAP); - uint16_t* gpu_charcases; + const uint16_t* cpu_char_cases = GetCharCasesMap(); + auto size = sizeof(CHAR_CASES_MAP); + uint16_t* gpu_char_cases; #ifdef PADDLE_WITH_HIP - hipMalloc(reinterpret_cast(&gpu_charcases), size); + hipMalloc(reinterpret_cast(&gpu_char_cases), size); phi::backends::gpu::GpuMemcpySync( - gpu_charcases, cpu_charcases, size, hipMemcpyHostToDevice); + gpu_char_cases, cpu_char_cases, size, hipMemcpyHostToDevice); #else - cudaMalloc(reinterpret_cast(&gpu_charcases), size); + cudaMalloc(reinterpret_cast(&gpu_char_cases), size); phi::backends::gpu::GpuMemcpySync( - gpu_charcases, cpu_charcases, size, cudaMemcpyHostToDevice); + gpu_char_cases, cpu_char_cases, size, cudaMemcpyHostToDevice); #endif - utils_map[2] = gpu_charcases; + utils_map[2] = gpu_char_cases; } return reinterpret_cast(utils_map[2]); } diff --git a/paddle/phi/kernels/strings/unicode.h b/paddle/phi/kernels/strings/unicode.h index 6dfb6aeb6ede6..48c07dbf8dd4f 100644 --- a/paddle/phi/kernels/strings/unicode.h +++ b/paddle/phi/kernels/strings/unicode.h @@ -169,7 +169,7 @@ HOSTDEVICE inline uint32_t GetUTF8StrLen(const uint32_t* unicode_str, // +1 means '\0' return utf8_str_count + 1; } -// Need to gurantee utf8_str has enough memory +// Need to guarantee utf8_str has enough memory HOSTDEVICE inline void GetUTF8Str(const uint32_t* unicode_str, char* utf8_str, @@ -186,12 +186,12 @@ HOSTDEVICE inline void GetUTF8Str(const uint32_t* unicode_str, } const uint8_t* GetUniFlagMap(); -const uint16_t* GetCharcasesMap(); +const uint16_t* GetCharCasesMap(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) const uint8_t* GetGPUUniflagMap(); -const uint16_t* GetGPUCharcasesMap(); +const uint16_t* GetGPUCharCasesMap(); #endif } // namespace strings diff --git a/paddle/phi/kernels/transfer_layout_kernel.cc b/paddle/phi/kernels/transfer_layout_kernel.cc index 656b92dffbf30..569be5ce9781f 100644 --- a/paddle/phi/kernels/transfer_layout_kernel.cc +++ b/paddle/phi/kernels/transfer_layout_kernel.cc @@ -166,7 +166,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, out->set_mem_desc(out_mem_desc); } else if (src_layout == DataLayout::ONEDNN && dst_layout != DataLayout::ONEDNN) { - // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel + // Case2 - transform from MKLDNN OPKernel to Non-MKLDNN OPKernel // Do transform via MKLDNN lib funcs::TransDataLayoutFromOneDNN( src_layout, dst_layout, x, out, dev_ctx.GetPlace()); diff --git a/paddle/phi/kernels/xpu/adamw_kernel.cc b/paddle/phi/kernels/xpu/adamw_kernel.cc index ca39a9932a609..f60e02c61a323 100644 --- a/paddle/phi/kernels/xpu/adamw_kernel.cc +++ b/paddle/phi/kernels/xpu/adamw_kernel.cc @@ -140,6 +140,109 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, MPDType* master_out_data = multi_precision ? dev_ctx.template Alloc(master_param_outs) : nullptr; + + // check moment_dtype + auto moment1_dtype = moment1.dtype(); + auto moment2_dtype = moment2.dtype(); + PADDLE_ENFORCE_EQ(moment1_dtype, + moment1_out->dtype(), + errors::InvalidArgument( + "moment1.dtype does not match moment1_out->dtype")); + PADDLE_ENFORCE_EQ(moment2_dtype, + moment2_out->dtype(), + errors::InvalidArgument( + "moment2.dtype does not match moment2_out->dtype")); + PADDLE_ENFORCE_EQ( + moment1_dtype, + moment2_dtype, + errors::InvalidArgument("moment1.dtype does not match moment2.dtype")); + + bool moment_in_fp16 = false; + if (moment1_dtype == phi::DataType::FLOAT16) { + moment_in_fp16 = true; + } else { + PADDLE_ENFORCE_EQ( + moment1_dtype, + phi::DataType::FLOAT32, + errors::InvalidArgument("moment1.dtype is neither fp32 nor fp16")); + } + + float* moment1_input_for_xdnn = nullptr; + float* moment2_input_for_xdnn = nullptr; + float* moment1_output_for_xdnn = nullptr; + float* moment2_output_for_xdnn = nullptr; + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + if (moment_in_fp16) { + // allocate temp buffer on XPU + moment1_input_for_xdnn = RAII_GUARD.alloc_l3_or_gm(moment1.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment1_input_for_xdnn); + moment2_input_for_xdnn = RAII_GUARD.alloc_l3_or_gm(moment2.numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment2_input_for_xdnn); + moment1_output_for_xdnn = + RAII_GUARD.alloc_l3_or_gm(moment1_out->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment1_output_for_xdnn); + moment2_output_for_xdnn = + RAII_GUARD.alloc_l3_or_gm(moment2_out->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(moment2_output_for_xdnn); + + int r = 0; + using XPUType16 = typename XPUTypeTrait::Type; + + // cast moment1 and moment2, from fp16 to fp32 + // int cast(Context* ctx, const TX* x, TY* y, int64_t len); + r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast( + moment1.template data()), + moment1_input_for_xdnn, + moment1.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment1 from fp16 to float"); + r = xpu::cast( + dev_ctx.x_context(), + reinterpret_cast( + moment2.template data()), + moment2_input_for_xdnn, + moment2.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment2 from fp16 to float"); + + // acquire xpu_scale_value + float moment1_scale_value = XPUStorageProperties::default_xpu_scale_value; + if (moment1.storage_properties_initialized()) { + moment1_scale_value = + moment1.storage_properties().xpu_scale_value; + } + float moment2_scale_value = XPUStorageProperties::default_xpu_scale_value; + if (moment2.storage_properties_initialized()) { + moment2_scale_value = + moment2.storage_properties().xpu_scale_value; + } + + // de-scale using scale_value + // int scale(Context* ctx, const T* x, T* y, int64_t len, bool + // bias_after_scale, float _scale, float _bias); + if (moment1_scale_value > 0) { + r = xpu::scale(dev_ctx.x_context(), + moment1_input_for_xdnn, + moment1_input_for_xdnn, + moment1.numel(), + false, + 1.0f / moment1_scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "de-scale for moment1"); + } + if (moment2_scale_value > 0) { + r = xpu::scale(dev_ctx.x_context(), + moment2_input_for_xdnn, + moment2_input_for_xdnn, + moment2.numel(), + false, + 1.0f / moment2_scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "de-scale for moment2"); + } + } + // template DLL_EXPORT int // adamw_v2(Context* ctx, MT beta1, MT beta2, MT epsilon, MT coeff, MT // lr_ratio, const MT* beta1_pow, MT* beta1_pow_out, const MT* beta2_pow, MT* @@ -168,10 +271,14 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, nullptr, beta2_pow_ptr, nullptr, - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), + moment_in_fp16 ? moment1_input_for_xdnn + : moment1.template data(), + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_input_for_xdnn + : moment2.template data(), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), learning_rate.data(), grad.data(), reinterpret_cast(param.data()), @@ -179,7 +286,7 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, master_in_data, master_out_data, param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } else { int r = xpu::adamw_v2( dev_ctx.x_context(), @@ -192,10 +299,14 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, nullptr, beta2_pow_ptr, nullptr, - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), + moment_in_fp16 ? moment1_input_for_xdnn + : moment1.template data(), + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_input_for_xdnn + : moment2.template data(), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), learning_rate.data(), reinterpret_cast(grad.data()), reinterpret_cast(param.data()), @@ -203,7 +314,7 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, master_in_data, master_out_data, param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } if (!use_global_beta_pow) { // Cpu update @@ -230,13 +341,17 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, coeff_, lr_ratio_, beta1_pow.data(), - beta1_pow_out_ptr, + nullptr, // beta1_pow_out_ptr, beta2_pow.data(), - beta2_pow_out_ptr, - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), + nullptr, // beta2_pow_out_ptr, + moment_in_fp16 ? moment1_input_for_xdnn + : moment1.template data(), + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_input_for_xdnn + : moment2.template data(), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), learning_rate.data(), grad.data(), reinterpret_cast(param.data()), @@ -244,7 +359,7 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, master_in_data, master_out_data, param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); } else { int r = xpu::adamw_v2( dev_ctx.x_context(), @@ -254,13 +369,17 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, coeff_, lr_ratio_, beta1_pow.data(), - beta1_pow_out_ptr, + nullptr, // beta1_pow_out_ptr, beta2_pow.data(), - beta2_pow_out_ptr, - moment1.data(), - dev_ctx.template Alloc(moment1_out), - moment2.data(), - dev_ctx.template Alloc(moment2_out), + nullptr, // beta2_pow_out_ptr, + moment_in_fp16 ? moment1_input_for_xdnn + : moment1.template data(), + moment_in_fp16 ? moment1_output_for_xdnn + : dev_ctx.template Alloc(moment1_out), + moment_in_fp16 ? moment2_input_for_xdnn + : moment2.template data(), + moment_in_fp16 ? moment2_output_for_xdnn + : dev_ctx.template Alloc(moment2_out), learning_rate.data(), reinterpret_cast(grad.data()), reinterpret_cast(param.data()), @@ -268,9 +387,98 @@ void AdamwDenseKernelKL3(const Context& dev_ctx, master_in_data, master_out_data, param.numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adamw_v2"); + } + if (!use_global_beta_pow) { + // update beta1_pow and beta2_pow + int r = xpu::scale(dev_ctx.x_context(), + beta1_pow.data(), + beta1_pow_out_ptr, + beta1_pow.numel(), + false, + beta1_, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + r = xpu::scale(dev_ctx.x_context(), + beta2_pow.data(), + beta2_pow_out_ptr, + beta2_pow.numel(), + false, + beta2_, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); } } + + if (moment_in_fp16) { + int r = 0; + using XPUType16 = typename XPUTypeTrait::Type; + + // findmax and calculate scale_value for moment1 and moment2 + int max_ptr_size = phi::backends::xpu::get_xpu_max_ptr_size(-1); + float* buffer_for_findmax = RAII_GUARD.alloc_l3_or_gm(max_ptr_size); + + // for moment1 + float moment1_max = GetAbsMax(dev_ctx, + moment1_output_for_xdnn, + buffer_for_findmax, + moment1_out->numel()); + float moment1_scale_value = 65504.0f / moment1_max / 2.0f; + // int scale(Context* ctx, const T* x, T* y, int64_t len, bool + // bias_after_scale, float _scale, float _bias); + r = xpu::scale(dev_ctx.x_context(), + moment1_output_for_xdnn, + moment1_output_for_xdnn, + moment1_out->numel(), + false, + moment1_scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS( + r, "scale before convert to fp16, for moment1_output_for_xdnn"); + // write to moment1_out + std::unique_ptr moment1_out_sp = + std::make_unique(moment1_scale_value); + moment1_out->set_storage_properties(std::move(moment1_out_sp)); + + // for moment2 + float moment2_max = GetAbsMax(dev_ctx, + moment2_output_for_xdnn, + buffer_for_findmax, + moment2_out->numel()); + float moment2_scale_value = 65504.0f / moment2_max / 2.0f; + // int scale(Context* ctx, const T* x, T* y, int64_t len, bool + // bias_after_scale, float _scale, float _bias); + r = xpu::scale(dev_ctx.x_context(), + moment2_output_for_xdnn, + moment2_output_for_xdnn, + moment2_out->numel(), + false, + moment2_scale_value, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS( + r, "scale before convert to fp16, for moment2_output_for_xdnn"); + // write to moment2_out + std::unique_ptr moment2_out_sp = + std::make_unique(moment2_scale_value); + moment2_out->set_storage_properties(std::move(moment2_out_sp)); + + // cast moment1 and moment2 output, from fp32 to fp16 + // int cast(Context* ctx, const TX* x, TY* y, int64_t len); + r = xpu::cast( + dev_ctx.x_context(), + moment1_output_for_xdnn, + reinterpret_cast( + dev_ctx.template Alloc(moment1_out)), + moment1.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment1_out from float to fp16"); + r = xpu::cast( + dev_ctx.x_context(), + moment2_output_for_xdnn, + reinterpret_cast( + dev_ctx.template Alloc(moment2_out)), + moment2.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast moment2_out from float to fp16"); + } return; } diff --git a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc index 454141ff4c3ea..7579d4f922d64 100644 --- a/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_grad_kernel.cc @@ -96,7 +96,7 @@ void BatchNormGradKernel(const Context &dev_ctx, true, phi::errors::InvalidArgument( "The 'data_layout' attribute must be NCHW or NHWC. " - "But recevived 'data_layout' is [%s].", + "But received 'data_layout' is [%s].", data_layout)); const auto data_layout_val = common::StringToDataLayout(data_layout); @@ -120,7 +120,7 @@ void BatchNormGradKernel(const Context &dev_ctx, x_dims.size() >= 2 && x_dims.size() <= 5, true, phi::errors::InvalidArgument( - "The size of input's dimensions should be between 2 and 5" + "The size of input's dimensions should be between 2 and 5. " "But received: the size of input's dimensions is [%d]", x_dims.size())); @@ -192,7 +192,7 @@ void BatchNormGradKernel(const Context &dev_ctx, const auto *global_mean = mean.get_ptr(); const auto *global_var = variance.get_ptr(); - // TODO(guozibin): hadle the situation case of N * H * W = 1 + // TODO(guozibin): handle the situation case of N * H * W = 1 int r = 0; if (is_inplace) { float *global_inv_std_data = nullptr; diff --git a/paddle/phi/kernels/xpu/batch_norm_kernel.cc b/paddle/phi/kernels/xpu/batch_norm_kernel.cc index 8427c49b43d42..81dd253460337 100644 --- a/paddle/phi/kernels/xpu/batch_norm_kernel.cc +++ b/paddle/phi/kernels/xpu/batch_norm_kernel.cc @@ -48,7 +48,7 @@ void BatchNormKernel(const Context& dev_ctx, true, phi::errors::InvalidArgument( "The 'data_layout' attribute must be NCHW or NHWC. " - "But recevived 'data_layout' is [%s].", + "But received 'data_layout' is [%s].", data_layout_str)); const auto& x_dims = x.dims(); @@ -104,7 +104,7 @@ void BatchNormKernel(const Context& dev_ctx, 5, phi::errors::InvalidArgument( "The size of input X's dimensions should be less than 6." - "But received: the size of input X's dimensionss is [%d]", + "But received: the size of input X's dimensions is [%d]", x_dims.size())); bool is_nchw = data_layout_str == "NCHW"; diff --git a/paddle/phi/kernels/xpu/bitwise.cc b/paddle/phi/kernels/xpu/bitwise.cc index dee96be39e185..c9eb0d93a66f0 100644 --- a/paddle/phi/kernels/xpu/bitwise.cc +++ b/paddle/phi/kernels/xpu/bitwise.cc @@ -39,7 +39,7 @@ void BitwiseAndKernel(const Context& ctx, const DenseTensor& y, DenseTensor* out) { // XPU api do not support bitwise operation now. - // However, because biwise and logical operation is identical for bool type, + // However, because bitwise and logical operation is identical for bool type, // we can implement bitwise_and_bool kernel by calling their logical // counterpart. Need to be changed when adding support to other types. LogicalAndKernel(ctx, x, y, out); diff --git a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc index cbc98dd7ad9ac..e2fdbb610d2a2 100644 --- a/paddle/phi/kernels/xpu/bmm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/bmm_grad_kernel.cc @@ -25,17 +25,17 @@ void MatMul(const Context& dev_ctx, const DenseTensor& b, bool trans_b, DenseTensor* out) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); xpu::Context* xpu_ctx = dev_ctx.x_context(); - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT16) { + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT16) { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); } else { MatMulXPUFunction(a, b, out, trans_a, trans_b, xpu_ctx); diff --git a/paddle/phi/kernels/xpu/bmm_kernel.cc b/paddle/phi/kernels/xpu/bmm_kernel.cc index ae80f12747ac1..3ce7d6578dfad 100644 --- a/paddle/phi/kernels/xpu/bmm_kernel.cc +++ b/paddle/phi/kernels/xpu/bmm_kernel.cc @@ -20,7 +20,7 @@ void BmmKernel(const Context& dev_ctx, const DenseTensor& x, const DenseTensor& y, DenseTensor* out) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; dev_ctx.template Alloc(out); if (x.numel() == 0 || y.numel() == 0) { return; @@ -63,14 +63,14 @@ void BmmKernel(const Context& dev_ctx, y_dims[1])); xpu::Context* xpu_ctx = dev_ctx.x_context(); - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT16) { + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT16) { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); } else { MatMulXPUFunction(x, y, out, trans_x, trans_y, xpu_ctx); diff --git a/paddle/phi/kernels/xpu/bmm_xpu_utils.h b/paddle/phi/kernels/xpu/bmm_xpu_utils.h index 90d5b51973957..c7c6bfe2bed64 100644 --- a/paddle/phi/kernels/xpu/bmm_xpu_utils.h +++ b/paddle/phi/kernels/xpu/bmm_xpu_utils.h @@ -40,7 +40,7 @@ static void MatMulXPUFunction(const DenseTensor& x, int k = mat_dim_a.width_; int batch_size = mat_dim_a.batch_size_; // batch matmul - int fccal_type = FCCalcType(); + int fc_calc_type = FCCalcType(); decltype(&xblas_fc_batch_wrapper) xblas_fc_batch_api_list[6] = { &xblas_fc_batch_wrapper, @@ -51,8 +51,8 @@ static void MatMulXPUFunction(const DenseTensor& x, &xblas_fc_batch_wrapper, }; - auto xblas_fc_batch_api = xblas_fc_batch_api_list[fccal_type]; - if (fccal_type == XPUFCCalcType::FC_FLOAT16 && + auto xblas_fc_batch_api = xblas_fc_batch_api_list[fc_calc_type]; + if (fc_calc_type == XPUFCCalcType::FC_FLOAT16 && std::getenv("XPU_PADDLE_FC_FLOAT16") != nullptr) { xblas_fc_batch_api = &xblas_fc_batch_wrapper; diff --git a/paddle/phi/kernels/xpu/concat_and_split_functor.cc b/paddle/phi/kernels/xpu/concat_and_split_functor.cc index a1335f33b6700..08d2832107d70 100644 --- a/paddle/phi/kernels/xpu/concat_and_split_functor.cc +++ b/paddle/phi/kernels/xpu/concat_and_split_functor.cc @@ -139,6 +139,7 @@ class SplitFunctor { DEFINE_XPU_FUNCTOR(float) DEFINE_XPU_FUNCTOR(phi::dtype::float16) +DEFINE_XPU_FUNCTOR(phi::dtype::bfloat16) DEFINE_XPU_FUNCTOR(int32_t) DEFINE_XPU_FUNCTOR(int64_t) DEFINE_XPU_FUNCTOR(uint8_t) diff --git a/paddle/phi/kernels/xpu/conv_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_grad_kernel.cc index 03276ebd53b5f..cf5162a71e108 100644 --- a/paddle/phi/kernels/xpu/conv_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_grad_kernel.cc @@ -34,7 +34,7 @@ void ConvGradKernel(const Context& dev_ctx, const std::string& data_format, DenseTensor* input_grad, DenseTensor* filter_grad) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; std::vector paddings = paddings_t; std::vector dilations = dilations_t; // The filter and filter_grad will be reshaped in the calculations, @@ -69,153 +69,157 @@ void ConvGradKernel(const Context& dev_ctx, is_nchw = false; } - const XPUT* input_data = reinterpret_cast(input.data()); - const XPUT* filter_data = reinterpret_cast(filter.data()); - const XPUT* output_grad_data = - reinterpret_cast(out_grad.data()); - XPUT* input_grad_data = nullptr; + const XPUType* input_data = reinterpret_cast(input.data()); + const XPUType* filter_data = + reinterpret_cast(filter.data()); + const XPUType* output_grad_data = + reinterpret_cast(out_grad.data()); + XPUType* input_grad_data = nullptr; if (input_grad) { dev_ctx.template Alloc(input_grad); - input_grad_data = reinterpret_cast(input_grad->data()); + input_grad_data = reinterpret_cast(input_grad->data()); } - XPUT* filter_grad_data = nullptr; + XPUType* filter_grad_data = nullptr; if (filter_grad) { dev_ctx.template Alloc(filter_grad); - filter_grad_data = reinterpret_cast(filter_grad->data()); + filter_grad_data = reinterpret_cast(filter_grad->data()); } xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - XPUT* filter_data_tmp; - XPUT* filter_grad_data_tmp; - const XPUT* filter_data_ptr = filter_data; - XPUT* filter_grad_data_ptr = filter_grad_data; + XPUType* filter_data_tmp; + XPUType* filter_grad_data_tmp; + const XPUType* filter_data_ptr = filter_data; + XPUType* filter_grad_data_ptr = filter_grad_data; if (data_format == "NHWC") { - filter_data_tmp = RAII_GUARD.alloc(filter.numel()); + filter_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); - int r = xpu::transpose(dev_ctx.x_context(), - filter_data, - filter_data_tmp, - filter_shape, - {0, 2, 3, 1}); + int r = xpu::transpose(dev_ctx.x_context(), + filter_data, + filter_data_tmp, + filter_shape, + {0, 2, 3, 1}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - filter_data_ptr = reinterpret_cast(filter_data_tmp); + filter_data_ptr = reinterpret_cast(filter_data_tmp); if (filter_grad_data != nullptr) { - filter_grad_data_tmp = RAII_GUARD.alloc(filter.numel()); + filter_grad_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_grad_data_tmp); filter_grad_data_ptr = filter_grad_data_tmp; } } - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { - int r = xpu::conv2d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_nchw); + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { + int r = + xpu::conv2d_grad(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { - int r = xpu::conv2d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_nchw); + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { + int r = + xpu::conv2d_grad(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { - int r = - xpu::conv2d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_nchw); + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { + int r = xpu::conv2d_grad( + dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); } else { - int r = xpu::conv2d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_nchw); + int r = xpu::conv2d_grad( + dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_grad"); } if ((filter_grad_data_ptr != nullptr) && (data_format == "NHWC")) { std::vector filter_shape_fhwc = { filter_shape[0], filter_shape[2], filter_shape[3], filter_shape[1]}; - int r = xpu::transpose(dev_ctx.x_context(), - filter_grad_data_ptr, - filter_grad_data, - filter_shape_fhwc, - {0, 3, 1, 2}); + int r = xpu::transpose(dev_ctx.x_context(), + filter_grad_data_ptr, + filter_grad_data, + filter_shape_fhwc, + {0, 3, 1, 2}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); } } @@ -260,7 +264,7 @@ void Conv3DGradKernel(const Context& dev_ctx, const std::string& data_format, DenseTensor* input_grad, DenseTensor* filter_grad) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; std::vector paddings = paddings_t; std::vector dilations = dilations_t; // The filter and filter_grad will be reshaped in the calculations, @@ -292,144 +296,148 @@ void Conv3DGradKernel(const Context& dev_ctx, is_ncdhw = false; } - const XPUT* input_data = reinterpret_cast(input.data()); - const XPUT* filter_data = reinterpret_cast(filter.data()); - const XPUT* output_grad_data = - reinterpret_cast(out_grad.data()); - XPUT* input_grad_data = nullptr; + const XPUType* input_data = reinterpret_cast(input.data()); + const XPUType* filter_data = + reinterpret_cast(filter.data()); + const XPUType* output_grad_data = + reinterpret_cast(out_grad.data()); + XPUType* input_grad_data = nullptr; if (input_grad) { dev_ctx.template Alloc(input_grad); - input_grad_data = reinterpret_cast(input_grad->data()); + input_grad_data = reinterpret_cast(input_grad->data()); } - XPUT* filter_grad_data = nullptr; + XPUType* filter_grad_data = nullptr; if (filter_grad) { dev_ctx.template Alloc(filter_grad); - filter_grad_data = reinterpret_cast(filter_grad->data()); + filter_grad_data = reinterpret_cast(filter_grad->data()); } xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - XPUT* filter_data_tmp; - XPUT* filter_grad_data_tmp; - const XPUT* filter_data_ptr = filter_data; - XPUT* filter_grad_data_ptr = filter_grad_data; + XPUType* filter_data_tmp; + XPUType* filter_grad_data_tmp; + const XPUType* filter_data_ptr = filter_data; + XPUType* filter_grad_data_ptr = filter_grad_data; if (data_format == "NDHWC") { - filter_data_tmp = RAII_GUARD.alloc(filter.numel()); + filter_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); - int r = xpu::transpose(dev_ctx.x_context(), - filter_data, - filter_data_tmp, - filter_shape, - {0, 2, 3, 4, 1}); + int r = xpu::transpose(dev_ctx.x_context(), + filter_data, + filter_data_tmp, + filter_shape, + {0, 2, 3, 4, 1}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - filter_data_ptr = reinterpret_cast(filter_data_tmp); + filter_data_ptr = reinterpret_cast(filter_data_tmp); if (filter_grad_data != nullptr) { - filter_grad_data_tmp = RAII_GUARD.alloc(filter.numel()); + filter_grad_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_grad_data_tmp); filter_grad_data_ptr = filter_grad_data_tmp; } } - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { - int r = xpu::conv3d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_ncdhw); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { - int r = xpu::conv3d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_ncdhw); + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { + int r = + xpu::conv3d_grad(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { int r = - xpu::conv3d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_ncdhw); + xpu::conv3d_grad(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_ncdhw); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { + int r = xpu::conv3d_grad( + dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); } else { - int r = xpu::conv3d_grad(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_grad_data, - input_grad_data, - filter_grad_data_ptr, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - is_ncdhw); + int r = xpu::conv3d_grad( + dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_grad_data, + input_grad_data, + filter_grad_data_ptr, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d_grad"); } @@ -439,11 +447,11 @@ void Conv3DGradKernel(const Context& dev_ctx, filter_shape[3], filter_shape[4], filter_shape[1]}; - int r = xpu::transpose(dev_ctx.x_context(), - filter_grad_data_ptr, - filter_grad_data, - filter_shape_fhwc, - {0, 4, 1, 2, 3}); + int r = xpu::transpose(dev_ctx.x_context(), + filter_grad_data_ptr, + filter_grad_data, + filter_shape_fhwc, + {0, 4, 1, 2, 3}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); } } diff --git a/paddle/phi/kernels/xpu/conv_kernel.cc b/paddle/phi/kernels/xpu/conv_kernel.cc index 0dc93d676186b..c0cfe2db83034 100644 --- a/paddle/phi/kernels/xpu/conv_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_kernel.cc @@ -32,7 +32,7 @@ void ConvKernel(const Context& dev_ctx, int groups, const std::string& data_format, DenseTensor* out) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; std::vector paddings = paddings_t; std::vector dilations = dilations_t; // The filter will be reshaped in the calculations, @@ -67,107 +67,109 @@ void ConvKernel(const Context& dev_ctx, is_nchw = false; } - const XPUT* input_data = reinterpret_cast(input.data()); - const XPUT* filter_data = reinterpret_cast(filter.data()); - XPUT* output_data = reinterpret_cast(out->data()); + const XPUType* input_data = reinterpret_cast(input.data()); + const XPUType* filter_data = + reinterpret_cast(filter.data()); + XPUType* output_data = reinterpret_cast(out->data()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - XPUT* filter_data_tmp; - const XPUT* filter_data_ptr = filter_data; + XPUType* filter_data_tmp; + const XPUType* filter_data_ptr = filter_data; if (data_format == "NHWC") { - filter_data_tmp = RAII_GUARD.alloc(filter.numel()); + filter_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); std::vector filter_shape = common::vectorize(filter.dims()); - int r = xpu::transpose(dev_ctx.x_context(), - filter_data, - filter_data_tmp, - filter_shape, - {0, 2, 3, 1}); + int r = xpu::transpose(dev_ctx.x_context(), + filter_data, + filter_data_tmp, + filter_shape, + {0, 2, 3, 1}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - filter_data_ptr = reinterpret_cast(filter_data_tmp); + filter_data_ptr = reinterpret_cast(filter_data_tmp); } - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { - int r = xpu::conv2d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_nchw); + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { + int r = xpu::conv2d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d"); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { - int r = xpu::conv2d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_nchw); + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { + int r = xpu::conv2d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d"); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { - int r = xpu::conv2d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_nchw); + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { + int r = xpu::conv2d( + dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d"); } else { - int r = xpu::conv2d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_nchw); + int r = xpu::conv2d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_nchw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d"); } } @@ -206,7 +208,7 @@ void Conv3DKernel(const Context& dev_ctx, const std::vector& dilations_t, const std::string& data_format, DenseTensor* out) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; std::vector paddings = paddings_t; std::vector dilations = dilations_t; // The filter will be reshaped in the calculations, @@ -237,112 +239,114 @@ void Conv3DKernel(const Context& dev_ctx, is_ncdhw = false; } - XPUT* output_data = reinterpret_cast(out->data()); - const XPUT* filter_data = reinterpret_cast(filter.data()); - const XPUT* input_data = reinterpret_cast(input.data()); + XPUType* output_data = reinterpret_cast(out->data()); + const XPUType* filter_data = + reinterpret_cast(filter.data()); + const XPUType* input_data = reinterpret_cast(input.data()); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - XPUT* filter_data_tmp; - const XPUT* filter_data_ptr = filter_data; + XPUType* filter_data_tmp; + const XPUType* filter_data_ptr = filter_data; if (data_format == "NDHWC") { - filter_data_tmp = RAII_GUARD.alloc(filter.numel()); + filter_data_tmp = RAII_GUARD.alloc(filter.numel()); PADDLE_ENFORCE_XDNN_NOT_NULL(filter_data_tmp); std::vector filter_shape = common::vectorize(filter.dims()); - int r = xpu::transpose(dev_ctx.x_context(), - filter_data, - filter_data_tmp, - filter_shape, - {0, 2, 3, 4, 1}); + int r = xpu::transpose(dev_ctx.x_context(), + filter_data, + filter_data_tmp, + filter_shape, + {0, 2, 3, 4, 1}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); - filter_data_ptr = reinterpret_cast(filter_data_tmp); + filter_data_ptr = reinterpret_cast(filter_data_tmp); } - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { - int r = xpu::conv3d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_ncdhw); + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { + int r = xpu::conv3d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d"); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { - int r = xpu::conv3d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_ncdhw); + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { + int r = xpu::conv3d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d"); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { - int r = xpu::conv3d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_ncdhw); + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { + int r = xpu::conv3d( + dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d"); } else { - int r = xpu::conv3d(dev_ctx.x_context(), - input_data, - filter_data_ptr, - output_data, - batch_size, - img_c, - img_d, - img_h, - img_w, - f, - ksize, - strides, - paddings, - dilations, - groups, - nullptr, - nullptr, - nullptr, - is_ncdhw); + int r = xpu::conv3d(dev_ctx.x_context(), + input_data, + filter_data_ptr, + output_data, + batch_size, + img_c, + img_d, + img_h, + img_w, + f, + ksize, + strides, + paddings, + dilations, + groups, + nullptr, + nullptr, + nullptr, + is_ncdhw); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv3d"); } } diff --git a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc index 296e02c28016d..5c911475af25f 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_grad_kernel.cc @@ -69,9 +69,9 @@ void Conv2dTransposeGradKernel(const Context& ctx, if (dfilter) { ctx.template Alloc(dfilter); } - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32 || - fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32 || + fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { // xpu api do not support int31 quantization now. int r = xpu::conv2d_transpose_grad( ctx.x_context(), diff --git a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc index 2a1195e48c1f0..d6685c998acec 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc @@ -51,7 +51,7 @@ void Conv2dTransposeKernel(const Context& ctx, const std::vector& dilations, const std::string& data_format, DenseTensor* out) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; ctx.template Alloc(out); @@ -76,8 +76,8 @@ void Conv2dTransposeKernel(const Context& ctx, const int img_xh = static_cast(out->dims()[2]); const int img_xw = static_cast(out->dims()[3]); - int fccal_type = FCCalcType(); - if (fccal_type == XPUFCCalcType::FC_INT32) { + int fc_calc_type = FCCalcType(); + if (fc_calc_type == XPUFCCalcType::FC_INT32) { int r = xpu::conv2d_transpose_v2( ctx.x_context(), x.data(), @@ -98,7 +98,7 @@ void Conv2dTransposeKernel(const Context& ctx, nullptr, true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2"); - } else if (fccal_type == XPUFCCalcType::FC_FLOAT) { + } else if (fc_calc_type == XPUFCCalcType::FC_FLOAT) { int r = xpu::conv2d_transpose_v2( ctx.x_context(), x.data(), @@ -119,7 +119,7 @@ void Conv2dTransposeKernel(const Context& ctx, nullptr, true); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose_v2"); - } else if (fccal_type == XPUFCCalcType::FC_INT32_WITH_LL) { + } else if (fc_calc_type == XPUFCCalcType::FC_INT32_WITH_LL) { if (output_size.size()) { VLOG(4) << "int_with_ll quantization is not supported when output_size " "is specified, " @@ -171,11 +171,11 @@ void Conv2dTransposeKernel(const Context& ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_transpose"); } } else { - int r = xpu::conv2d_transpose_v2( + int r = xpu::conv2d_transpose_v2( ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(filter.data()), - reinterpret_cast(out->data()), + reinterpret_cast(x.data()), + reinterpret_cast(filter.data()), + reinterpret_cast(out->data()), batch_size, img_yc, img_xh, diff --git a/paddle/phi/kernels/xpu/dropout_kernel.cc b/paddle/phi/kernels/xpu/dropout_kernel.cc index fbd071b868701..a166b860ab2ec 100644 --- a/paddle/phi/kernels/xpu/dropout_kernel.cc +++ b/paddle/phi/kernels/xpu/dropout_kernel.cc @@ -34,15 +34,18 @@ void DropoutRawKernel(const Context& dev_ctx, bool fix_seed, DenseTensor* out, DenseTensor* mask) { + bool is_upscale = (mode == "upscale_in_train"); + dev_ctx.template Alloc(out); + if (mask) { + dev_ctx.template Alloc(mask); + } + using XPUType = typename XPUTypeTrait::Type; - auto* y = out; const auto* x_data = x.data(); - auto* y_data = dev_ctx.template Alloc(y); + auto* y_data = out->data(); float dropout_prob = p.to(); - int is_upscale = (mode == "upscale_in_train"); - - if (!is_test) { + if (!is_test && mask) { int seed_data = 0; if (seed_tensor.get_ptr() != nullptr) { if ((seed_tensor->place()).GetType() == phi::AllocationType::XPU) { @@ -54,7 +57,6 @@ void DropoutRawKernel(const Context& dev_ctx, } else { seed_data = *(seed_tensor->data()); } - } else { seed_data = fix_seed ? seed : 0; } @@ -62,7 +64,7 @@ void DropoutRawKernel(const Context& dev_ctx, seed_data = dev_ctx.GetGenerator()->Random64(); } - auto* mask_data = dev_ctx.template Alloc(mask); + auto* mask_data = mask->data(); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); auto dev_version = phi::backends::xpu::get_xpu_version(dev_ctx.GetPlace().GetDeviceId()); @@ -70,7 +72,7 @@ void DropoutRawKernel(const Context& dev_ctx, if (dropout_prob == 1.0f) { int r = xpu::constant(dev_ctx.x_context(), reinterpret_cast(y_data), - y->numel(), + out->numel(), XPUType(0)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); r = xpu::constant( @@ -79,21 +81,25 @@ void DropoutRawKernel(const Context& dev_ctx, return; } if (dev_version == phi::backends::xpu::XPUVersion::XPU3) { - int r = xpu::dropout_v2(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(y->data()), - mask->data(), + // int dropout_v3(Context* ctx, const T* input, T* res, uint8_t* mask, + // unsigned int seed, int64_t n, bool is_upscale, float dropout_prob); + int r = xpu::dropout_v3(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + mask_data, seed_data, mask->numel(), is_upscale, dropout_prob); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_v2"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_v3"); } else { XPUType* mask_tmp_data = RAII_GUARD.alloc_l3_or_gm(mask->numel()); + // int dropout(Context* ctx, const T* input, T* res, T* mask, unsigned int + // seed, int64_t n, bool is_upscale, float dropout_prob); int r = xpu::dropout(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(y->data()), + reinterpret_cast(x_data), + reinterpret_cast(y_data), mask_tmp_data, seed_data, mask->numel(), @@ -105,16 +111,23 @@ void DropoutRawKernel(const Context& dev_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); } } else { - float scale = - (is_upscale) ? (1.0) : (static_cast(1.0f - dropout_prob)); - int r = xpu::scale(dev_ctx.x_context(), - reinterpret_cast(x_data), - reinterpret_cast(y_data), - x.numel(), - false, - scale, - 0.0f); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + if (is_upscale) { + // y = x + int ret = xpu::copy(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + x.numel() * phi::SizeOf(x.dtype())); + PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy"); + } else { + int r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + x.numel(), + false, + 1.0f - dropout_prob, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); + } } } @@ -126,5 +139,6 @@ PD_REGISTER_KERNEL(dropout, phi::DropoutRawKernel, float, phi::dtype::float16) { + kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->OutputAt(1).SetDataType(phi::DataType::UINT8); } diff --git a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc index 3d0d0355b635f..2089bbd6dd8e4 100644 --- a/paddle/phi/kernels/xpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/embedding_grad_kernel.cc @@ -28,7 +28,7 @@ void EmbeddingGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, DenseTensor* weight_grad) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; DDim table_dim; table_dim = weight.dims(); @@ -36,6 +36,10 @@ void EmbeddingGradKernel(const Context& ctx, auto d_output_t = &out_grad; auto d_table_t = weight_grad; + if (std::getenv("XPU_CDNN_CLUSTER_PARALLEL") != nullptr) { + ctx.Wait(); + } + int64_t ids_numel = ids_t->numel(); PADDLE_ENFORCE_EQ( ids_numel <= std::numeric_limits::max(), @@ -63,11 +67,11 @@ void EmbeddingGradKernel(const Context& ctx, int ym = static_cast(ids_numel); int n = d_table_t->dims()[1]; - int r = xpu::embedding_grad( + int r = xpu::embedding_grad( dev_ctx.x_context(), - reinterpret_cast(d_output_data), + reinterpret_cast(d_output_data), ids_data, - reinterpret_cast(d_table_data), + reinterpret_cast(d_table_data), xm, n, ym, @@ -109,7 +113,7 @@ void EmbeddingSparseGradKernel(const Context& ctx, ids = CopyIdsToVector(ids_cpu); } else { PADDLE_THROW(phi::errors::Unimplemented( - "emebdding input only support int32 and int64")); + "embedding input only support int32 and int64")); } auto ids_num = static_cast(input.numel()); diff --git a/paddle/phi/kernels/xpu/expand_as_kernel.cc b/paddle/phi/kernels/xpu/expand_as_kernel.cc index 0701294217f41..45d0515a0b822 100644 --- a/paddle/phi/kernels/xpu/expand_as_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_as_kernel.cc @@ -17,7 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace phi { diff --git a/paddle/phi/kernels/xpu/flash_attn_kernel.cc b/paddle/phi/kernels/xpu/flash_attn_kernel.cc index f040ef383c539..9ea712c410d1d 100644 --- a/paddle/phi/kernels/xpu/flash_attn_kernel.cc +++ b/paddle/phi/kernels/xpu/flash_attn_kernel.cc @@ -23,6 +23,161 @@ namespace phi { +template +void FlashAttnUnpaddedKernel( + const Context& ctx, + const DenseTensor& q, + const DenseTensor& k, + const DenseTensor& v, + const DenseTensor& cu_seqlens_q, + const DenseTensor& cu_seqlens_k, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + float scale, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + DenseTensor* out, + DenseTensor* softmax, + DenseTensor* softmax_lse, + DenseTensor* seed_offset) { +#ifdef PADDLE_WITH_XPU_XHPC + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + // q, k, v [batch_size * seq_len, num_heads, head_dim] + std::vector dims = common::vectorize(q.dims()); + + const int batch_size = cu_seqlens_q.numel() - 1; + const int num_heads = dims[1]; + const int head_size = dims[2]; + const int num_heads_k = k.dims()[1]; + + // lod info, only support qlod == klod + std::vector qlod_vec(batch_size + 1, 0); + int r = xpu_wait(ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_EQ(r, 0, "xpu_wait failed."); + r = xpu_memcpy(qlod_vec.data(), + cu_seqlens_q.data(), + sizeof(int32_t) * (batch_size + 1), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + PADDLE_ENFORCE_EQ(r, 0, "xpu_memcpy failed."); + std::vector klod_vec(batch_size + 1, 0); + r = xpu_wait(ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_EQ(r, 0, "xpu_wait failed."); + r = xpu_memcpy(klod_vec.data(), + cu_seqlens_k.data(), + sizeof(int32_t) * (batch_size + 1), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + PADDLE_ENFORCE_EQ(r, 0, "xpu_memcpy failed."); + // output: softmax_lse, 训练参数,给反向用于反向重计算的L + bool is_cross_attn = false; + for (int i = 0; i < batch_size + 1; ++i) { + if (qlod_vec[i] != klod_vec[i]) { + is_cross_attn = true; + break; + } + } + + using XPUType = typename XPUTypeTrait::Type; + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + const XPUType* q_data = reinterpret_cast(q.data()); + const XPUType* k_data = reinterpret_cast(k.data()); + const XPUType* v_data = reinterpret_cast(v.data()); + if (!is_cross_attn) { + xpu::VectorParam lods{ + qlod_vec.data(), (int32_t)(qlod_vec.size()), nullptr}; + xpu::QKVAttnParam qkv_attn_param( + lods, // only support qlods == kvlods + num_heads, // head_nums + head_size, // head_dim + xpu::Activation_t::RELU, // Activation_t + -1, // last_slice_seq(unused param) + false, // do_fc_qkv_fusion(unused param) + -1, // pad_seqlen(unused param) + -1, // hidden_dim(unused param) + false, // is_pre_norm(unused param) + false, // is_perchannel(unused param) + 0, // qkv_shape + {}, // z_shape + AttnMacMaxPtrType_t::ATTN_WHOLE_BATCH, // max_ptr_type + -1, // ldz(unused param) + {}, // sqlod(unused param) + scale); // alpha + qkv_attn_param.triangle_mask_autogen = causal; + qkv_attn_param.key_value_head_num = num_heads_k; + r = xpu::qkv_attention(ctx.x_context(), + q_data, // q + k_data, // k + v_data, // v + out_data, // out + nullptr, // max_q + nullptr, // max_k + nullptr, // max_v + nullptr, // max_ctx + qkv_attn_param, + nullptr, + nullptr, + nullptr); + PADDLE_ENFORCE_EQ(r, 0, "xpu::qkv_attention failed."); + } else { + std::vector lod; + lod.reserve(2 * batch_size + 2); + int real_max_len = 0; + for (int i = 0; i < batch_size + 1; i++) { + lod.push_back(qlod_vec[i]); + if (i) + real_max_len = std::max(qlod_vec[i] - qlod_vec[i - 1], real_max_len); + } + for (int i = 0; i < batch_size + 1; i++) { + lod.push_back(klod_vec[i]); + if (i) + real_max_len = std::max(klod_vec[i] - klod_vec[i - 1], real_max_len); + } + xpu::DifSeqAttnParam dis_api_attn_param( + {lod.data(), 2 * batch_size + 2, nullptr}, num_heads, head_size); + XPUType* qk_buf = RAII_GUARD.alloc_l3_or_gm( + batch_size * num_heads * real_max_len * real_max_len); + float* qk_max_buf = RAII_GUARD.alloc_l3_or_gm(6); + r = xpu::qk_attention( + ctx.x_context(), + q_data, + k_data, + qk_buf, + nullptr, + nullptr, + qk_max_buf, + dis_api_attn_param, + nullptr); + PADDLE_ENFORCE_EQ(r, 0, "xpu::qk_attention failed."); + r = xpu::qk_v_attention( + ctx.x_context(), + qk_buf, + v_data, + out_data, + qk_max_buf, + nullptr, + nullptr, + dis_api_attn_param, + nullptr); + PADDLE_ENFORCE_EQ(r, 0, "xpu::qk_v_attention failed."); + } +#else + PADDLE_THROW(phi::errors::PreconditionNotMet( + "re-compile using -DWITH_XPU_XHPC=ON to use FlashAttnKernel")); +#endif +} + template void FlashAttnKernel(const Context& ctx, const DenseTensor& q, @@ -127,6 +282,16 @@ void FlashAttnKernel(const Context& ctx, } // namespace phi +PD_REGISTER_KERNEL(flash_attn_unpadded, + XPU, + ALL_LAYOUT, + phi::FlashAttnUnpaddedKernel, + float, + phi::dtype::float16) { + kernel->InputAt(5).SetBackend( + phi::Backend::ALL_BACKEND); // fixed_seed_offset +} + PD_REGISTER_KERNEL(flash_attn, XPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/xpu/flip_kernel.cc b/paddle/phi/kernels/xpu/flip_kernel.cc index 56a31197e56c7..aa44e3083b7c2 100644 --- a/paddle/phi/kernels/xpu/flip_kernel.cc +++ b/paddle/phi/kernels/xpu/flip_kernel.cc @@ -26,17 +26,17 @@ void FlipKernel(const Context& dev_ctx, DenseTensor* out) { using XPUInTDType = typename XPUTypeTrait::Type; int x_rank = x.dims().size(); - std::vector formated_axis(std::begin(axis), std::end(axis)); + std::vector formatted_axis(std::begin(axis), std::end(axis)); for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { - formated_axis[i] = static_cast(axis[i] + x_rank); + formatted_axis[i] = static_cast(axis[i] + x_rank); } } dev_ctx.template Alloc(out); if (out->numel() == 0) { return; } - if (formated_axis.size() == 0) { + if (formatted_axis.size() == 0) { phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); return; } @@ -52,7 +52,7 @@ void FlipKernel(const Context& dev_ctx, /* const T* x */ x_data, /* T* y */ out_data, /* const std::vector& xshape */ x_shape, - /* const std::vector& axis */ formated_axis); + /* const std::vector& axis */ formatted_axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "flip"); } diff --git a/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc b/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc index c4432f82d9b26..fe989318cbcb4 100644 --- a/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/fused_attention_grad_kernel.cc @@ -224,9 +224,9 @@ void FusedAttentionGradKernel( XPUTypeT *d_dropout_grad_ptr = NULL; // dx5 [batch_size, seq_len, hidden] XPUTypeT *d_fmha_out_ptr = - NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims] - XPUTypeT *d_fmha_out_transpos_tmp_ptr = - NULL; // d_fmha_out_transpos [batch_size, seq_len, num_heads, + NULL; // d_fmha_out [batch_size, seq_len, num_heads, head_dims] + XPUTypeT *d_fmha_out_transpose_tmp_ptr = + NULL; // d_fmha_out_transpose [batch_size, seq_len, num_heads, // head_dims] XPUTypeT *d_qk_ptr = @@ -235,7 +235,7 @@ void FusedAttentionGradKernel( XPUTypeT *d_combination_qkv_ptr = NULL; // d_combination_qkv_ptr[3, batch_size, num_heads, seq_len, // head_dims] - XPUTypeT *d_transpos_qkv_ptr = + XPUTypeT *d_transpose_qkv_ptr = NULL; // dx2 [batch_size, seq_len, 3, num_heads, head_dims] XPUTypeT *d_last_layernorm_grad_ptr = @@ -250,9 +250,9 @@ void FusedAttentionGradKernel( num_heads * head_dims); d_combination_qkv_ptr = RAII_GUARD.alloc(batch_size * seq_len * embed_dims * 3); - d_transpos_qkv_ptr = RAII_GUARD.alloc_l3_or_gm( + d_transpose_qkv_ptr = RAII_GUARD.alloc_l3_or_gm( batch_size * seq_len * embed_dims * 3); - d_fmha_out_transpos_tmp_ptr = + d_fmha_out_transpose_tmp_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * embed_dims); d_qk_ptr = RAII_GUARD.alloc_l3_or_gm(batch_size * seq_len * seq_len * num_heads); @@ -343,7 +343,7 @@ void FusedAttentionGradKernel( XPUTypeT *d_v_out_ptr = d_k_out_ptr + qkv_size; r = xpu::transpose(xpu_ctx, d_fmha_out_ptr, - d_fmha_out_transpos_tmp_ptr, + d_fmha_out_transpose_tmp_ptr, {batch_size, seq_len, num_heads, head_dims}, {0, 2, 1, 3}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); @@ -381,7 +381,7 @@ void FusedAttentionGradKernel( false, attn_dropout_out_ptr, v_out_ptr, - d_fmha_out_transpos_tmp_ptr); + d_fmha_out_transpose_tmp_ptr); std::tie(info_d_qk, info_d_v, a_1, b_1, a_2, b_2) = fc_info; phi::MatMulXPUFunction( @@ -452,7 +452,7 @@ void FusedAttentionGradKernel( // r = xpu::transpose(xpu_ctx, d_combination_qkv_ptr, - d_transpos_qkv_ptr, + d_transpose_qkv_ptr, {3, batch_size, num_heads, seq_len, head_dims}, {1, 3, 0, 2, 4}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); @@ -487,7 +487,7 @@ void FusedAttentionGradKernel( true, use_calc_input_x_ptr, qkv_weight_ptr, - d_transpos_qkv_ptr); + d_transpose_qkv_ptr); std::tie(info_d_x, info_d_qkv_w, a_1, b_1, a_2, b_2) = fc_info; phi::MatMulXPUFunction( @@ -497,7 +497,7 @@ void FusedAttentionGradKernel( // d_qkv_bias r = xpu::reduce_sum(xpu_ctx, - d_transpos_qkv_ptr, + d_transpose_qkv_ptr, d_qkv_bias_ptr, {batch_size * seq_len, 3 * embed_dims}, {0}); diff --git a/paddle/phi/kernels/xpu/fused_attention_kernel.cc b/paddle/phi/kernels/xpu/fused_attention_kernel.cc index d18dda47866ef..b7a1c8a638648 100644 --- a/paddle/phi/kernels/xpu/fused_attention_kernel.cc +++ b/paddle/phi/kernels/xpu/fused_attention_kernel.cc @@ -199,7 +199,7 @@ void FusedAttentionKernel(const Context &dev_ctx, int l3_total_size = xpu_ctx->_l3_mgr.get_size(); - XPUTypeT *qkv_before_transpos_ptr = + XPUTypeT *qkv_before_transpose_ptr = NULL; // x2[batch_size, seq_len, 3, num_heads,head_dims] XPUTypeT *qk_ptr = NULL; // qk [batch_size, num_heads, seq_len, seq_len] XPUTypeT *qkv_ptr = NULL; // qkv[batch_size, num_heads, seq_len, head_dims] @@ -215,7 +215,7 @@ void FusedAttentionKernel(const Context &dev_ctx, std::sort(temp_vec.begin(), temp_vec.end(), std::greater()); XPUTypeT *max_gm_ptr = RAII_GUARD.alloc(temp_vec[0]); PADDLE_ENFORCE_XDNN_NOT_NULL(max_gm_ptr); - qkv_before_transpos_ptr = max_gm_ptr; + qkv_before_transpose_ptr = max_gm_ptr; qk_ptr = max_gm_ptr; qkv_ptr = max_gm_ptr; linear_out_ptr = max_gm_ptr; @@ -223,7 +223,7 @@ void FusedAttentionKernel(const Context &dev_ctx, for (size_t i = 0; i < temp_vec.size(); ++i) { if (l3_total_size >= temp_vec[i] * sizeof_t) { XPUTypeT *l3_ptr = RAII_GUARD.alloc_l3(temp_vec[i]); - qkv_before_transpos_ptr = + qkv_before_transpose_ptr = (temp_size_1 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; qk_ptr = (temp_size_2 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; qkv_ptr = (temp_size_3 <= temp_vec[i]) ? l3_ptr : max_gm_ptr; @@ -264,22 +264,22 @@ void FusedAttentionKernel(const Context &dev_ctx, phi::MatMulXPUFunction(xpu_ctx, x_cacl_ptr, qkv_weight_ptr, - qkv_before_transpos_ptr, + qkv_before_transpose_ptr, qkv_fc_info, 1.0f); // bias r = xpu::broadcast_add(xpu_ctx, - qkv_before_transpos_ptr, + qkv_before_transpose_ptr, qkv_bias_ptr, - qkv_before_transpos_ptr, + qkv_before_transpose_ptr, {batch_size * seq_len, 3 * num_heads * head_dims}, {3 * num_heads * head_dims}); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_add"); // transpose r = xpu::transpose(xpu_ctx, - qkv_before_transpos_ptr, + qkv_before_transpose_ptr, qkv_transpose_out_ptr, {batch_size, seq_len, 3, num_heads, head_dims}, {2, 0, 3, 1, 4}); diff --git a/paddle/phi/kernels/xpu/index_put_kernel.cc b/paddle/phi/kernels/xpu/index_put_kernel.cc index 60c91a8e5c83c..0a86bc6cef536 100644 --- a/paddle/phi/kernels/xpu/index_put_kernel.cc +++ b/paddle/phi/kernels/xpu/index_put_kernel.cc @@ -104,7 +104,7 @@ void IndexPutKernel(const Context& dev_ctx, return; } - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; auto out_data = dev_ctx.template Alloc(out); auto bd_dims = funcs::BroadCastTensorsDims(int_indices_v); DenseTensor res_indices(DataType::INT64); @@ -133,15 +133,15 @@ void IndexPutKernel(const Context& dev_ctx, value_data = value_bd.data(); } - int r = - xpu::index_put(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(value_data), - res_indices.data(), - reinterpret_cast(out_data), - x_shape, - index_shape, - accumulate); + int r = xpu::index_put( + dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(value_data), + res_indices.data(), + reinterpret_cast(out_data), + x_shape, + index_shape, + accumulate); PADDLE_ENFORCE_XDNN_SUCCESS(r, "index_put"); if (dev_ctx.x_context()->xpu_stream) { dev_ctx.Wait(); diff --git a/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc b/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc index dba0e2ccfd765..f1a217ed81ad3 100644 --- a/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/instance_norm_grad_kernel.cc @@ -39,7 +39,7 @@ void InstanceNormGradKernel(const Context& dev_ctx, true, phi::errors::InvalidArgument( "The size of input's dimensions should be less equal than 5", - "and the dimension of D should be eaual to 1", + "and the dimension of D should be equal to 1", "But received: the size of input's dimensions is [%d]", x_dims.size())); diff --git a/paddle/phi/kernels/xpu/inverse_kernel.cc b/paddle/phi/kernels/xpu/inverse_kernel.cc index a48baa508ade0..82d54653eb03c 100644 --- a/paddle/phi/kernels/xpu/inverse_kernel.cc +++ b/paddle/phi/kernels/xpu/inverse_kernel.cc @@ -24,7 +24,7 @@ template void InverseKernel(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; auto out_data = dev_ctx.template Alloc(out); int64_t x_dims_len = x.dims().size(); @@ -41,17 +41,17 @@ void InverseKernel(const Context& dev_ctx, 8192, phi::errors::InvalidArgument( "The size of a single matrix (%d bytes) exceeds the " - "maxinum numbers of bytes xpu supports (8192).", + "maximum numbers of bytes xpu supports (8192).", n * n * sizeof(T))); auto RAII_GUARD = xpu::ctx_guard(dev_ctx.x_context()); auto* info_xpu = RAII_GUARD.alloc_l3_or_gm(batch); // Xpu inverse api has check for singularity itself. - int r = xpu::inverse(dev_ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(out_data), - info_xpu, - batch, - n); + int r = xpu::inverse(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out_data), + info_xpu, + batch, + n); PADDLE_ENFORCE_XDNN_SUCCESS(r, "inverse"); } diff --git a/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc b/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc index 17746e4eeff0a..6e1c20a366d23 100644 --- a/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc +++ b/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc @@ -38,10 +38,12 @@ void MultiClassNMSKernel(const Context& ctx, DenseTensor* out, DenseTensor* index, DenseTensor* nms_rois_num) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; - const XPUT* bboxes_data = reinterpret_cast(bboxes.data()); - const XPUT* scores_data = reinterpret_cast(scores.data()); + const XPUType* bboxes_data = + reinterpret_cast(bboxes.data()); + const XPUType* scores_data = + reinterpret_cast(scores.data()); bool return_index = index != nullptr; bool has_rois_num = rois_num.get_ptr() != nullptr; @@ -90,7 +92,7 @@ void MultiClassNMSKernel(const Context& ctx, PADDLE_ENFORCE_EQ( boxes_count == score_dims[0], true, - phi::errors::InvalidArgument("boxes_count shuold equal score_dims[0].", + phi::errors::InvalidArgument("boxes_count should equal score_dims[0].", "But received: (%d) and (%d)", boxes_count, score_dims[0])); diff --git a/paddle/phi/kernels/xpu/prelu_grad_kernel.cc b/paddle/phi/kernels/xpu/prelu_grad_kernel.cc index fa43c90883766..b7c2157d55f43 100644 --- a/paddle/phi/kernels/xpu/prelu_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/prelu_grad_kernel.cc @@ -60,9 +60,9 @@ void PReluGradKernel(const Context& dev_ctx, } } - // mode = 0: channel_nchw, slope_shape = {c}, default. meanwhile, xhsape = {n, + // mode = 0: channel_nchw, slope_shape = {c}, default. meanwhile, xshape = {n, // c, h, w} - // mode = 1, channel_nhwc, slope_shape = {c}, meanwhile, xhsape = {n, h, w, c} + // mode = 1, channel_nhwc, slope_shape = {c}, meanwhile, xshape = {n, h, w, c} // mode = 2, elementwise, slope_shape = {c*h*w} // mode = 3, single slope, slope_shape = {1} diff --git a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc index 846250c067740..aa8736d84b71f 100644 --- a/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_max_grad_kernel.cc @@ -60,23 +60,23 @@ void ReduceMaxGradKernel(const Context& dev_ctx, } } - T* brocast1 = nullptr; - T* brocast2 = nullptr; + T* broadcast1 = nullptr; + T* broadcast2 = nullptr; bool* equal = nullptr; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - brocast1 = RAII_GUARD.alloc_l3_or_gm(x.numel()); + broadcast1 = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_NOT_NULL( - brocast1, errors::ResourceExhausted("XPU has no enough memory")); + broadcast1, errors::ResourceExhausted("XPU has no enough memory")); equal = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_NOT_NULL( equal, errors::ResourceExhausted("XPU has no enough memory")); - brocast2 = RAII_GUARD.alloc_l3_or_gm(x.numel()); + broadcast2 = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_NOT_NULL( - brocast2, errors::ResourceExhausted("XPU has no enough memory")); + broadcast2, errors::ResourceExhausted("XPU has no enough memory")); // use [1] to replace [], because xpu not support [] if (xdims.size() == 0) { @@ -86,25 +86,25 @@ void ReduceMaxGradKernel(const Context& dev_ctx, ydims = std::vector({1}); } - // step 1. brocast out and out_grad - int r = - xpu::broadcast(dev_ctx.x_context(), out_data, brocast1, ydims, xdims); + // step 1. broadcast out and out_grad + int r = xpu::broadcast( + dev_ctx.x_context(), out_data, broadcast1, ydims, xdims); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); r = xpu::broadcast( - dev_ctx.x_context(), out_grad_data, brocast2, ydims, xdims); + dev_ctx.x_context(), out_grad_data, broadcast2, ydims, xdims); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); - // step 2. comparse out_brocast and x - r = xpu::equal(dev_ctx.x_context(), x_data, brocast1, equal, x.numel()); + // step 2. compare out_broadcast and x + r = xpu::equal(dev_ctx.x_context(), x_data, broadcast1, equal, x.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "equal"); // step 3. get x_grad - r = xpu::constant(dev_ctx.x_context(), brocast1, x.numel(), 0); + r = xpu::constant(dev_ctx.x_context(), broadcast1, x.numel(), 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); r = xpu::select(dev_ctx.x_context(), equal, - brocast2, - brocast1, + broadcast2, + broadcast1, x_grad_data, xdims, xdims); diff --git a/paddle/phi/kernels/xpu/reduce_min_grad_kernel.cc b/paddle/phi/kernels/xpu/reduce_min_grad_kernel.cc index 9019cb0834d72..aefcc74b45091 100644 --- a/paddle/phi/kernels/xpu/reduce_min_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/reduce_min_grad_kernel.cc @@ -60,23 +60,23 @@ void ReduceMinGradKernel(const Context& dev_ctx, } } - T* brocast1 = nullptr; - T* brocast2 = nullptr; + T* broadcast1 = nullptr; + T* broadcast2 = nullptr; bool* equal = nullptr; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - brocast1 = RAII_GUARD.alloc_l3_or_gm(x.numel()); + broadcast1 = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_NOT_NULL( - brocast1, errors::ResourceExhausted("XPU has no enough memory")); + broadcast1, errors::ResourceExhausted("XPU has no enough memory")); equal = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_NOT_NULL( equal, errors::ResourceExhausted("XPU has no enough memory")); - brocast2 = RAII_GUARD.alloc_l3_or_gm(x.numel()); + broadcast2 = RAII_GUARD.alloc_l3_or_gm(x.numel()); PADDLE_ENFORCE_NOT_NULL( - brocast2, errors::ResourceExhausted("XPU has no enough memory")); + broadcast2, errors::ResourceExhausted("XPU has no enough memory")); // use [1] to replace [], because xpu not support [] if (xdims.size() == 0) { @@ -86,25 +86,25 @@ void ReduceMinGradKernel(const Context& dev_ctx, ydims = std::vector({1}); } - // step 1. brocast out and out_grad - int r = - xpu::broadcast(dev_ctx.x_context(), out_data, brocast1, ydims, xdims); + // step 1. broadcast out and out_grad + int r = xpu::broadcast( + dev_ctx.x_context(), out_data, broadcast1, ydims, xdims); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); r = xpu::broadcast( - dev_ctx.x_context(), out_grad_data, brocast2, ydims, xdims); + dev_ctx.x_context(), out_grad_data, broadcast2, ydims, xdims); PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast"); - // step 2. comparse out_brocast and x - r = xpu::equal(dev_ctx.x_context(), x_data, brocast1, equal, x.numel()); + // step 2. compare out_broadcast and x + r = xpu::equal(dev_ctx.x_context(), x_data, broadcast1, equal, x.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "equal"); // step 3. get x_grad - r = xpu::constant(dev_ctx.x_context(), brocast1, x.numel(), 0); + r = xpu::constant(dev_ctx.x_context(), broadcast1, x.numel(), 0); PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); r = xpu::select(dev_ctx.x_context(), equal, - brocast2, - brocast1, + broadcast2, + broadcast1, x_grad_data, xdims, xdims); diff --git a/paddle/phi/kernels/xpu/rnn_util.h b/paddle/phi/kernels/xpu/rnn_util.h index 5310b35e64dc3..7948bb2defa0c 100644 --- a/paddle/phi/kernels/xpu/rnn_util.h +++ b/paddle/phi/kernels/xpu/rnn_util.h @@ -23,7 +23,7 @@ void ResetParameterVector(const std::vector& raw_params_vec, const int& num_layers, const bool& is_bidirec, std::vector>* params_vec) { - // the parameter raw seuquence is [FWhi, FWhh, BWhi, BWhh] * num_layers + // the parameter raw sequence is [FWhi, FWhh, BWhi, BWhh] * num_layers // + [FBhi, FBhh, BBhi, BBhh] * num_layers, we will reset the parameter to // ([FWhi, FWhh, FBhi, FBhh] + [BWhi, BWhh, BBhi, BBhh]) * num_layers const int& direction_num = is_bidirec ? 2 : 1; diff --git a/paddle/phi/kernels/xpu/scale_kernel.cc b/paddle/phi/kernels/xpu/scale_kernel.cc index 6fe127af3d6ef..e63787a93c84c 100644 --- a/paddle/phi/kernels/xpu/scale_kernel.cc +++ b/paddle/phi/kernels/xpu/scale_kernel.cc @@ -23,7 +23,7 @@ template void ScaleKernel(const Context& dev_ctx, const DenseTensor& x, const Scalar& scale, - float bias, + const Scalar& bias, bool bias_after_scale, DenseTensor* out) { dev_ctx.template Alloc(out); @@ -45,7 +45,7 @@ void ScaleKernel(const Context& dev_ctx, x.numel(), bias_after_scale, scale.to(), - bias); + bias.to()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); } diff --git a/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc index 37e6e91ea779e..bc08afbb7f6da 100644 --- a/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/scatter_nd_add_grad_kernel.cc @@ -25,15 +25,15 @@ void ScatterNdAddGradKernel(const Context &ctx, const DenseTensor &out_grad, DenseTensor *x_grad, DenseTensor *updates_grad) { - using XPUT = typename XPUTypeTrait::Type; + using XPUType = typename XPUTypeTrait::Type; int ret = xpu::SUCCESS; const T *out_grad_data = out_grad.data(); if (x_grad) { auto *x_grad_data = ctx.template Alloc(x_grad); - ret = xpu::copy(ctx.x_context(), - reinterpret_cast(out_grad_data), - reinterpret_cast(x_grad_data), - out_grad.numel()); + ret = xpu::copy(ctx.x_context(), + reinterpret_cast(out_grad_data), + reinterpret_cast(x_grad_data), + out_grad.numel()); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy"); } @@ -64,11 +64,12 @@ void ScatterNdAddGradKernel(const Context &ctx, out_grad_numel, remain_numel, updates_grad_numel)); - ret = xpu::broadcast(ctx.x_context(), - reinterpret_cast(out_grad_data), - reinterpret_cast(updates_grad_data), - {1, out_grad_numel}, - {remain_numel, out_grad_numel}); + ret = xpu::broadcast( + ctx.x_context(), + reinterpret_cast(out_grad_data), + reinterpret_cast(updates_grad_data), + {1, out_grad_numel}, + {remain_numel, out_grad_numel}); PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); return; } @@ -84,19 +85,19 @@ void ScatterNdAddGradKernel(const Context &ctx, nullptr}; if (index.dtype() == DataType::INT32) { - ret = xpu::gather_nd( + ret = xpu::gather_nd( ctx.x_context(), - reinterpret_cast(out_grad_data), + reinterpret_cast(out_grad_data), index.data(), - reinterpret_cast(updates_grad_data), + reinterpret_cast(updates_grad_data), out_grad_shape_param, index_shape_vec); } else { - ret = xpu::gather_nd( + ret = xpu::gather_nd( ctx.x_context(), - reinterpret_cast(out_grad_data), + reinterpret_cast(out_grad_data), index.data(), - reinterpret_cast(updates_grad_data), + reinterpret_cast(updates_grad_data), out_grad_shape_param, index_shape_vec); } diff --git a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc index c5d33ae4ac8d0..227d6b39c9f28 100644 --- a/paddle/phi/kernels/xpu/set_value_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_grad_kernel.cc @@ -203,7 +203,7 @@ void SetValueGradImpl(const Context& dev_ctx, auto value_grad_dims = value_grad->dims(); auto fake_value_grad_dims = out_dims; - // Create an extented shape according to the rules of broadcast. + // Create an extended shape according to the rules of broadcast. auto value_grad_dims_size = value_grad_dims.size(); int num_decrease = 0; diff --git a/paddle/phi/kernels/xpu/set_value_kernel.cc b/paddle/phi/kernels/xpu/set_value_kernel.cc index c457a6d21fd8a..60b0fff7d9d7c 100644 --- a/paddle/phi/kernels/xpu/set_value_kernel.cc +++ b/paddle/phi/kernels/xpu/set_value_kernel.cc @@ -263,7 +263,7 @@ void SetValueKernelImpl(const Context& dev_ctx, const std::vector& decrease_axes, const std::vector& none_axes, DenseTensor* out) { - // rank是xtensor的维度信息 + // rank是x tensor的维度信息 const int rank = x.dims().size(); switch (rank) { diff --git a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc index 709eeaac49546..e54de257ead10 100644 --- a/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_grad_kernel.cc @@ -51,11 +51,6 @@ void StridedSliceRawGradKernel(const Context& dev_ctx, int num = axes.size(); for (int i = 0; i < num; ++i) { - PADDLE_ENFORCE_EQ( - strides_[i] > 0, - true, - errors::InvalidArgument("Currently, XPU strided slice kernel does not", - "support reverse strided slice")); int cur_axe = axes[i]; int st = starts_[i]; if (st > xshape[cur_axe]) { @@ -71,7 +66,12 @@ void StridedSliceRawGradKernel(const Context& dev_ctx, end = xshape[cur_axe]; } if (end < 0) { - end += xshape[cur_axe]; + if (!(end == -1 && strides_[i] < 0)) { + end = end + xshape[cur_axe]; + if (end < 0) { + end = 0; + } + } } ends_in[cur_axe] = end; diff --git a/paddle/phi/kernels/xpu/stride_slice_kernel.cc b/paddle/phi/kernels/xpu/stride_slice_kernel.cc index 5aee59729b52e..1a10ba1e8fae4 100644 --- a/paddle/phi/kernels/xpu/stride_slice_kernel.cc +++ b/paddle/phi/kernels/xpu/stride_slice_kernel.cc @@ -66,15 +66,10 @@ void StridedSliceRawKernel(const Context& dev_ctx, int num = axes.size(); for (int i = 0; i < num; ++i) { - PADDLE_ENFORCE_EQ( - strides_[i] > 0, - true, - errors::InvalidArgument("Currently, XPU strided slice kernel does not ", - "support reverse strided slice.")); int cur_axe = axes[i]; int st = starts_[i]; if (st > xshape[cur_axe]) { - st = xshape[cur_axe]; + st = xshape[cur_axe] - 1; } if (st < 0) { st += xshape[cur_axe]; @@ -86,17 +81,15 @@ void StridedSliceRawKernel(const Context& dev_ctx, end = xshape[cur_axe]; } if (end < 0) { - end += xshape[cur_axe]; + if (!(end == -1 && strides_[i] < 0)) { + end = end + xshape[cur_axe]; + if (end < 0) { + end = 0; + } + } } ends_in[cur_axe] = end; - PADDLE_ENFORCE_EQ( - st < end, - true, - errors::InvalidArgument("End index should be larger than", - "start Index, this OP does not support", - "reverse operator.")); - strides_in[cur_axe] = strides_[i]; } diff --git a/paddle/phi/kernels/xpu/take_along_axis_kernel.cc b/paddle/phi/kernels/xpu/take_along_axis_kernel.cc index e55604e768b9a..bff4204b65801 100644 --- a/paddle/phi/kernels/xpu/take_along_axis_kernel.cc +++ b/paddle/phi/kernels/xpu/take_along_axis_kernel.cc @@ -33,45 +33,45 @@ void TakeAlongAxisKernel(const Context& dev_ctx, if (x.numel() == 0 || index.numel() == 0) return; - const auto& index_type = index.dtype(); - bool index_type_match = - index_type == DataType::INT32 || index_type == DataType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, + const auto& index_dtype = index.dtype(); + bool index_dtype_match = + index_dtype == DataType::INT32 || index_dtype == DataType::INT64; + PADDLE_ENFORCE_EQ(index_dtype_match, true, errors::InvalidArgument( "Input(Index) holds the wrong type, it holds %s, but " "desires to be %s or %s", - DataTypeToString(index_type), + DataTypeToString(index_dtype), DataTypeToString(DataType::INT32), DataTypeToString(DataType::INT64))); - std::vector xshape(x.dims().size()); + std::vector x_shape(x.dims().size()); for (int i = 0; i < x.dims().size(); ++i) { - xshape[i] = x.dims()[i]; + x_shape[i] = x.dims()[i]; } - std::vector idxshape(index.dims().size()); + std::vector index_shape(index.dims().size()); for (int i = 0; i < index.dims().size(); ++i) { - idxshape[i] = index.dims()[i]; + index_shape[i] = index.dims()[i]; } - if (xshape.size() <= 1 && idxshape.size() <= 1) { - for (int i = xshape.size(); i < 2; ++i) { - xshape.push_back(1); - idxshape.push_back(1); + if (x_shape.size() <= 1 && index_shape.size() <= 1) { + for (int i = x_shape.size(); i < 2; ++i) { + x_shape.push_back(1); + index_shape.push_back(1); } } using XPUType = typename XPUTypeTrait::Type; int r = XPU_SUCCESS; #ifndef PADDLE_WITH_XPU_PLUGIN - if (index_type == DataType::INT32) { + if (index_dtype == DataType::INT32) { r = xpu::gather_element( dev_ctx.x_context(), reinterpret_cast(x.data()), index.data(), reinterpret_cast(out->data()), - xshape, - idxshape, + x_shape, + index_shape, axis); } else { r = xpu::gather_element( @@ -79,20 +79,20 @@ void TakeAlongAxisKernel(const Context& dev_ctx, reinterpret_cast(x.data()), index.data(), reinterpret_cast(out->data()), - xshape, - idxshape, + x_shape, + index_shape, axis); } PADDLE_ENFORCE_XDNN_SUCCESS(r, "gather_element"); #else - if (index_type == DataType::INT32) { + if (index_dtype == DataType::INT32) { r = xpu::plugin::take_along_axis( dev_ctx.x_context(), reinterpret_cast(x.data()), index.data(), reinterpret_cast(out->data()), - xshape, - idxshape, + x_shape, + index_shape, axis); } else { r = xpu::plugin::take_along_axis( @@ -100,8 +100,8 @@ void TakeAlongAxisKernel(const Context& dev_ctx, reinterpret_cast(x.data()), index.data(), reinterpret_cast(out->data()), - xshape, - idxshape, + x_shape, + index_shape, axis); } PADDLE_ENFORCE_XDNN_SUCCESS(r, "take_along_axis"); diff --git a/paddle/phi/kernels/xpu/tile_grad_kernel.cc b/paddle/phi/kernels/xpu/tile_grad_kernel.cc index b131c16854960..b47d8fa5a115c 100644 --- a/paddle/phi/kernels/xpu/tile_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/tile_grad_kernel.cc @@ -83,8 +83,8 @@ void TileGradKernel(const Context& dev_ctx, using XPUType = typename XPUTypeTrait::Type; // int reduce_sum(Context* ctx, const T* x, T* y, const std::vector& // xshape, const std::vector& rdims) - const auto* out_data = out_grad.data(); - auto* x_grad_data = x_grad->data(); + const auto* out_data = reinterpret_cast(out_grad.data()); + auto* x_grad_data = reinterpret_cast(x_grad->data()); int r = xpu::reduce_sum(dev_ctx.x_context(), out_data, x_grad_data, @@ -96,4 +96,9 @@ void TileGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(tile_grad, XPU, ALL_LAYOUT, phi::TileGradKernel, float) {} +PD_REGISTER_KERNEL(tile_grad, + XPU, + ALL_LAYOUT, + phi::TileGradKernel, + float, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/tile_kernel.cc b/paddle/phi/kernels/xpu/tile_kernel.cc index d90232b6767e7..63d316f547554 100644 --- a/paddle/phi/kernels/xpu/tile_kernel.cc +++ b/paddle/phi/kernels/xpu/tile_kernel.cc @@ -29,6 +29,7 @@ void TileKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& repeat_times_arr, DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; auto rank = x.dims().size(); std::vector repeat_times = repeat_times_arr.GetData(); int repeat_times_size = repeat_times.size(); @@ -123,17 +124,23 @@ void TileKernel(const Context& dev_ctx, vec_out_dims); } else { - ret = xpu::broadcast(dev_ctx.x_context(), - x.data(), - out->data(), - vec_in_dims, - vec_out_dims); + const auto* x_data = reinterpret_cast(x.data()); + auto* out_data = reinterpret_cast(out->data()); + ret = xpu::broadcast( + dev_ctx.x_context(), x_data, out_data, vec_in_dims, vec_out_dims); } PADDLE_ENFORCE_XDNN_SUCCESS(ret, "broadcast"); } } // namespace phi -PD_REGISTER_KERNEL( - tile, XPU, ALL_LAYOUT, phi::TileKernel, bool, float, double, int, int64_t) { -} +PD_REGISTER_KERNEL(tile, + XPU, + ALL_LAYOUT, + phi::TileKernel, + bool, + float, + double, + int, + int64_t, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc index ab6be8c3347ca..a461b0dcb1b58 100644 --- a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc @@ -36,16 +36,16 @@ void TransposeGradKernel(const Context& dev_ctx, return; } - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis_size; i++) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + axis_size; + formatted_axis[i] = axis[i] + axis_size; } } std::vector reversed_axis(axis); for (size_t i = 0; i < axis_size; i++) { - reversed_axis[formated_axis[i]] = i; + reversed_axis[formatted_axis[i]] = i; } std::vector out_grad_dim_vec = common::vectorize(out_grad.dims()); diff --git a/paddle/phi/kernels/xpu/transpose_kernel.cc b/paddle/phi/kernels/xpu/transpose_kernel.cc index f88e06b18e88d..4fda5e3912645 100644 --- a/paddle/phi/kernels/xpu/transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/transpose_kernel.cc @@ -25,10 +25,10 @@ void TransposeKernel(const Context& dev_ctx, const std::vector& axis, DenseTensor* out) { size_t x_rank = x.dims().size(); - std::vector formated_axis = axis; + std::vector formatted_axis = axis; for (size_t i = 0; i < axis.size(); i++) { if (axis[i] < 0) { - formated_axis[i] = axis[i] + x_rank; + formatted_axis[i] = axis[i] + x_rank; } } @@ -38,7 +38,7 @@ void TransposeKernel(const Context& dev_ctx, if (out->numel() == 0) { return; } - if (formated_axis.size() == 0) { + if (formatted_axis.size() == 0) { phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); return; } @@ -48,7 +48,7 @@ void TransposeKernel(const Context& dev_ctx, reinterpret_cast(x.data()), reinterpret_cast(out->data()), x_dim_vec, - formated_axis); + formatted_axis); PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); } diff --git a/paddle/phi/kernels/xpu/xpu_api_wrapper.h b/paddle/phi/kernels/xpu/xpu_api_wrapper.h index 5d6006b7a69bd..c6560622eaaf6 100644 --- a/paddle/phi/kernels/xpu/xpu_api_wrapper.h +++ b/paddle/phi/kernels/xpu/xpu_api_wrapper.h @@ -54,8 +54,10 @@ XPUFCCalcType FCCalcType() { return XPUFCCalcType::FC_FLOAT; } else if (std::getenv("XPU_PADDLE_FC_INT32_WITH_LL") != nullptr) { return XPUFCCalcType::FC_INT32_WITH_LL; - } else if (std::is_same::value || - std::is_same::value) { + } else if ((std::is_same::value || + std::is_same::value) || + (std::is_same::value && + std::getenv("XPU_PADDLE_FC_TF32") != nullptr)) { return XPUFCCalcType::FC_TF32; } return XPUFCCalcType::FC_INT16; @@ -309,7 +311,7 @@ static void xblas_fc_wrapper(xpu::Context* ctx, } } -#define DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUType, FCT) \ +#define DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUType, FCT) \ template <> \ void xblas_fc_wrapper(xpu::Context * ctx, \ const XPUType* x, \ @@ -338,12 +340,12 @@ static void xblas_fc_wrapper(xpu::Context* ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_wrapper"); \ } -DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int_with_ll_t) -DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int16_t) -DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int32_t) -DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeFP16, int32_t) -DECLEAR_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeFP16, tfloat32) +DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int_with_ll_t) +DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int16_t) +DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, int32_t) +DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeBF16, XPUTypeFP16) +DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeFP16, int32_t) +DECLARE_UNSUPPORTED_XBLAS_FC_WRAPPER(XPUTypeFP16, tfloat32) template static void xblas_fc_batch_wrapper(xpu::Context* xpu_ctx, @@ -384,7 +386,7 @@ static void xblas_fc_batch_wrapper(xpu::Context* xpu_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_batch_wrapper"); } -#define DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUType, FCT, TGEMM_OUT) \ +#define DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUType, FCT, TGEMM_OUT) \ template <> \ void xblas_fc_batch_wrapper( \ xpu::Context * xpu_ctx, \ @@ -408,23 +410,23 @@ static void xblas_fc_batch_wrapper(xpu::Context* xpu_ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "xblas_fc_batched"); \ } -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int_with_ll_t, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, tfloat32, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, float, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, tfloat32, XPUTypeFP16) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, float, XPUTypeFP16) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, XPUTypeFP16, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int32_t, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int16_t, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, int32_t, XPUTypeFP16) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int_with_ll_t, float) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, XPUTypeFP16, float) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, tfloat32, float) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int32_t, float) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int16_t, float) -DECLEAR_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, int32_t, float) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int32_t, XPUTypeFP16) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int16_t, XPUTypeFP16) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, int32_t, XPUTypeFP16) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int_with_ll_t, float) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, XPUTypeFP16, float) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, tfloat32, float) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int32_t, float) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeBF16, int16_t, float) +DECLARE_UNSUPPORTED_XBLAS_FC_BATCH_WRAPPER(XPUTypeFP16, int32_t, float) template static void MatMulXPUFunction( @@ -437,7 +439,7 @@ static void MatMulXPUFunction( bool is_grad = false, xpu::Activation_t act = xpu::Activation_t::LINEAR) { using XPUType = typename XPUTypeTrait::Type; - int fccal_type = FCCalcType(); + int fc_calc_type = FCCalcType(); decltype(&xblas_fc_wrapper) xblas_fc_api_list[6] = { &xblas_fc_wrapper, @@ -458,16 +460,16 @@ static void MatMulXPUFunction( &xblas_fc_batch_wrapper, }; - auto xblas_fc_api = xblas_fc_api_list[fccal_type]; + auto xblas_fc_api = xblas_fc_api_list[fc_calc_type]; if (std::getenv("XPU_PADDLE_FC_GRAD_LOCAL") != nullptr) { if (is_grad) { xblas_fc_api = xblas_fc_api_list[2]; } } - auto xblas_fc_batch_api = xblas_fc_batch_api_list[fccal_type]; + auto xblas_fc_batch_api = xblas_fc_batch_api_list[fc_calc_type]; - if (fccal_type == XPUFCCalcType::FC_FLOAT16 && + if (fc_calc_type == XPUFCCalcType::FC_FLOAT16 && std::getenv("XPU_PADDLE_FC_FLOAT16") != nullptr) { xblas_fc_batch_api = &xblas_fc_batch_wrapper; diff --git a/paddle/pir/include/core/attribute.h b/paddle/pir/include/core/attribute.h index 9571440679b8c..53b0d92a4e6b5 100644 --- a/paddle/pir/include/core/attribute.h +++ b/paddle/pir/include/core/attribute.h @@ -15,10 +15,12 @@ #pragma once #include "paddle/pir/include/core/cast_utils.h" +#include "paddle/pir/include/core/storage_manager_support.h" #include "paddle/pir/include/core/type_id.h" constexpr char kAttrStopGradients[] = "stop_gradient"; -constexpr char kAttrIsPersistable[] = "is_persistable"; +constexpr char kAttrIsPersistable[] = "persistable"; +constexpr char kAttrOpDistAttr[] = "op_dist_attr"; namespace pir { class AttributeStorage; @@ -87,6 +89,8 @@ class IR_API Attribute { return pir::dyn_cast(*this); } + std::size_t hash() const { return std::hash()(storage_); } + protected: const Storage *storage_{nullptr}; }; @@ -97,8 +101,6 @@ IR_API std::ostream &operator<<(std::ostream &os, Attribute attr); namespace std { template <> struct hash { - std::size_t operator()(const pir::Attribute &obj) const { - return std::hash()(obj); - } + std::size_t operator()(const pir::Attribute &obj) const { return obj.hash(); } }; } // namespace std diff --git a/paddle/pir/include/core/attribute_base.h b/paddle/pir/include/core/attribute_base.h index d6c75f2e5d8ce..0f459f23e9f99 100644 --- a/paddle/pir/include/core/attribute_base.h +++ b/paddle/pir/include/core/attribute_base.h @@ -16,8 +16,8 @@ #include "paddle/pir/include/core/ir_context.h" #include "paddle/pir/include/core/storage_manager.h" +#include "paddle/pir/include/core/storage_manager_support.h" #include "paddle/pir/include/core/type_id.h" - namespace pir { class Dialect; @@ -239,6 +239,16 @@ struct IR_API AttributeManager { } }; +template +using AttrBase = detail::StorageHelperBase; + /// /// \brief Add some necessary functions to the custom Attribute class. /// diff --git a/paddle/pir/include/core/block.h b/paddle/pir/include/core/block.h index a9d68d0969473..25b4afe9bfc47 100644 --- a/paddle/pir/include/core/block.h +++ b/paddle/pir/include/core/block.h @@ -61,6 +61,7 @@ class IR_API Block { ConstReverseIterator rend() const { return ops_.rend(); } ReverseIterator rbegin() { return ops_.rbegin(); } ReverseIterator rend() { return ops_.rend(); } + const OpListType &ops() const { return ops_; } Operation &back() { return *ops_.back(); } Operation &front() { return *ops_.front(); } diff --git a/paddle/pir/include/core/block_argument.h b/paddle/pir/include/core/block_argument.h index 3ddf7847fd8a2..b3b8c78660c34 100644 --- a/paddle/pir/include/core/block_argument.h +++ b/paddle/pir/include/core/block_argument.h @@ -16,6 +16,7 @@ #include "paddle/pir/include/core/operation_utils.h" #include "paddle/pir/include/core/value.h" + namespace pir { class Block; diff --git a/paddle/pir/include/core/builder.h b/paddle/pir/include/core/builder.h index 5278eed2a5af9..fa431d38a6fd0 100644 --- a/paddle/pir/include/core/builder.h +++ b/paddle/pir/include/core/builder.h @@ -107,7 +107,8 @@ class Builder { /// Set the insertion point to the end of the specified block. void SetInsertionPointToBlockEnd(Block *block) { - IR_ENFORCE(block != nullptr, "argument of block is nullptr"); + PADDLE_ENFORCE_NOT_NULL( + block, phi::errors::PreconditionNotMet("argument of block is nullptr")); set_insertion_point(block, block->end()); } @@ -126,6 +127,8 @@ class Builder { const std::vector &output_types, pir::OpInfo op_info); + Operation *Insert(Operation *op); + /// Create an operation of specific op type at the current insertion point. template OpTy Build(Args &&...args); @@ -157,8 +160,6 @@ class Builder { IR_API Complex128Attribute complex128_attr(phi::dtype::complex value); private: - Operation *Insert(Operation *op); - IrContext *context_; InsertionPoint insertion_point_; diff --git a/paddle/pir/include/core/builtin_attribute.h b/paddle/pir/include/core/builtin_attribute.h index b2eba7c423555..e9c0e39239ca8 100644 --- a/paddle/pir/include/core/builtin_attribute.h +++ b/paddle/pir/include/core/builtin_attribute.h @@ -26,6 +26,7 @@ class IR_API BoolAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(BoolAttribute, BoolAttributeStorage); + static std::string name() { return "a_bool"; } bool data() const; }; @@ -36,6 +37,7 @@ class IR_API Complex64Attribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Complex64Attribute, Complex64AttributeStorage); + static std::string name() { return "a_c64"; } phi::dtype::complex data() const; }; @@ -46,6 +48,7 @@ class IR_API Complex128Attribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Complex128Attribute, Complex128AttributeStorage); + static std::string name() { return "a_c128"; } phi::dtype::complex data() const; }; @@ -55,6 +58,7 @@ class IR_API FloatAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(FloatAttribute, FloatAttributeStorage); + static std::string name() { return "a_f32"; } float data() const; }; @@ -64,6 +68,7 @@ class IR_API DoubleAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(DoubleAttribute, DoubleAttributeStorage); + static std::string name() { return "a_f64"; } double data() const; }; @@ -73,6 +78,7 @@ class IR_API Int32Attribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int32Attribute, Int32AttributeStorage); + static std::string name() { return "a_i32"; } int32_t data() const; }; @@ -82,6 +88,7 @@ class IR_API IndexAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(IndexAttribute, IndexAttributeStorage); + static std::string name() { return "a_index"; } int64_t data() const; }; @@ -91,6 +98,7 @@ class IR_API Int64Attribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(Int64Attribute, Int64AttributeStorage); + static std::string name() { return "a_i64"; } int64_t data() const; }; @@ -100,6 +108,7 @@ class IR_API PointerAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage); + static std::string name() { return "a_pointer"; } void* data() const; }; @@ -109,6 +118,7 @@ class IR_API TypeAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage); + static std::string name() { return "a_type"; } Type data() const; }; @@ -122,6 +132,7 @@ class IR_API StrAttribute : public Attribute { std::string AsString() const; + static std::string name() { return "a_str"; } size_t size() const; static StrAttribute get(IrContext* ctx, const std::string& value); @@ -134,6 +145,7 @@ class IR_API ArrayAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage); std::vector AsVector() const; + static std::string name() { return "a_array"; } size_t size() const; @@ -156,7 +168,7 @@ class IR_API TensorNameAttribute : public Attribute { DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TensorNameAttribute, StrAttributeStorage); bool operator<(const TensorNameAttribute& right) const; - + static std::string name() { return "a_tensorname"; } std::string data() const; size_t size() const; diff --git a/paddle/pir/include/core/builtin_attribute_storage.h b/paddle/pir/include/core/builtin_attribute_storage.h index 0e7041abb73eb..8df489ce46a60 100644 --- a/paddle/pir/include/core/builtin_attribute_storage.h +++ b/paddle/pir/include/core/builtin_attribute_storage.h @@ -138,10 +138,11 @@ struct ArrayAttributeStorage : public AttributeStorage { bool empty() const { return size_ == 0u; } Attribute at(size_t index) const { - IR_ENFORCE(index < size_, - "The index (%d) must be less than size (%d).", - index, - size_); + PADDLE_ENFORCE_LT( + index, + size_, + phi::errors::InvalidArgument( + "The index (%d) must be less than size (%d).", index, size_)); return data_[index]; } Attribute operator[](size_t index) const { return data_[index]; } diff --git a/paddle/pir/include/core/builtin_dialect.h b/paddle/pir/include/core/builtin_dialect.h index 1203cdec9d283..193141750283c 100644 --- a/paddle/pir/include/core/builtin_dialect.h +++ b/paddle/pir/include/core/builtin_dialect.h @@ -24,14 +24,17 @@ namespace pir { /// class IR_API BuiltinDialect : public pir::Dialect { public: - explicit BuiltinDialect(pir::IrContext *context); + explicit BuiltinDialect(pir::IrContext* context); /// /// \brief Each Dialect needs to provide a name function to return the name of /// the Dialect. /// /// \return The name of this Dialect. /// - static const char *name() { return "builtin"; } + static const char* name() { return "builtin"; } + + pir::Type ParseType(pir::IrParser& parser) override; // NOLINT + void PrintType(pir::Type type, std::ostream& os) const override; private: void initialize(); diff --git a/paddle/pir/include/core/builtin_op.h b/paddle/pir/include/core/builtin_op.h index add3e6a6a312d..f723eaa96b138 100644 --- a/paddle/pir/include/core/builtin_op.h +++ b/paddle/pir/include/core/builtin_op.h @@ -23,6 +23,8 @@ namespace pir { class Program; class Block; constexpr char kStopGradientAttrName[] = "stop_gradient"; +constexpr char kOutputDimExprs[] = "output_dim_exprs"; +constexpr char kSymbolBindings[] = "symbol_bindings"; /// /// \brief ModuleOp /// diff --git a/paddle/pir/include/core/builtin_type.h b/paddle/pir/include/core/builtin_type.h index 3218707277a7a..caef2ff332f4f 100644 --- a/paddle/pir/include/core/builtin_type.h +++ b/paddle/pir/include/core/builtin_type.h @@ -44,6 +44,7 @@ class IR_API VectorType using Base::Base; std::vector data() const; + static std::string name() { return "t_vec"; } size_t size() const { return data().size(); } @@ -66,6 +67,15 @@ class IR_API DenseTensorType : public Type::TypeBase { \ public: \ using Base::Base; \ static __name get(IrContext *context); \ + static std::string name() { return s_name; } \ }; #define FOREACH_BUILTIN_TYPE(__macro) \ - __macro(BFloat16Type); \ - __macro(Float16Type); \ - __macro(Float32Type); \ - __macro(Float64Type); \ - __macro(Int8Type); \ - __macro(UInt8Type); \ - __macro(Int16Type); \ - __macro(Int32Type); \ - __macro(Int64Type); \ - __macro(IndexType); \ - __macro(BoolType); \ - __macro(Complex64Type); \ - __macro(Complex128Type); - + __macro(BFloat16Type, "t_bf16"); \ + __macro(Float16Type, "t_f16"); \ + __macro(Float32Type, "t_f32"); \ + __macro(Float64Type, "t_f64"); \ + __macro(Int8Type, "t_i8"); \ + __macro(UInt8Type, "t_ui8"); \ + __macro(Int16Type, "t_i16"); \ + __macro(Int32Type, "t_i32"); \ + __macro(Int64Type, "t_i64"); \ + __macro(IndexType, "t_index"); \ + __macro(BoolType, "t_bool"); \ + __macro(Complex64Type, "t_c64"); \ + __macro(Complex128Type, "t_c128"); FOREACH_BUILTIN_TYPE(DECLARE_BUILTIN_TYPE) #undef FOREACH_BUILTIN_TYPE diff --git a/paddle/pir/include/core/builtin_type_interfaces.h b/paddle/pir/include/core/builtin_type_interfaces.h index d6425549fab1f..81ac76e8f48e9 100644 --- a/paddle/pir/include/core/builtin_type_interfaces.h +++ b/paddle/pir/include/core/builtin_type_interfaces.h @@ -80,7 +80,10 @@ class IR_API ShapedTypeInterface /// If this is a ranked type, return the rank. Otherwise, abort. /// int64_t GetRank() const { - IR_ENFORCE((*this).HasRank(), "Cannot query rank of unranked shaped type."); + PADDLE_ENFORCE_EQ((*this).HasRank(), + true, + phi::errors::InvalidArgument( + "Cannot query rank of unranked shaped type.")); return (*this).GetShape().size(); } @@ -110,7 +113,10 @@ class IR_API ShapedTypeInterface /// unranked types. /// bool IsDynamicDim(unsigned idx) const { - IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type."); + PADDLE_ENFORCE_LT( + idx, + GetRank(), + phi::errors::InvalidArgument("Invalid index for shaped type.")); return ShapedTypeInterface::IsDynamic((*this).GetShape()[idx]); } @@ -129,7 +135,10 @@ class IR_API ShapedTypeInterface /// for unranked types. /// int64_t GetDimSize(unsigned idx) const { - IR_ENFORCE(idx < GetRank(), "Invalid index for shaped type."); + PADDLE_ENFORCE_LT( + idx, + GetRank(), + phi::errors::InvalidArgument("Invalid index for shaped type.")); return (*this).GetShape()[idx]; } @@ -137,6 +146,31 @@ class IR_API ShapedTypeInterface Concept *impl_; }; +class IR_API WrapTypeInterface : public TypeInterfaceBase { + public: + struct Concept { + /// Defined these methods with the interface. + explicit Concept(Type (*prim_type)(Type)) : prim_type(prim_type) {} + Type (*prim_type)(Type); + }; + + template + struct Model : public Concept { + static Type prim_type(Type type) { + return pir::cast(type).prim_type(); + } + Model() : Concept(prim_type) {} + }; + + WrapTypeInterface(Type type, Concept *impl) + : TypeInterfaceBase(type), impl_(impl) {} + + Type prim_type() { return impl_->prim_type(*this); } + + private: + Concept *impl_; +}; } // namespace pir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface) +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::WrapTypeInterface) diff --git a/paddle/pir/include/core/builtin_type_storage.h b/paddle/pir/include/core/builtin_type_storage.h index 03f06279a0dfd..f706e0c66277e 100644 --- a/paddle/pir/include/core/builtin_type_storage.h +++ b/paddle/pir/include/core/builtin_type_storage.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/common/ddim.h" #include "paddle/common/dim.h" #include "paddle/common/hash_funcs.h" diff --git a/paddle/pir/include/core/interface_support.h b/paddle/pir/include/core/interface_support.h index a035114e44bf2..9c9eea85f87c1 100644 --- a/paddle/pir/include/core/interface_support.h +++ b/paddle/pir/include/core/interface_support.h @@ -19,40 +19,42 @@ namespace pir { namespace detail { + template class ConstructInterfacesOrTraits { public: /// Construct method for interfaces. static void interface(InterfaceSet &interface_set) { // NOLINT (void)std::initializer_list{ - 0, (ConstrctInterface(interface_set), 0)...}; + 0, (ConstructInterface(interface_set), 0)...}; } /// Construct method for traits. static TypeId *trait(TypeId *p_trait) { (void)std::initializer_list{ - 0, (PlacementConstrctTrait(p_trait), 0)...}; + 0, (PlacementConstructTrait(p_trait), 0)...}; return p_trait; } private: /// Placement new interface. template - static void ConstrctInterface(InterfaceSet &interface_set) { // NOLINT + static void ConstructInterface(InterfaceSet &interface_set) { // NOLINT InterfaceValue val = InterfaceValue::Get>(); - auto suceess = interface_set.insert(std::move(val)).second; - IR_ENFORCE(suceess, - "Interface: id[%u] is already registered. inset failed", - TypeId::get()); - VLOG(10) << "New a interface: id[" << TypeId::get() << "]."; + auto success = interface_set.insert(std::move(val)).second; + PADDLE_ENFORCE_EQ( + success, + true, + phi::errors::PreconditionNotMet( + "Interface: id[%u] is already registered. inset failed", + TypeId::get())); } /// Placement new trait. template - static void PlacementConstrctTrait(pir::TypeId *&p_trait) { // NOLINT + static void PlacementConstructTrait(pir::TypeId *&p_trait) { // NOLINT *p_trait = TypeId::get(); - VLOG(10) << "New a trait: id[" << *p_trait << "]."; ++p_trait; } }; diff --git a/paddle/pir/include/core/interface_value.h b/paddle/pir/include/core/interface_value.h index 00f8cc289143f..64619a0e0f591 100644 --- a/paddle/pir/include/core/interface_value.h +++ b/paddle/pir/include/core/interface_value.h @@ -13,8 +13,10 @@ // limitations under the License. #pragma once + #include #include + #include "paddle/pir/include/core/type_id.h" #include "paddle/pir/include/core/utils.h" diff --git a/paddle/pir/include/core/ir_context.h b/paddle/pir/include/core/ir_context.h index dbf7ff4cdd73e..50ce178531673 100644 --- a/paddle/pir/include/core/ir_context.h +++ b/paddle/pir/include/core/ir_context.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include #include @@ -117,12 +118,12 @@ class IR_API IrContext { void (*verify_region)(Operation *)); /// - /// \brief Get registered operaiton infomation. + /// \brief Get registered operation infomation. /// OpInfo GetRegisteredOpInfo(const std::string &name); /// - /// \brief Get registered operaiton infomation map. + /// \brief Get registered operation infomation map. /// const OpInfoMap ®istered_op_info_map(); diff --git a/paddle/pir/include/core/ir_mapping.h b/paddle/pir/include/core/ir_mapping.h index 83994ea284570..2164c4a85c149 100644 --- a/paddle/pir/include/core/ir_mapping.h +++ b/paddle/pir/include/core/ir_mapping.h @@ -12,7 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. #pragma once + #include + #include "paddle/common/enforce.h" #include "paddle/pir/include/core/value.h" @@ -82,8 +84,10 @@ class IrMapping { template IrType Lookup(T from) const { if (!from) return static_cast>(nullptr); - IR_ENFORCE(GetMap>().count(from) > 0, - "Not found key in IRMapping."); + PADDLE_ENFORCE_GT( + GetMap>().count(from), + 0UL, + phi::errors::InvalidArgument("Not found key in IRMapping.")); return GetMap>().at(from); } diff --git a/paddle/pir/include/core/iterator.h b/paddle/pir/include/core/iterator.h index 8fbfae8cb4b2d..fc88d981c3661 100644 --- a/paddle/pir/include/core/iterator.h +++ b/paddle/pir/include/core/iterator.h @@ -13,9 +13,12 @@ // limitations under the License. #pragma once + #include #include + #include "paddle/common/macros.h" + namespace pir { class Operation; diff --git a/paddle/pir/include/core/op_base.h b/paddle/pir/include/core/op_base.h index 93e6939be8adf..84f4c33131920 100644 --- a/paddle/pir/include/core/op_base.h +++ b/paddle/pir/include/core/op_base.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include "paddle/common/enforce.h" @@ -31,7 +32,9 @@ class IR_API OpBase { explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} Operation *operation() const { - IR_ENFORCE(operation_, "Can't use operation() in a null op."); + PADDLE_ENFORCE_NOT_NULL( + operation_, + phi::errors::InvalidArgument("Can't use operation() in a null op.")); return operation_; } diff --git a/paddle/pir/include/core/op_info.h b/paddle/pir/include/core/op_info.h index fbeb679463a4d..994aed189fc6f 100644 --- a/paddle/pir/include/core/op_info.h +++ b/paddle/pir/include/core/op_info.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include @@ -31,7 +32,7 @@ typedef void (*VerifyPtr)(Operation *op); class IR_API OpInfo { public: - OpInfo() = default; + OpInfo(std::nullptr_t ptr = nullptr){}; // NOLINT OpInfo(const OpInfo &other) = default; diff --git a/paddle/pir/include/core/op_operand.h b/paddle/pir/include/core/op_operand.h index 5366ab390ffa0..4944c31fdb283 100644 --- a/paddle/pir/include/core/op_operand.h +++ b/paddle/pir/include/core/op_operand.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once + #include #include "paddle/pir/include/core/dll_decl.h" diff --git a/paddle/pir/include/core/op_result.h b/paddle/pir/include/core/op_result.h index 04ae0e848e511..89a7b6664230f 100644 --- a/paddle/pir/include/core/op_result.h +++ b/paddle/pir/include/core/op_result.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/pir/include/core/value.h" + namespace pir { namespace detail { @@ -37,6 +38,9 @@ class IR_API OpResult : public Value { Attribute attribute(const std::string &key) const; void set_attribute(const std::string &key, Attribute value); + void *property(const std::string &key) const; + void set_property(const std::string &key, const Property &value); + private: friend Operation; OpResult(detail::OpResultImpl *impl); // NOLINT diff --git a/paddle/pir/include/core/operation.h b/paddle/pir/include/core/operation.h index 66d5da9d0d8ab..7d279e50bff6e 100644 --- a/paddle/pir/include/core/operation.h +++ b/paddle/pir/include/core/operation.h @@ -34,7 +34,7 @@ class OpResult; namespace detail { class OpResultImpl; -class OpOperendImpl; +class OpOperandImpl; } // namespace detail class CloneOptions { @@ -117,6 +117,12 @@ class IR_API alignas(8) Operation final return attributes_.find(key) != attributes_.end(); } + void set_value_property(const std::string &key, + const Property &value, + size_t index); + + void *value_property(const std::string &key, size_t index) const; + /// /// \brief op ouput related public interfaces /// @@ -133,7 +139,7 @@ class IR_API alignas(8) Operation final /// uint32_t num_operands() const { return num_operands_; } OpOperand operand(uint32_t index) const { return op_operand_impl(index); } - std::vector operands(); + std::vector operands() const; Value operand_source(uint32_t index) const; std::vector operands_source() const; Type operand_type(uint32_t index) const { return operand(index).type(); } @@ -229,7 +235,7 @@ class IR_API alignas(8) Operation final void Verify(); - uint64_t id() { return id_; } + uint64_t id() const { return id_; } private: DISABLE_COPY_AND_ASSIGN(Operation); @@ -266,6 +272,9 @@ class IR_API alignas(8) Operation final AttributeMap attributes_; + // store data that user create by Python + std::vector value_properties_; + OpInfo info_; static uint64_t GenerateId() { diff --git a/paddle/pir/include/core/operation_utils.h b/paddle/pir/include/core/operation_utils.h index 4360af17e08a4..88ab019771fbe 100644 --- a/paddle/pir/include/core/operation_utils.h +++ b/paddle/pir/include/core/operation_utils.h @@ -16,6 +16,7 @@ #include #include + #include "paddle/pir/include/core/attribute.h" #include "paddle/pir/include/core/dll_decl.h" #include "paddle/pir/include/core/op_info.h" @@ -27,6 +28,7 @@ namespace pir { class Block; using AttributeMap = std::unordered_map; +using PropertyMap = std::unordered_map; //===----------------------------------------------------------------------===// // OperationArgument diff --git a/paddle/pir/include/core/parameter.h b/paddle/pir/include/core/parameter.h index cad6839ea8851..bfcbe17b3289c 100644 --- a/paddle/pir/include/core/parameter.h +++ b/paddle/pir/include/core/parameter.h @@ -14,6 +14,8 @@ #pragma once +#include + #include "paddle/pir/include/core/type.h" namespace pir { diff --git a/paddle/pir/include/core/region.h b/paddle/pir/include/core/region.h index c141611172f9b..6667aba5392ed 100644 --- a/paddle/pir/include/core/region.h +++ b/paddle/pir/include/core/region.h @@ -53,12 +53,12 @@ class IR_API Region { ReverseIterator rend() { return blocks_.rend(); } ConstReverseIterator rbegin() const { return blocks_.rbegin(); } ConstReverseIterator rend() const { return blocks_.rend(); } + const std::list &blocks() const { return blocks_; } Block &front() { return *blocks_.front(); } Block &back() { return *blocks_.back(); } const Block &front() const { return *blocks_.front(); } const Block &back() const { return *blocks_.back(); } - void push_back(Block *block); Block &emplace_back(); void push_front(Block *block); diff --git a/paddle/pir/include/core/storage_manager.h b/paddle/pir/include/core/storage_manager.h index 8cacc3bd38bd0..7024e580e4a1f 100644 --- a/paddle/pir/include/core/storage_manager.h +++ b/paddle/pir/include/core/storage_manager.h @@ -74,7 +74,7 @@ class IR_API StorageManager { return static_cast(*existing) == param; }; auto constructor = [&]() { - auto *storage = Storage::Construct(param); + auto *storage = Storage::Construct(std::move(param)); if (init_func) init_func(storage); return storage; }; diff --git a/paddle/pir/include/core/storage_manager_support.h b/paddle/pir/include/core/storage_manager_support.h index 9952d2d144d66..614f3938c54e2 100644 --- a/paddle/pir/include/core/storage_manager_support.h +++ b/paddle/pir/include/core/storage_manager_support.h @@ -15,10 +15,9 @@ #pragma once #include + #include "paddle/pir/include/core/interface_support.h" #include "paddle/pir/include/core/ir_context.h" -#include "paddle/pir/include/core/type.h" -#include "paddle/pir/include/core/type_base.h" #include "paddle/pir/include/core/type_id.h" namespace pir { @@ -67,7 +66,7 @@ class StorageHelperBase : public BaseT { typename Filter>::Type; static ConcreteT dyn_cast_impl(BaseT type) { - if (type && type.abstract_type().type_id() == TypeId::get()) { + if (type && type.type_id() == TypeId::get()) { return ConcreteT(type.storage()); } return ConcreteT(nullptr); @@ -91,7 +90,7 @@ class StorageHelperBase : public BaseT { /// template static bool classof(T val) { - return val.type_id() == type_id(); + return val && val.type_id() == type_id(); } /// @@ -106,8 +105,8 @@ class StorageHelperBase : public BaseT { /// \brief Get or create a new ConcreteT instance within the ctx. /// template - static ConcreteT get(pir::IrContext *ctx, Args... args) { - return ManagerT::template get(ctx, args...); + static ConcreteT get(pir::IrContext *ctx, Args &&...args) { + return ManagerT::template get(ctx, std::forward(args)...); } /// diff --git a/paddle/pir/include/core/type.h b/paddle/pir/include/core/type.h index 98ef867bef49b..fcfe0a77a8ac5 100644 --- a/paddle/pir/include/core/type.h +++ b/paddle/pir/include/core/type.h @@ -18,7 +18,9 @@ #include "paddle/pir/include/core/cast_utils.h" #include "paddle/pir/include/core/storage_manager_support.h" +#include "paddle/pir/include/core/type_base.h" #include "paddle/pir/include/core/type_id.h" + namespace pir { class TypeStorage; class AbstractType; @@ -41,7 +43,6 @@ class IR_API Type { StorageType, TypeManager, TraitOrInterface...>; - using Storage = TypeStorage; using AbstractT = AbstractType; @@ -124,6 +125,8 @@ class IR_API Type { bool IsIntOrIndex() const; bool IsIndex() const; + std::size_t hash() const { return std::hash()(storage_); } + protected: const Storage *storage_{nullptr}; @@ -183,8 +186,6 @@ namespace std { /// template <> struct hash { - std::size_t operator()(const pir::Type &obj) const { - return std::hash()(obj); - } + std::size_t operator()(const pir::Type &obj) const { return obj.hash(); } }; } // namespace std diff --git a/paddle/pir/include/core/type_id.h b/paddle/pir/include/core/type_id.h index b6e107c777559..2bce5d92752d2 100644 --- a/paddle/pir/include/core/type_id.h +++ b/paddle/pir/include/core/type_id.h @@ -14,7 +14,6 @@ #pragma once -#include #include #include "paddle/pir/include/core/dll_decl.h" diff --git a/paddle/pir/include/core/value.h b/paddle/pir/include/core/value.h index 0e1a2989e8f37..2e0c46c882b28 100644 --- a/paddle/pir/include/core/value.h +++ b/paddle/pir/include/core/value.h @@ -21,6 +21,8 @@ namespace pir { class Operation; +using PropertiesDeleter = void (*)(void *); +using Property = std::pair; namespace detail { class ValueImpl; @@ -32,12 +34,14 @@ class ValueImpl; /// class IR_API Value { public: - Value() = default; + Value(std::nullptr_t ptr = nullptr){}; // NOLINT Value(detail::ValueImpl *impl) : impl_(impl) {} // NOLINT Value(const Value &other) = default; + Value &operator=(const Value &other) = default; + bool operator==(const Value &other) const; bool operator!=(const Value &other) const; @@ -66,7 +70,7 @@ class IR_API Value { template OpTy defining_op() const { - /// It is safety even if defining_op() return nullptr. + /// It is safe even if defining_op() returns nullptr. return OpTy::dyn_cast(defining_op()); } @@ -114,6 +118,10 @@ class IR_API Value { void set_attribute(const std::string &key, Attribute value); + void set_property(const std::string &key, const Property &value); + + void *property(const std::string &name) const; + protected: detail::ValueImpl *impl_{nullptr}; }; diff --git a/paddle/pir/include/core/visitors.h b/paddle/pir/include/core/visitors.h index c2cf137e44624..31f0262865127 100644 --- a/paddle/pir/include/core/visitors.h +++ b/paddle/pir/include/core/visitors.h @@ -14,6 +14,7 @@ #pragma once #include + #include "paddle/pir/include/core/dll_decl.h" namespace pir { diff --git a/paddle/pir/include/dialect/control_flow/ir/cf_op.h b/paddle/pir/include/dialect/control_flow/ir/cf_op.h index 0d6e60a017ab3..8d49f60e32617 100644 --- a/paddle/pir/include/dialect/control_flow/ir/cf_op.h +++ b/paddle/pir/include/dialect/control_flow/ir/cf_op.h @@ -13,7 +13,9 @@ // limitations under the License. #pragma once + #include + #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/core/op_base.h" #include "paddle/pir/include/core/op_trait.h" @@ -82,6 +84,7 @@ class IR_API TuplePopOp : public Op { void VerifySig(); void VerifyRegion(); + bool has_container() { return outlet().defining_op(); } Value container() { return container_interface().container(); } Value inlet() { return container_interface().inlet(); } Value outlet() { return operand_source(0); } diff --git a/paddle/pir/include/dialect/shape/ir/shape_op.h b/paddle/pir/include/dialect/shape/ir/shape_op.h index 84440d64abc43..3bc7562eaf0e4 100644 --- a/paddle/pir/include/dialect/shape/ir/shape_op.h +++ b/paddle/pir/include/dialect/shape/ir/shape_op.h @@ -15,6 +15,7 @@ #pragma once #include + #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/core/builtin_type_interfaces.h" #include "paddle/pir/include/core/ir_printer.h" diff --git a/paddle/pir/include/dialect/shape/utils/dim_expr.h b/paddle/pir/include/dialect/shape/utils/dim_expr.h index ef141a3d3329c..2999858522d6d 100644 --- a/paddle/pir/include/dialect/shape/utils/dim_expr.h +++ b/paddle/pir/include/dialect/shape/utils/dim_expr.h @@ -28,7 +28,8 @@ namespace symbol { -#define SYMBOL_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented" +#define SYMBOL_NOT_IMPLEMENTED \ + PADDLE_THROW(phi::errors::Unimplemented("Not Implemented")) template struct Overloaded : Ts... { @@ -225,8 +226,6 @@ class IR_API DimExpr : public DimExprBase { // | Broadcastable DimExpr using DimExprConstraint = std::variant, Broadcastable>; -// ShapeOrDataDimExprs = (tShape [DimExpr], tData (opt [DimExpr])) - IR_API std::string ToString(const DimExpr& dim_expr); IR_API std::ostream& operator<<(std::ostream&, const DimExpr& dim_expr); diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/pir/include/dialect/shape/utils/dim_expr_util.h similarity index 59% rename from paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h rename to paddle/pir/include/dialect/shape/utils/dim_expr_util.h index e63d58886d46f..8c10ef805875f 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/pir/include/dialect/shape/utils/dim_expr_util.h @@ -14,17 +14,20 @@ #pragma once -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" +#include +#include -// TODO(ljz) Automatic this process in cmake file. -namespace paddle { -namespace distributed { -namespace auto_parallel { +#include "paddle/pir/include/dialect/shape/utils/dim_expr.h" -// replicated rule -REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule); +namespace symbol { -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle +IR_API DimExpr SimplifyDimExpr(const DimExpr& dim_expr); + +IR_API DimExpr SubstituteDimExpr( + const DimExpr& dim_expr, + const std::unordered_map& pattern_to_replacement); + +IR_API std::unordered_set CollectDimExprSymbols( + const DimExpr& dim_expr); + +} // namespace symbol diff --git a/paddle/pir/include/dialect/shape/utils/shape_analysis.h b/paddle/pir/include/dialect/shape/utils/shape_analysis.h index 284487b7210c5..fd3a5b45fee05 100644 --- a/paddle/pir/include/dialect/shape/utils/shape_analysis.h +++ b/paddle/pir/include/dialect/shape/utils/shape_analysis.h @@ -28,8 +28,6 @@ namespace pir { // The implementation is based on shape constraint ir. class IR_API ShapeConstraintIRAnalysis { public: - explicit ShapeConstraintIRAnalysis(ModuleOp m); - void Init(); const std::string GetNextSymName(); @@ -41,7 +39,7 @@ class IR_API ShapeConstraintIRAnalysis { void SetShapeOrDataForValue(Value val, const symbol::ShapeOrDataDimExprs& shape_or_data); - symbol::DimExprBuilder CreateDimExprBuilder(); + symbol::DimExprBuilder DimExprBuilder(); // Used to debug void PrintShapeOrDatas() const; @@ -75,6 +73,9 @@ class IR_API ShapeConstraintIRAnalysis { pir::PrintHooks PrintHook() const; + symbol::DimExpr GetProductDimExpr(Value lhs, + const std::vector& lhs_dim_idxs) const; + private: ModuleOp m_; @@ -100,4 +101,8 @@ class IR_API ShapeAnalysisManager { std::unordered_map tables_; }; +#define OP_DECLARE_INFER_SYMBOLIC_SHAPE(name) \ + bool name##OpInferSymbolicShape( \ + pir::Operation* op, pir::ShapeConstraintIRAnalysis* shape_analysis); + } // namespace pir diff --git a/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h b/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h index b4a537a9a0d6b..bada3c93d5cc6 100644 --- a/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h +++ b/paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h @@ -15,7 +15,7 @@ #pragma once #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace symbol { @@ -26,22 +26,29 @@ class ShapeOrData { : shape_(shape), data_(std::nullopt) {} explicit ShapeOrData(const std::vector& shape, const std::vector& data) : shape_(shape), data_(data) { - // Vaild check + // Valid check if (shape.size() == 0) { - IR_ENFORCE(data.size() == 1, - "When shape is 0-D, size of data shoubld be 1, but got %d.", - data.size()); + PADDLE_ENFORCE_EQ( + data.size(), + 1UL, + phi::errors::InvalidArgument( + "When shape is 0-D, size of data should be 1, but got %d.", + data.size())); } else if (shape.size() == 1) { - IR_ENFORCE(shape[0].template Has(), - "When shape is 1-D, value of shape shoubld be int"); - IR_ENFORCE( + PADDLE_ENFORCE_EQ(shape[0].template Has(), + true, + phi::errors::InvalidArgument( + "When shape is 1-D, value of shape should be int")); + PADDLE_ENFORCE_EQ( shape[0].template Get() == static_cast(data.size()), - "When shape is 1-D, size of data shoubld be the same as " - "value[%d] of shape, but got [%d].", - shape[0].template Get(), - data.size()); + true, + phi::errors::InvalidArgument( + "When shape is 1-D, size of data should be the same as " + "value[%d] of shape, but got [%d].", + shape[0].template Get(), + data.size())); } else { - IR_THROW("Size of shape shoubld be 0 or 1, but got %d", shape.size()); + IR_THROW("Size of shape should be 0 or 1, but got %d", shape.size()); } } @@ -60,7 +67,7 @@ class ShapeOrData { bool operator==(const ShapeOrData& other) const { if (data_.has_value() && !other.data_.has_value()) return false; if (!data_.has_value() && other.data_.has_value()) return false; - if (shape_.size() != shape_.size()) return false; + if (shape_.size() != other.shape_.size()) return false; if (data_.has_value() && other.data_.has_value()) { if (data_.value().size() != other.data_.value().size()) return false; @@ -128,26 +135,32 @@ class ShapeOrDataDimExprs : public ShapeOrDataDimExprsBase { } const std::vector& shape() const { - IR_ENFORCE( + PADDLE_ENFORCE_EQ( std::holds_alternative(*this), - "Shape of ShapeOrData is not a vector, check wheather the value is a " - "tensor-list or not."); + true, + phi::errors::PreconditionNotMet("Shape of ShapeOrData is not a vector, " + "check whether the value is a " + "tensor-list or not.")); return std::get(*this).shape(); } const std::optional>& data() const { - IR_ENFORCE( + PADDLE_ENFORCE_EQ( std::holds_alternative(*this), - "Data of ShapeOrData is not a vector, check wheather the value is a " - "tensor-list or not."); + true, + phi::errors::PreconditionNotMet( + "Data of ShapeOrData is not a vector, check whether the value is a " + "tensor-list or not.")); return std::get(*this).data(); } void SetData(const std::vector& data) { - IR_ENFORCE( + PADDLE_ENFORCE_EQ( std::holds_alternative(*this), - "Data of ShapeOrData is not a vector, check wheather the value is a " - "tensor-list or not."); + true, + phi::errors::PreconditionNotMet( + "Data of ShapeOrData is not a vector, check whether the value is a " + "tensor-list or not.")); std::get(*this).SetData(data); } diff --git a/paddle/pir/include/pass/pass.h b/paddle/pir/include/pass/pass.h index 3be04b71051f7..48fd795522cdf 100644 --- a/paddle/pir/include/pass/pass.h +++ b/paddle/pir/include/pass/pass.h @@ -23,6 +23,7 @@ #include "paddle/common/enforce.h" #include "paddle/pir/include/pass/analysis_manager.h" #include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h" +#include "paddle/pir/include/pattern_rewrite/pattern_rewrite_driver.h" namespace pir { @@ -70,12 +71,12 @@ struct PassInfo { } // namespace detail -static const char kParamScopeAttr[] = "__param_scope__"; -static const char kPlaceAttr[] = "__place__"; - /// We can access pass only from PassManager. class IR_API Pass { public: + inline static const char kParamScopeAttr[] = "__param_scope__"; + inline static const char kPlaceAttr[] = "__place__"; + explicit Pass(const std::string& name, uint8_t opt_level, const std::vector& dependents = {}) @@ -90,9 +91,10 @@ class IR_API Pass { // Get a reference to the attributed previously set. template AttrType& Get(const std::string& attr_name) const { - IR_ENFORCE(attrs_.find(attr_name) != attrs_.end(), - "Attribute %s not registered for pass.", - attr_name); + PADDLE_ENFORCE_EQ(attrs_.find(attr_name) != attrs_.end(), + true, + phi::errors::InvalidArgument( + "Attribute %s not registered for pass.", attr_name)); try { return *std::any_cast(attrs_.at(attr_name)); } catch (std::bad_any_cast&) { @@ -136,25 +138,21 @@ class IR_API Pass { // Set a pointer to the attribute. Pass takes ownership of the attribute. template void Set(const std::string& attr_name, AttrType* attr) { - VLOG(3) << "Setting the attribute " << attr_name << " for the pass " - << name(); if (Has(attr_name)) { Erase(attr_name); } attrs_[attr_name] = attr; - attr_dels_[attr_name] = [attr, attr_name]() { - VLOG(8) << "deleting " << attr_name; - delete attr; - }; + attr_dels_[attr_name] = [attr, attr_name]() { delete attr; }; } // Set a pointer to the attribute. Pass doesn't take ownership. Caller // should delete the attribute. template void SetNotOwned(const std::string& attr_name, AttrType* attr) { - VLOG(3) << "Setting the attribute " << attr_name << " for the " << name(); - IR_ENFORCE( - !Has(attr_name), "Attribute %s already set in the pass.", attr_name); + PADDLE_ENFORCE_EQ(!Has(attr_name), + true, + phi::errors::InvalidArgument( + "Attribute %s already set in the pass.", attr_name)); attrs_[attr_name] = attr; } @@ -206,12 +204,16 @@ class IR_API PatternRewritePass : public Pass { protected: virtual RewritePatternSet InitializePatterns(IrContext* context) = 0; + virtual GreedyRewriteConfig InitializeConfig(); + bool Initialize(IrContext* context) final; void Run(Operation* op) override; private: FrozenRewritePatternSet patterns_; + + GreedyRewriteConfig config_; }; } // namespace pir diff --git a/paddle/pir/include/pass/pass_registry.h b/paddle/pir/include/pass/pass_registry.h index 9350a98ee616d..9fba4e09c5433 100644 --- a/paddle/pir/include/pass/pass_registry.h +++ b/paddle/pir/include/pass/pass_registry.h @@ -34,14 +34,18 @@ class PassRegistry { } void Insert(const std::string &pass_type, const PassCreator &pass_creator) { - IR_ENFORCE( - Has(pass_type) != true, "Pass %s has been registered.", pass_type); + PADDLE_ENFORCE_NE(Has(pass_type), + true, + phi::errors::InvalidArgument( + "Pass %s has been registered.", pass_type)); pass_map_.insert({pass_type, pass_creator}); } std::unique_ptr Get(const std::string &pass_type) const { - IR_ENFORCE( - Has(pass_type) == true, "Pass %s has not been registered.", pass_type); + PADDLE_ENFORCE_EQ(Has(pass_type), + true, + phi::errors::InvalidArgument( + "Pass %s has not been registered.", pass_type)); return pass_map_.at(pass_type)(); } diff --git a/paddle/pir/src/core/block.cc b/paddle/pir/src/core/block.cc index 258f681b303cb..1d9021a47b47b 100644 --- a/paddle/pir/src/core/block.cc +++ b/paddle/pir/src/core/block.cc @@ -14,6 +14,7 @@ #include "paddle/pir/include/core/block.h" +#include #include #include "paddle/common/enforce.h" @@ -23,7 +24,10 @@ namespace pir { Block::~Block() { if (!use_empty()) { - LOG(FATAL) << "Destroyed a block that is still in use."; + auto parent_op = GetParentOp(); + PADDLE_FATAL( + "Destroyed a block that is still in use.. The parent op is : %s", + parent_op ? parent_op->name() : std::string("nullptr")); } ClearOps(); ClearKwargs(); diff --git a/paddle/pir/src/core/block_argument.cc b/paddle/pir/src/core/block_argument.cc index 99a799e9f592e..85ed7e2fa6b77 100644 --- a/paddle/pir/src/core/block_argument.cc +++ b/paddle/pir/src/core/block_argument.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "paddle/pir/include/core/block_argument.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/operation_utils.h" @@ -73,7 +75,17 @@ class BlockArgumentImpl : public ValueImpl { BlockArgumentImpl::~BlockArgumentImpl() { if (!use_empty()) { - LOG(FATAL) << "Destroyed a block argument that is still in use."; + if (is_kwarg_) { + PADDLE_FATAL( + "Destroyed a keyword block argument that is still in use. The key is " + ": %s", + keyword_); + } else { + PADDLE_FATAL( + "Destroyed a position block argument that is still in use. The index " + "is : %u", + index_); + } } } diff --git a/paddle/pir/src/core/block_operand_impl.h b/paddle/pir/src/core/block_operand_impl.h index 8cd331d87ab7a..0293ea36d7ca8 100644 --- a/paddle/pir/src/core/block_operand_impl.h +++ b/paddle/pir/src/core/block_operand_impl.h @@ -44,8 +44,8 @@ class BlockOperandImpl { private: BlockOperandImpl(Block* source, Operation* owner); - // Insert self to the UD chain holded by source_; - // It is not safe. So set provate. + // Insert self to the UD chain held by source_; + // It is not safe. So set private. void InsertToUdChain(); BlockOperand next_use_ = nullptr; diff --git a/paddle/pir/src/core/builder.cc b/paddle/pir/src/core/builder.cc index 80147428922ba..2b6d000b8639e 100644 --- a/paddle/pir/src/core/builder.cc +++ b/paddle/pir/src/core/builder.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "paddle/pir/include/core/builder.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_type.h" diff --git a/paddle/pir/src/core/builtin_dialect.cc b/paddle/pir/src/core/builtin_dialect.cc index 8b450ffbc1d09..db4fc1808c300 100644 --- a/paddle/pir/src/core/builtin_dialect.cc +++ b/paddle/pir/src/core/builtin_dialect.cc @@ -13,12 +13,16 @@ // limitations under the License. #include "paddle/pir/include/core/builtin_dialect.h" + +#include "paddle/common/ddim.h" +#include "paddle/common/layout.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/core/builtin_type.h" +#include "paddle/pir/include/core/parser/ir_parser.h" namespace pir { -BuiltinDialect::BuiltinDialect(IrContext *context) +BuiltinDialect::BuiltinDialect(IrContext* context) : Dialect(name(), context, TypeId::get()) { initialize(); } @@ -38,7 +42,8 @@ void BuiltinDialect::initialize() { BoolType, Complex64Type, Complex128Type, - VectorType>(); + VectorType, + DenseTensorType>(); RegisterAttributes(); } +pir::Type BuiltinDialect::ParseType(pir::IrParser& parser) { // NOLINT + parser.ConsumeAToken("builtin.tensor"); + parser.ConsumeAToken("<"); + std::vector dim{}; + Token dim_token = parser.PeekToken(); + while (dim_token.token_type_ == DIGIT) { + dim_token = parser.ConsumeToken(); + dim.push_back(atoi(dim_token.val_.c_str())); + std::string peek_token_val = parser.PeekToken().val_; + if (peek_token_val[0] != 'x') { + break; + } + parser.ConsumeToken(); + parser.lexer->Unget(static_cast(peek_token_val.size() - 1)); + if (parser.PeekToken().token_type_ != DIGIT) { + break; + } + } + pir::DDim ddim = common::make_ddim(dim); + pir::Type dtype = parser.ParseType(); + std::vector> lod; + std::vector lodv; + lodv.push_back(0); + lod.push_back(lodv); + parser.ConsumeAToken(">"); + return DenseTensorType::get( + parser.ctx, dtype, ddim, pir::DataLayout::UNDEFINED, lod, 0); +} + +void BuiltinDialect::PrintType(pir::Type type, std::ostream& os) const { + os << type.dialect().name(); + os << '.'; + if (auto tensor_type = type.dyn_cast()) { + os << "tensor<"; + for (auto d : common::vectorize(tensor_type.dims())) { + os << d; + os << "x"; + } + tensor_type.dtype().Print(os); + os << ">"; + } +} + } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::BuiltinDialect) diff --git a/paddle/pir/src/core/builtin_op.cc b/paddle/pir/src/core/builtin_op.cc index 24b7624dafc63..fca2ebe63eea5 100644 --- a/paddle/pir/src/core/builtin_op.cc +++ b/paddle/pir/src/core/builtin_op.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/include/core/builtin_op.h" +#include + #include "paddle/common/enforce.h" #include "paddle/pir/include/core/builtin_attribute.h" +#include "paddle/pir/include/core/builtin_op.h" #include "paddle/pir/include/core/builtin_type.h" namespace pir { diff --git a/paddle/pir/src/core/builtin_type.cc b/paddle/pir/src/core/builtin_type.cc index 0da20a6b83bd1..6a1f5f9b26fd6 100644 --- a/paddle/pir/src/core/builtin_type.cc +++ b/paddle/pir/src/core/builtin_type.cc @@ -30,6 +30,27 @@ const DenseTensorType::LoD& DenseTensorType::lod() const { } size_t DenseTensorType::offset() const { return storage()->offset_; } + +bool DenseTensorType::classof(Type type) { + if (type) { + if (type.type_id() == type_id()) return true; + if (auto wrap_type = type.dyn_cast()) { + return classof(wrap_type.prim_type()); + } + } + return false; +} + +DenseTensorType DenseTensorType::dyn_cast_impl(Type type) { + if (type) { + if (type.type_id() == type_id()) return DenseTensorType(type.storage()); + if (auto wrap_type = type.dyn_cast()) { + return dyn_cast_impl(wrap_type.prim_type()); + } + } + return nullptr; +} + } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::UInt8Type) diff --git a/paddle/pir/src/core/builtin_type_interfaces.cc b/paddle/pir/src/core/builtin_type_interfaces.cc index de0538eacc0d9..25ec38c709bef 100644 --- a/paddle/pir/src/core/builtin_type_interfaces.cc +++ b/paddle/pir/src/core/builtin_type_interfaces.cc @@ -18,12 +18,13 @@ namespace pir { Type ShapedTypeInterface::GetElementType() const { - return impl_->get_element_type(*this); + return impl_->get_element_type(*this); // NOLINT } pir::DDim ShapedTypeInterface::GetShape() const { - return impl_->get_shape(*this); + return impl_->get_shape(*this); // NOLINT } } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::ShapedTypeInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::WrapTypeInterface) diff --git a/paddle/pir/src/core/dialect.cc b/paddle/pir/src/core/dialect.cc index b09709da6b0db..668c56111d0ac 100644 --- a/paddle/pir/src/core/dialect.cc +++ b/paddle/pir/src/core/dialect.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "paddle/pir/include/core/dialect.h" namespace pir { diff --git a/paddle/pir/src/core/ir_context.cc b/paddle/pir/src/core/ir_context.cc index a4839bb2d4a34..90393fe4370b9 100644 --- a/paddle/pir/src/core/ir_context.cc +++ b/paddle/pir/src/core/ir_context.cc @@ -14,6 +14,7 @@ #include "paddle/pir/include/core/ir_context.h" +#include #include #include "paddle/pir/include/core/attribute_base.h" diff --git a/paddle/pir/src/core/ir_printer.cc b/paddle/pir/src/core/ir_printer.cc index de75d6d2fc603..e2bc7757f9de4 100644 --- a/paddle/pir/src/core/ir_printer.cc +++ b/paddle/pir/src/core/ir_printer.cc @@ -279,6 +279,10 @@ void IrPrinter::PrintAttributeMap(Operation* op) { AttributeMap attributes = op->attributes(); std::map> order_attributes( attributes.begin(), attributes.end()); + + // Filter out the callstack attribute + order_attributes.erase("op_callstack"); + os << " {"; pir::detail::PrintInterleave( diff --git a/paddle/pir/src/core/op_info_impl.cc b/paddle/pir/src/core/op_info_impl.cc index efbcedf42cc0f..f9d5295671113 100644 --- a/paddle/pir/src/core/op_info_impl.cc +++ b/paddle/pir/src/core/op_info_impl.cc @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/src/core/op_info_impl.h" +#include + #include "paddle/pir/include/core/dialect.h" #include "paddle/pir/include/core/interface_support.h" +#include "paddle/pir/src/core/op_info_impl.h" namespace pir { diff --git a/paddle/pir/src/core/op_operand.cc b/paddle/pir/src/core/op_operand.cc index 5c27cd4943ca6..06c0d79ed9ae0 100644 --- a/paddle/pir/src/core/op_operand.cc +++ b/paddle/pir/src/core/op_operand.cc @@ -22,8 +22,8 @@ "impl_ pointer is null when call func:" #func_name \ " , in class: " #class_name ".") -#define CHECK_OPOPEREND_NULL_IMPL(func_name) \ - CHECK_NULL_IMPL(OpOpernad, func_name) +#define CHECK_OP_OPERAND_NULL_IMPL(func_name) \ + CHECK_NULL_IMPL(OpOperand, func_name) namespace pir { OpOperand &OpOperand::operator=(const OpOperand &rhs) { // NOLINT @@ -37,34 +37,34 @@ OpOperand &OpOperand::operator=(const OpOperand &rhs) { // NOLINT OpOperand::operator bool() const { return impl_ && impl_->source(); } OpOperand OpOperand::next_use() const { - CHECK_OPOPEREND_NULL_IMPL(next_use); + CHECK_OP_OPERAND_NULL_IMPL(next_use); return impl_->next_use(); } Value OpOperand::source() const { - CHECK_OPOPEREND_NULL_IMPL(source); + CHECK_OP_OPERAND_NULL_IMPL(source); return impl_->source(); } Type OpOperand::type() const { return source().type(); } void OpOperand::set_source(Value value) { - CHECK_OPOPEREND_NULL_IMPL(set_source); + CHECK_OP_OPERAND_NULL_IMPL(set_source); impl_->set_source(value); } Operation *OpOperand::owner() const { - CHECK_OPOPEREND_NULL_IMPL(owner); + CHECK_OP_OPERAND_NULL_IMPL(owner); return impl_->owner(); } uint32_t OpOperand::index() const { - CHECK_OPOPEREND_NULL_IMPL(index); + CHECK_OP_OPERAND_NULL_IMPL(index); return impl_->index(); } void OpOperand::RemoveFromUdChain() { - CHECK_OPOPEREND_NULL_IMPL(RemoveFromUdChain); + CHECK_OP_OPERAND_NULL_IMPL(RemoveFromUdChain); return impl_->RemoveFromUdChain(); } diff --git a/paddle/pir/src/core/op_operand_impl.h b/paddle/pir/src/core/op_operand_impl.h index f83c54f58acfa..9dc3e29ce764e 100644 --- a/paddle/pir/src/core/op_operand_impl.h +++ b/paddle/pir/src/core/op_operand_impl.h @@ -46,7 +46,7 @@ class OpOperandImpl { private: OpOperandImpl(Value source, Operation *owner); - // Insert self to the UD chain holded by source_; + // Insert self to the UD chain held by source_; // It is not safe. So set private. void InsertToUdChain(); diff --git a/paddle/pir/src/core/op_result.cc b/paddle/pir/src/core/op_result.cc index 44b2e81ad953b..cd72b5b2800b7 100644 --- a/paddle/pir/src/core/op_result.cc +++ b/paddle/pir/src/core/op_result.cc @@ -57,6 +57,14 @@ void OpResult::set_attribute(const std::string &key, Attribute value) { return IMPL_->set_attribute(key, value); } +void *OpResult::property(const std::string &key) const { + return impl_ ? IMPL_->property(key) : nullptr; +} +void OpResult::set_property(const std::string &key, const Property &value) { + CHECK_OPRESULT_NULL_IMPL(set_property); + return IMPL_->set_property(key, value); +} + OpResult::OpResult(detail::OpResultImpl *impl) : Value(impl) {} } // namespace pir diff --git a/paddle/pir/src/core/op_result_impl.cc b/paddle/pir/src/core/op_result_impl.cc index 3bc9e5023b3b2..e03c4ad5b8292 100644 --- a/paddle/pir/src/core/op_result_impl.cc +++ b/paddle/pir/src/core/op_result_impl.cc @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/src/core/op_result_impl.h" +#include + +#include "paddle/common/enforce.h" #include "paddle/pir/include/core/builtin_attribute.h" #include "paddle/pir/include/core/operation.h" +#include "paddle/pir/src/core/op_result_impl.h" namespace pir { namespace detail { @@ -28,8 +31,9 @@ uint32_t OpResultImpl::index() const { OpResultImpl::~OpResultImpl() { if (!use_empty()) { - LOG(FATAL) << "Destroyed a op_result that is still in use. \n" - << "The owner op type is:" << owner()->name(); + PADDLE_FATAL( + "Destroyed a op_result that is still in use. The owner op type is: %s", + owner()->name()); } } @@ -71,11 +75,12 @@ Attribute OpResultImpl::attribute(const std::string &key) const { void OpResultImpl::set_attribute(const std::string &key, Attribute value) { auto owner = this->owner(); auto attr = owner->attribute(key); - if (attr && !attr.isa()) { - IR_THROW( - "The %s attribute has existed as operation attribute. Can't set it as " - "value attribute. "); - } + PADDLE_ENFORCE_EQ(attr && !attr.isa(), + false, + common::errors::PreconditionNotMet( + "The %s attribute has existed as operation attribute. " + "Can't set it as value attribute. ", + key)); auto array_attr = attr.dyn_cast(); auto index = this->index(); std::vector vec; @@ -85,5 +90,24 @@ void OpResultImpl::set_attribute(const std::string &key, Attribute value) { owner->set_attribute(key, ArrayAttribute::get(owner->ir_context(), vec)); } +void *OpResultImpl::property(const std::string &key) const { + return owner()->value_property(key, index()); +} + +void OpResultImpl::set_property(const std::string &key, const Property &value) { + auto owner = this->owner(); + owner->set_value_property(key, value, index()); +} + +OpInlineResultImpl::OpInlineResultImpl(Type type, uint32_t result_index) + : OpResultImpl(type, result_index) { + PADDLE_ENFORCE_LE( + result_index, + MAX_INLINE_RESULT_IDX, + common::errors::PreconditionNotMet( + "Inline result index [%u] should not exceed MaxInlineResultIndex(5)", + result_index)); +} + } // namespace detail } // namespace pir diff --git a/paddle/pir/src/core/op_result_impl.h b/paddle/pir/src/core/op_result_impl.h index b50b2dd94a258..eb3bd46a1fd4a 100644 --- a/paddle/pir/src/core/op_result_impl.h +++ b/paddle/pir/src/core/op_result_impl.h @@ -42,7 +42,7 @@ class OpResultImpl : public ValueImpl { /// uint32_t index() const; - ~OpResultImpl(); + TEST_API ~OpResultImpl(); /// /// \brief attribute related public interfaces @@ -50,6 +50,9 @@ class OpResultImpl : public ValueImpl { Attribute attribute(const std::string &key) const; void set_attribute(const std::string &key, Attribute value); + void *property(const std::string &key) const; + void set_property(const std::string &key, const Property &value); + private: int32_t ComputeOperationOffset() const; }; @@ -60,12 +63,7 @@ class OpResultImpl : public ValueImpl { /// class OpInlineResultImpl : public OpResultImpl { public: - OpInlineResultImpl(Type type, uint32_t result_index) - : OpResultImpl(type, result_index) { - if (result_index > MAX_INLINE_RESULT_IDX) { - throw("Inline result index should not exceed MaxInlineResultIndex(5)"); - } - } + TEST_API OpInlineResultImpl(Type type, uint32_t result_index); static bool classof(const ValueImpl &value) { return value.kind() < OUTLINE_RESULT_IDX; diff --git a/paddle/pir/src/core/op_trait.cc b/paddle/pir/src/core/op_trait.cc index 4261dbcc8a457..39a0f6001da18 100644 --- a/paddle/pir/src/core/op_trait.cc +++ b/paddle/pir/src/core/op_trait.cc @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/include/core/op_trait.h" +#include + #include "paddle/common/enforce.h" +#include "paddle/pir/include/core/op_trait.h" #include "paddle/pir/include/core/type_utils.h" namespace { diff --git a/paddle/pir/src/core/operation.cc b/paddle/pir/src/core/operation.cc index e7dce069ebd81..b1b09c60344f6 100644 --- a/paddle/pir/src/core/operation.cc +++ b/paddle/pir/src/core/operation.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include #include #include @@ -198,10 +199,19 @@ void Operation::Destroy() { } } - // 3. Deconstruct Operation. + // 3. Deconstruct Properties. + for (auto &value_property : value_properties_) { + for (auto &property_map : value_property) { + if (property_map.second.second) { + property_map.second.second((property_map.second.first)); + } + } + } + + // 4. Deconstruct Operation. this->~Operation(); - // 4. Deconstruct OpOperand. + // 5. Deconstruct OpOperand. for (size_t idx = 0; idx < num_operands_; idx++) { detail::OpOperandImpl *op_operand_impl = operand(idx).impl_; if (op_operand_impl) { @@ -209,7 +219,7 @@ void Operation::Destroy() { } } - // 5. Deconstruct BlockOperand. + // 6. Deconstruct BlockOperand. for (size_t idx = 0; idx < num_successors_; idx++) { detail::BlockOperandImpl *block_operand_impl = block_operands_ + idx; if (block_operand_impl) { @@ -217,7 +227,7 @@ void Operation::Destroy() { } } - // 5. Free memory. + // 7. Free memory. size_t result_mem_size = num_results_ > OUTLINE_RESULT_IDX ? sizeof(detail::OpOutlineResultImpl) * @@ -263,7 +273,7 @@ std::vector Operation::results() const { /// /// \brief op input related public interfaces /// -std::vector Operation::operands() { +std::vector Operation::operands() const { std::vector res; for (uint32_t i = 0; i < num_operands(); ++i) { res.push_back(operand(i)); @@ -371,9 +381,13 @@ void Operation::Verify() { } int32_t Operation::ComputeOpResultOffset(uint32_t index) const { - if (index >= num_results_) { - LOG(FATAL) << "index exceeds OP op result range."; - } + PADDLE_ENFORCE_LT( + index, + num_results_, + common::errors::PreconditionNotMet( + "The op result index [%u] must less than results size[%u].", + index, + num_results_)); if (index < OUTLINE_RESULT_IDX) { return -static_cast((index + 1u) * sizeof(OpInlineResultImpl)); } @@ -383,13 +397,39 @@ int32_t Operation::ComputeOpResultOffset(uint32_t index) const { } int32_t Operation::ComputeOpOperandOffset(uint32_t index) const { - if (index >= num_operands_) { - LOG(FATAL) << "index exceeds OP op operand range."; - } + PADDLE_ENFORCE_LT( + index, + num_operands_, + common::errors::PreconditionNotMet( + "The op operand index [%u] must less than operands size[%u].", + index, + num_operands_)); return static_cast(index * sizeof(OpOperandImpl) + sizeof(Operation)); } +void Operation::set_value_property(const std::string &key, + const Property &value, + size_t index) { + if (value_properties_.size() < index + 1) { + value_properties_.resize(index + 1); + } + auto &property_map = value_properties_[index]; + if (property_map.count(key)) { + property_map[key].second(property_map[key].first); + } + property_map[key] = value; +} + +void *Operation::value_property(const std::string &key, size_t index) const { + if (value_properties_.size() < (index + 1)) { + return nullptr; + } + auto &property_map = value_properties_[index]; + auto iter = property_map.find(key); + return iter == property_map.end() ? nullptr : iter->second.first; +} + #define COMPONENT_IMPL(component_lower, component_upper) \ component_upper##Impl *Operation::component_lower##_impl(uint32_t index) \ const { \ diff --git a/paddle/pir/src/core/parser/ir_parser.cc b/paddle/pir/src/core/parser/ir_parser.cc index 3f45573509305..5fe0cc55320ec 100644 --- a/paddle/pir/src/core/parser/ir_parser.cc +++ b/paddle/pir/src/core/parser/ir_parser.cc @@ -211,7 +211,7 @@ Operation* IrParser::ParseOperation() { std::vector value_index = ParseValueList(); ConsumeAToken("="); - OpInfo opinfo = ParseOpInfo(); + OpInfo op_info = ParseOpInfo(); std::vector inputs = ParseOperandList(); @@ -226,7 +226,7 @@ Operation* IrParser::ParseOperation() { std::vector type_vector = ParseTypeList(); Operation* op = - Operation::Create(inputs, attributeMap, type_vector, opinfo, 0); + Operation::Create(inputs, attributeMap, type_vector, op_info, 0); for (uint32_t i = 0; i < op->num_results(); i++) { std::string key_t = value_index[i]; diff --git a/paddle/pir/src/core/parser/lexer.cc b/paddle/pir/src/core/parser/lexer.cc index 7914063d148c0..fa93033074094 100644 --- a/paddle/pir/src/core/parser/lexer.cc +++ b/paddle/pir/src/core/parser/lexer.cc @@ -144,13 +144,13 @@ std::unique_ptr Lexer::LexEndTagOrNullVal() { new Token{"<<" + token_null_val + ">>", NULL_}); return null_token; } else { - std::string token_attrnull = ""; + std::string token_attr_null = ""; while (is.peek() != '>') { - token_attrnull += GetChar(); + token_attr_null += GetChar(); } GetChar(); std::unique_ptr null_token( - new Token{"<" + token_attrnull + ">", NULL_}); + new Token{"<" + token_attr_null + ">", NULL_}); return null_token; } } diff --git a/paddle/pir/src/core/storage_manager.cc b/paddle/pir/src/core/storage_manager.cc index 6018917062d43..a6fb1621292a6 100644 --- a/paddle/pir/src/core/storage_manager.cc +++ b/paddle/pir/src/core/storage_manager.cc @@ -14,6 +14,7 @@ #include "paddle/pir/include/core/storage_manager.h" +#include #include #include diff --git a/paddle/pir/src/core/value.cc b/paddle/pir/src/core/value.cc index 43bdf200c381e..da587e27f9475 100644 --- a/paddle/pir/src/core/value.cc +++ b/paddle/pir/src/core/value.cc @@ -110,4 +110,22 @@ void Value::set_attribute(const std::string &key, Attribute value) { return dyn_cast().set_attribute(key, value); } +void Value::set_property(const std::string &key, const Property &value) { + auto op_result = dyn_cast(); + PADDLE_ENFORCE_NE(op_result, + nullptr, + common::errors::PreconditionNotMet( + "The Value is not an OpResult, we can set property " + "only for OpResult currently")); + return op_result.set_property(key, value); +} + +void *Value::property(const std::string &key) const { + auto op_result = dyn_cast(); + if (op_result) { + return op_result.property(key); + } else { + return nullptr; + } +} } // namespace pir diff --git a/paddle/pir/src/core/value_impl.cc b/paddle/pir/src/core/value_impl.cc index 37dcb48370b6e..b5b41374497cc 100644 --- a/paddle/pir/src/core/value_impl.cc +++ b/paddle/pir/src/core/value_impl.cc @@ -12,6 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + +#include "paddle/common/enforce.h" #include "paddle/pir/src/core/value_impl.h" namespace { @@ -48,10 +51,12 @@ std::string ValueImpl::PrintUdChain() { return result.str(); } ValueImpl::ValueImpl(Type type, uint32_t kind) : id_(GenerateId()) { - if (kind > BLOCK_ARG_IDX) { - LOG(FATAL) << "The kind of value_impl(" << kind - << "), is bigger than BLOCK_ARG_IDX(7)"; - } + PADDLE_ENFORCE_LE( + kind, + BLOCK_ARG_IDX, + common::errors::PreconditionNotMet( + "The kind of value_impl[%u] must not bigger than BLOCK_ARG_IDX(7)", + kind)); type_ = type; first_use_offseted_by_kind_ = reinterpret_cast( reinterpret_cast(nullptr) + kind); diff --git a/paddle/pir/src/dialect/control_flow/ir/cf_op.cc b/paddle/pir/src/dialect/control_flow/ir/cf_op.cc index 3ead6991b272a..f7ad9b763f2cb 100644 --- a/paddle/pir/src/dialect/control_flow/ir/cf_op.cc +++ b/paddle/pir/src/dialect/control_flow/ir/cf_op.cc @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" +#include +#include "paddle/phi/core/enforce.h" + #include "paddle/pir/include/core/builtin_type.h" #include "paddle/pir/include/core/ir_printer.h" +#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h" #include "paddle/pir/include/dialect/control_flow/ir/cf_type.h" namespace pir { @@ -105,19 +108,29 @@ void TuplePopOp::VerifyRegion() { "The outlet value of cf.tuple_pop can only be used once."); // Verify stack validity: - auto pop_op = container_interface().tuple_pop_op(); - IR_ENFORCE(*this == pop_op, - "The pop_op of tuple_pop_op must be this tuple_pop_op self."); - - auto inlet_size = tuple_push_op().tuple_size(); - IR_ENFORCE(inlet_size == tuple_size(), - "The pop elements size must equal to push elements size."); - for (size_t index = 0; index < inlet_size; ++index) { - IR_ENFORCE(outlet_element(index).type() == inlet_element(index).type(), - "The %d element's push type (%s) isn't equal to pop type (%s)", - index, - outlet_element(index).type(), - inlet_element(index).type()); + if (has_container()) { + // can be verified only if TuplePopOp and TuplePushOp are in the same + // sub_program + auto pop_op = container_interface().tuple_pop_op(); + PADDLE_ENFORCE( + *this == pop_op, + phi::errors::InvalidArgument( + "The pop_op of tuple_pop_op must be this tuple_pop_op self.")); + + auto inlet_size = tuple_push_op().tuple_size(); + PADDLE_ENFORCE( + inlet_size == tuple_size(), + phi::errors::InvalidArgument( + "The pop elements size must equal to push elements size.")); + for (size_t index = 0; index < inlet_size; ++index) { + PADDLE_ENFORCE( + outlet_element(index).type() == inlet_element(index).type(), + phi::errors::InvalidArgument( + "The %d element's push type (%s) isn't equal to pop type (%s)", + index, + outlet_element(index).type(), + inlet_element(index).type())); + } } VLOG(4) << "End Verifying for TuplePopOp."; } diff --git a/paddle/pir/src/dialect/shape/utils/dim_expr.cc b/paddle/pir/src/dialect/shape/utils/dim_expr.cc index 618cb6914553c..cec9dab7f6e8e 100644 --- a/paddle/pir/src/dialect/shape/utils/dim_expr.cc +++ b/paddle/pir/src/dialect/shape/utils/dim_expr.cc @@ -14,6 +14,7 @@ #include "paddle/pir/include/dialect/shape/utils/dim_expr.h" #include "paddle/pir/include/core/utils.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace symbol { @@ -21,7 +22,8 @@ DimExpr DimExpr::operator+(const DimExpr& other) const { if (this->isa() && other.isa()) { return this->dyn_cast() + other.dyn_cast(); } - return Add{List{*this, other}}; + DimExpr add_expr = Add{List{*this, other}}; + return SimplifyDimExpr(add_expr); } DimExpr DimExpr::operator-(const DimExpr& other) const { @@ -29,14 +31,16 @@ DimExpr DimExpr::operator-(const DimExpr& other) const { return this->dyn_cast() - other.dyn_cast(); } const DimExpr& neg = Negative(other); - return Add{List{*this, neg}}; + DimExpr sub_expr = Add{List{*this, neg}}; + return SimplifyDimExpr(sub_expr); } DimExpr DimExpr::operator*(const DimExpr& other) const { if (this->isa() && other.isa()) { return this->dyn_cast() * other.dyn_cast(); } - return Mul{List{*this, other}}; + DimExpr mul_expr = Mul{List{*this, other}}; + return SimplifyDimExpr(mul_expr); } DimExpr DimExpr::operator/(const DimExpr& other) const { @@ -48,7 +52,8 @@ DimExpr DimExpr::operator/(const DimExpr& other) const { } } const DimExpr& reciprocal = Reciprocal(other); - return Mul{List{*this, reciprocal}}; + DimExpr div_expr = Mul{List{*this, reciprocal}}; + return SimplifyDimExpr(div_expr); } namespace { diff --git a/paddle/pir/src/dialect/shape/utils/dim_expr_builder.cc b/paddle/pir/src/dialect/shape/utils/dim_expr_builder.cc index cb49cdbf326fd..acdc65ebec24f 100644 --- a/paddle/pir/src/dialect/shape/utils/dim_expr_builder.cc +++ b/paddle/pir/src/dialect/shape/utils/dim_expr_builder.cc @@ -14,6 +14,7 @@ #include "paddle/pir/include/dialect/shape/utils/dim_expr_builder.h" #include "paddle/common/enforce.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" namespace symbol { @@ -44,15 +45,15 @@ DimExpr DimExprBuilder::Div(const DimExpr& lhs, const DimExpr& rhs) { } DimExpr DimExprBuilder::Max(const DimExpr& lhs, const DimExpr& rhs) { - return MaxDimExpr{List{lhs, rhs}}; + return SimplifyDimExpr(MaxDimExpr{List{lhs, rhs}}); } DimExpr DimExprBuilder::Min(const DimExpr& lhs, const DimExpr& rhs) { - return MinDimExpr{List{lhs, rhs}}; + return SimplifyDimExpr(MinDimExpr{List{lhs, rhs}}); } DimExpr DimExprBuilder::Broadcast(const DimExpr& lhs, const DimExpr& rhs) { - return BroadcastDimExpr{List{lhs, rhs}}; + return SimplifyDimExpr(BroadcastDimExpr{List{lhs, rhs}}); } std::vector DimExprBuilder::ConstShape( diff --git a/paddle/pir/src/dialect/shape/utils/dim_expr_simplify.cc b/paddle/pir/src/dialect/shape/utils/dim_expr_util.cc similarity index 73% rename from paddle/pir/src/dialect/shape/utils/dim_expr_simplify.cc rename to paddle/pir/src/dialect/shape/utils/dim_expr_util.cc index ca934941bcb72..9549d66893228 100644 --- a/paddle/pir/src/dialect/shape/utils/dim_expr_simplify.cc +++ b/paddle/pir/src/dialect/shape/utils/dim_expr_util.cc @@ -12,7 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/include/dialect/shape/utils/dim_expr_simplify.h" +#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h" + #include namespace symbol { @@ -45,7 +46,7 @@ struct SimplifyOneOperand { } else { return Op{ret_operand}; } - LOG(FATAL) << "Dead code."; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } }; @@ -70,7 +71,28 @@ struct SimplifyUnitOneOperand { } else { return expr; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); + } +}; + +/* + * Simplify Example: + * Negative(Negative(dim_expr)) => dim_expr + * Negative(int) => -int + */ +struct SimplifyDoubleNeg { + using dim_expr_type = Negative; + + DimExpr Rewrite(const DimExpr& expr) { + const auto& inner_expr = expr.Get>()->data; + if (inner_expr.Has>()) { + const auto& ret_expr = inner_expr.Get>()->data; + return ret_expr; + } else if (inner_expr.Has()) { + return -inner_expr.Get(); + } else { + return expr; + } } }; @@ -104,7 +126,7 @@ struct SimplifyOperands { } else { return Op{mut_operands}; } - LOG(FATAL) << "Dead code."; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } }; @@ -369,7 +391,7 @@ struct GetInversed { template <> struct GetInversed { static DimExpr Call(const DimExpr& expr) { - LOG(FATAL) << "Broadcast is not a group in math."; + PADDLE_THROW(phi::errors::Fatal("Broadcast is not a group in math.")); } }; @@ -442,7 +464,7 @@ struct FoldUnitConstant { } else { return Op{ret_operands}; } - LOG(FATAL) << "Dead code."; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } }; @@ -481,7 +503,7 @@ struct FoldConstants { } else { return Op{ret_operands}; } - LOG(FATAL) << "Dead code."; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } }; @@ -538,7 +560,7 @@ ConstRational SimplifiedConstRational(int64_t num, int64_t dem) { template std::optional GetConstRationalImpl(const T& expr) { - LOG(FATAL) << "not supported."; + PADDLE_THROW(phi::errors::Fatal("not supported.")); return std::nullopt; } @@ -607,7 +629,10 @@ struct FoldOperandTrait { List* ret) { const auto& [num, dem] = value; (*ret)->emplace_back(num); - CHECK_NE(dem, 0); + PADDLE_ENFORCE_NE(dem, + 0, + phi::errors::InvalidArgument( + "The denominator of rational can not be zero.")); if (dem != 1) { (*ret)->emplace_back(Reciprocal{DimExpr{dem}}); } @@ -643,7 +668,13 @@ struct FoldOperandTrait { if (*value == 1) { *value = expr_value; } else if (expr_value != 1) { - CHECK_EQ(*value, expr_value); + PADDLE_ENFORCE_EQ( + *value, + expr_value, + phi::errors::InvalidArgument("The value (%d) should be equel to expr " + "(%d) when they are both not 1.", + *value, + expr_value)); } else { // do nothing. } @@ -703,7 +734,7 @@ struct FoldInversedPairToUnit { } else { return Op{ret_operands}; } - LOG(FATAL) << "Dead code"; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } std::optional SearchInversedPair( @@ -757,7 +788,7 @@ struct FoldRedundantSymbolicBroadcast { } else { return Broadcast{ret_operands}; } - LOG(FATAL) << "Dead code."; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } std::optional SearchMaxInt64(const List& operands) { @@ -772,7 +803,15 @@ struct FoldRedundantSymbolicBroadcast { if (ret.has_value()) { if (int64_value > 1) { if (ret.value().value > 1) { - CHECK_EQ(ret.value().value, int64_value); + PADDLE_ENFORCE_EQ( + ret.value().value, + int64_value, + phi::errors::InvalidArgument( + "The value of return (%d) should be equel to expr (%d) of " + "operands at index (%d) when they are both > 1.", + ret.value().value, + int64_value, + i)); } ret = MaxInt64{int64_value, i}; } @@ -816,7 +855,7 @@ struct FoldRedundantBroadcast { } else { return Broadcast{ret_operands}; } - LOG(FATAL) << "Dead code."; + PADDLE_THROW(phi::errors::Fatal("Dead code.")); } std::optional SearchInversedPair( @@ -849,6 +888,7 @@ DimExpr Simplify(const DimExpr& expr) { DoPass>(&keep_rewrite, &ret); DoPass>(&keep_rewrite, &ret); DoPass>(&keep_rewrite, &ret); + DoPass(&keep_rewrite, &ret); DoPass>(&keep_rewrite, &ret); DoPass>(&keep_rewrite, &ret); DoPass>(&keep_rewrite, &ret); @@ -877,3 +917,197 @@ DimExpr Simplify(const DimExpr& expr) { DimExpr SimplifyDimExpr(const DimExpr& expr) { return Simplify(expr); } } // namespace symbol + +namespace symbol { + +namespace { + +class SubstituteDimExprHelper final { + public: + explicit SubstituteDimExprHelper( + const std::unordered_map& pattern_to_replacement) + : pattern_to_replacement_(pattern_to_replacement) {} + + std::optional Substitute(const DimExpr& dim_expr) { + auto iter = pattern_to_replacement_.find(dim_expr); + if (iter != pattern_to_replacement_.end()) return iter->second; + return std::visit([&](const auto& impl) { return SubstituteImpl(impl); }, + dim_expr.variant()); + } + + private: + std::optional SubstituteImpl(const std::int64_t& value) { + // `Substitute` has handled the case that `value` is matched. + return std::nullopt; + } + std::optional SubstituteImpl(const std::string& value) { + // `Substitute` has handled the case that `value` is matched. + return std::nullopt; + } + + std::optional SubstituteImpl(const Negative& dim_expr) { + return SubstituteUnary(dim_expr); + } + std::optional SubstituteImpl(const Reciprocal& dim_expr) { + return SubstituteUnary(dim_expr); + } + + template + std::optional SubstituteUnary(const T& dim_expr) { + const auto& operand = dim_expr->data; + const auto& substituted_operand = Substitute(operand); + if (!substituted_operand.has_value()) return std::nullopt; + return T{substituted_operand.value()}; + } + + std::optional SubstituteImpl(const Add& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Mul& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Max& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Min& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + std::optional SubstituteImpl(const Broadcast& dim_expr) { + return SubstituteVariadic(dim_expr); + } + + template