diff --git a/.bazelversion b/.bazelversion index 0b2eb36f5..fae6e3d04 100644 --- a/.bazelversion +++ b/.bazelversion @@ -1 +1 @@ -3.7.2 +4.2.1 diff --git a/.github/tools/release_linux.sh b/.github/tools/release_linux.sh index 6e26a1772..f1ee04254 100755 --- a/.github/tools/release_linux.sh +++ b/.github/tools/release_linux.sh @@ -9,7 +9,7 @@ bazel build :build_pip_pkg \ --copt=-mavx \ --distinct_host_configuration=false \ --verbose_failures \ - --crosstool_top=@org_tensorflow//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11.2:toolchain + --crosstool_top=//third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11:toolchain # Package Whl bazel-bin/build_pip_pkg artifacts diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9b6f06aaf..37dd6fd49 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -210,7 +210,7 @@ jobs: run: | docker run -e LCE_RELEASE_VERSION=${{ github.event.inputs.version }} \ -v ${PWD}:/compute-engine -w /compute-engine \ - tensorflow/build:latest-python${{ matrix.python-version }} \ + tensorflow/build:2.8-python${{ matrix.python-version }} \ .github/tools/release_linux.sh sudo apt-get -y -qq install patchelf --no-install-recommends @@ -228,7 +228,7 @@ jobs: windows-release-wheel: name: Build release wheels for Windows - runs-on: windows-latest + runs-on: windows-2019 strategy: matrix: python-version: [3.7, 3.8, 3.9] diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 6b4babaae..2cf449b92 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -25,8 +25,12 @@ jobs: run: pip install numpy --no-cache-dir - name: Run C++ Unit Tests run: bazelisk test larq_compute_engine/tests:cc_tests --copt=-O2 --distinct_host_configuration=false --test_output=all - - name: Build TF Lite Static Library with Make - run: larq_compute_engine/tflite/build_make/build_lce.sh --native + - name: Build TF Lite Static Library with CMake + run: | + mkdir build + cd build + cmake .. + make -j2 ARM: runs-on: ubuntu-latest @@ -83,7 +87,7 @@ jobs: if: github.ref != 'refs/heads/main' shell: bash - name: Install pip dependencies - run: pip install tensorflow-cpu~=2.6.2 larq~=0.11 larq_zoo~=2.0 pytest tensorflow_datasets~=4.2 flatbuffers tqdm --no-cache-dir + run: pip install tensorflow-cpu~=2.8.0 larq~=0.11 larq_zoo~=2.0 pytest tensorflow_datasets~=4.4 flatbuffers==1.12 tqdm --no-cache-dir - name: Run Interpreter test run: bazelisk test larq_compute_engine/tflite/tests:interpreter_test --test_output=all - name: Run FileCheck tests @@ -97,7 +101,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - tf-version: [1.14.0, 1.15.5, 2.0.4, 2.1.3, 2.2.2, 2.3.2, 2.4.1, 2.5.0, 2.6.2, 2.7.0] + tf-version: [1.14.0, 1.15.5, 2.0.4, 2.1.4, 2.2.3, 2.3.3, 2.4.4, 2.5.3, 2.6.3, 2.7.1, 2.8.0] if: "!contains(github.event.head_commit.message, 'ci-skip')" steps: - uses: actions/checkout@v3 @@ -105,7 +109,7 @@ jobs: with: python-version: 3.7 - name: Install dependencies - run: pip install tensorflow==${{matrix.tf-version}} larq~=0.11 larq_zoo~=2.0 tensorflow_datasets==1.3.2 packaging flatbuffers --no-cache-dir + run: pip install tensorflow==${{matrix.tf-version}} larq~=0.11 larq_zoo~=2.0 tensorflow_datasets==1.3.2 packaging flatbuffers==1.12 --no-cache-dir - name: Run Converter test run: PYTHONPATH=./ python larq_compute_engine/mlir/python/converter_test.py diff --git a/.gitignore b/.gitignore index 4e2c58656..3ea81c14a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,7 +16,8 @@ node_modules __pycache__ *.swp .vscode/ -cmake_build/ +cmake_build* +cmake-build* tensorflow/contrib/cmake/_build/ .idea/** /build/ diff --git a/.tensorflow.bazelrc b/.tensorflow.bazelrc index 1dcbb836a..c9b24bc2d 100644 --- a/.tensorflow.bazelrc +++ b/.tensorflow.bazelrc @@ -136,6 +136,7 @@ build:elinux_aarch64 --config=elinux build:elinux_aarch64 --cpu=aarch64 build:elinux_armhf --config=elinux build:elinux_armhf --cpu=armhf +build:elinux_armhf --copt -mfp16-format=ieee # Address sanitizer # CC=clang bazel build --config asan diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..e0f5d38ee --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,96 @@ +cmake_minimum_required(VERSION 3.16) +project(larq_compute_engine C CXX) + +# Options and their default values +option(COMPILE_EXAMPLE "Enable compilation of the minimal example" ON) +option(COMPILE_BENCHMARK "Enable compilation of the benchmarking utility" ON) + +# TensorFlow dependency, see https://www.tensorflow.org/lite/guide/build_cmake +set(TENSORFLOW_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/third_party/tensorflow/") +set(TFLITE_SOURCE_DIR "${TENSORFLOW_SOURCE_DIR}/tensorflow/lite") +add_subdirectory("${TFLITE_SOURCE_DIR}" "${CMAKE_CURRENT_BINARY_DIR}/tensorflow-lite" EXCLUDE_FROM_ALL) + +# Generic compilation options and settings +set(CMAKE_CXX_STANDARD 14) +include_directories(${CMAKE_CURRENT_LIST_DIR}) + +# The LCE core files +set(LCE_SOURCE_DIR "${CMAKE_CURRENT_LIST_DIR}/larq_compute_engine") +set(LCE_CORE_SRCS + ${LCE_SOURCE_DIR}/tflite/kernels/bconv2d.cc + ${LCE_SOURCE_DIR}/tflite/kernels/bmaxpool.cc + ${LCE_SOURCE_DIR}/tflite/kernels/quantization.cc + ) +set(LCE_CORE_HDRS # such that they can be discovered by IDEs such as CLion Visual Studio + ${LCE_SOURCE_DIR}/core/indirect_bgemm/kernel.h + ${LCE_SOURCE_DIR}/core/indirect_bgemm/kernel_4x2_portable.h + ${LCE_SOURCE_DIR}/core/indirect_bgemm/kernel_8x4x4_aarch64.h + ${LCE_SOURCE_DIR}/core/indirect_bgemm/kernel_8x4x1_aarch64.h + ${LCE_SOURCE_DIR}/core/indirect_bgemm/select_kernel.h + ${LCE_SOURCE_DIR}/core/indirect_bgemm/kernel_8x4x2_aarch64.h + ${LCE_SOURCE_DIR}/core/bmaxpool.h + ${LCE_SOURCE_DIR}/core/bitpacking/utils.h + ${LCE_SOURCE_DIR}/core/bitpacking/bitpack.h + ${LCE_SOURCE_DIR}/core/bitpacking/bitpack_aarch64.h + ${LCE_SOURCE_DIR}/core/types.h + ${LCE_SOURCE_DIR}/core/bconv2d/optimized_indirect_bgemm.h + ${LCE_SOURCE_DIR}/core/bconv2d/reference.h + ${LCE_SOURCE_DIR}/core/bconv2d/optimized_bgemm.h + ${LCE_SOURCE_DIR}/core/bconv2d/zero_padding_correction.h + ${LCE_SOURCE_DIR}/core/bconv2d/params.h + ${LCE_SOURCE_DIR}/core/bconv2d/output_transform.h + ${LCE_SOURCE_DIR}/core/bgemm/kernels_common.h + ${LCE_SOURCE_DIR}/core/bgemm/ruy_trmul_params.h + ${LCE_SOURCE_DIR}/core/bgemm/kernels_aarch64.h + ${LCE_SOURCE_DIR}/core/bgemm/kernels.h + ${LCE_SOURCE_DIR}/core/bgemm/ruy_pack.h + ${LCE_SOURCE_DIR}/core/bgemm/kernels_arm32.h + ${LCE_SOURCE_DIR}/core/bgemm/bgemm.h + ${LCE_SOURCE_DIR}/tflite/kernels/lce_ops_register.h + ${LCE_SOURCE_DIR}/tflite/kernels/utils.h + ) + +# The example application +if(COMPILE_EXAMPLE) + set(LCE_EXAMPLE_SRCS ${CMAKE_CURRENT_LIST_DIR}/examples/lce_minimal.cc) + add_executable(example ${LCE_CORE_SRCS} ${LCE_CORE_HDRS} ${LCE_EXAMPLE_SRCS}) + target_link_libraries(example tensorflow-lite) +endif() + +# The benchmarking binary +if(COMPILE_BENCHMARK) + set(LCE_BENCHMARK_SRCS + ${LCE_SOURCE_DIR}/tflite/benchmark/lce_benchmark_tflite_model.cc + ${LCE_SOURCE_DIR}/tflite/benchmark/lce_benchmark_main.cc + ) + set(LCE_BENCHMARK_HRDS + ${LCE_SOURCE_DIR}/tflite/benchmark/lce_benchmark_tflite_model.h + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_model.h + ) + set(TFLITE_BENCHMARK_SRCS # from ${TFLITE_SOURCE_DIR}/tools/benchmark/CMakeLists.txt + ${TENSORFLOW_SOURCE_DIR}/tensorflow/core/util/stats_calculator.cc + ${TFLITE_SOURCE_DIR}/kernels/internal/utils/sparsity_format_converter.cc + ${TFLITE_SOURCE_DIR}/profiling/memory_info.cc + ${TFLITE_SOURCE_DIR}/profiling/memory_usage_monitor.cc + ${TFLITE_SOURCE_DIR}/profiling/profile_summarizer.cc + ${TFLITE_SOURCE_DIR}/profiling/profile_summary_formatter.cc + ${TFLITE_SOURCE_DIR}/profiling/time.cc + ${TFLITE_SOURCE_DIR}/tools/command_line_flags.cc + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_model.cc + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_performance_options.cc + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_tflite_model.cc + ${TFLITE_SOURCE_DIR}/tools/benchmark/benchmark_utils.cc + ${TFLITE_SOURCE_DIR}/tools/benchmark/profiling_listener.cc + ${TFLITE_SOURCE_DIR}/tools/delegates/default_execution_provider.cc + ${TFLITE_SOURCE_DIR}/tools/delegates/delegate_provider.cc + ${TFLITE_SOURCE_DIR}/tools/delegates/xnnpack_delegate_provider.cc + ${TFLITE_SOURCE_DIR}/tools/evaluation/utils.cc + ${TFLITE_SOURCE_DIR}/tools/tool_params.cc + ) + add_executable(lce_benchmark_model + ${TFLITE_BENCHMARK_SRCS} + ${LCE_CORE_SRCS} ${LCE_CORE_HDRS} + ${LCE_BENCHMARK_SRCS} ${LCE_BENCHMARK_HRDS} + ) + target_link_libraries(lce_benchmark_model tensorflow-lite) +endif() diff --git a/WORKSPACE b/WORKSPACE index 742bbd253..b51870e52 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -15,11 +15,12 @@ http_archive( patch_tool = "patch", patches = [ "//third_party/tensorflow_patches:disable_forced_mkl.patch", + "//third_party/tensorflow_patches:fix_armhf_xnnpack.patch", ], - sha256 = "e68c1d346fc3d529653530ca346b2c62f5b31bd4fcca7ffc9c65bb39ab2f6ed3", - strip_prefix = "tensorflow-2.6.2", + sha256 = "66b953ae7fba61fd78969a2e24e350b26ec116cf2e6a7eb93d02c63939c6f9f7", + strip_prefix = "tensorflow-2.8.0", urls = [ - "https://github.com/tensorflow/tensorflow/archive/v2.6.2.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/v2.8.0.tar.gz", ], ) diff --git a/larq_compute_engine/mlir/BUILD b/larq_compute_engine/mlir/BUILD index 2de060bdc..cd1e4f6b3 100644 --- a/larq_compute_engine/mlir/BUILD +++ b/larq_compute_engine/mlir/BUILD @@ -1,5 +1,5 @@ -load("@org_tensorflow//third_party/mlir:tblgen.bzl", "gentbl", "td_library") load("@org_tensorflow//tensorflow:tensorflow.bzl", "pybind_extension", "tf_cc_binary") +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") package( default_visibility = ["//visibility:public"], @@ -8,169 +8,231 @@ package( td_library( name = "lce_ops_td_file", - srcs = [ - "ir/lce_ops.td", - ], + srcs = ["ir/lce_ops.td"], + includes = ["/external/org_tensorflow"], deps = [ "@llvm-project//mlir:SideEffectTdFiles", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", "@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", ], ) -gentbl( +td_library( + name = "fuse_padding_file", + srcs = ["transforms/fuse_padding.td"], + includes = ["/external/org_tensorflow"], + deps = [ + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_inc_gen", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + ], +) + +td_library( + name = "op_removal_patterns_file", + srcs = ["transforms/op_removal_patterns.td"], + includes = ["/external/org_tensorflow"], + deps = [ + "@llvm-project//mlir:StdOpsTdFiles", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + ], +) + +td_library( + name = "optimize_patterns_common_file", + srcs = ["transforms/optimize_patterns_common.td"], + includes = ["/external/org_tensorflow"], + deps = [ + ":lce_ops_td_file", + "@llvm-project//mlir:StdOpsTdFiles", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + ], +) + +td_library( + name = "prepare_patterns_common_file", + srcs = ["transforms/prepare_patterns_common.td"], + includes = ["/external/org_tensorflow"], + deps = [ + ":lce_ops_td_file", + "@llvm-project//mlir:StdOpsTdFiles", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + ], +) + +gentbl_cc_library( name = "lce_ops_inc_gen", tbl_outs = [ - ("-gen-enum-decls", "ir/lce_enum.h.inc"), - ("-gen-enum-defs", "ir/lce_enum.cc.inc"), - ("-gen-op-decls", "ir/lce_ops.h.inc"), - ("-gen-op-defs", "ir/lce_ops.cc.inc"), - ("-gen-dialect-decls -dialect=lq", "ir/lce_dialect.h.inc"), - ("-gen-dialect-doc", "g3doc/lce_ops.md"), + ( + ["-gen-enum-decls"], + "ir/lce_enum.h.inc", + ), + ( + ["-gen-enum-defs"], + "ir/lce_enum.cc.inc", + ), + ( + ["-gen-op-decls"], + "ir/lce_ops.h.inc", + ), + ( + ["-gen-op-defs"], + "ir/lce_ops.cc.inc", + ), + ( + [ + "-gen-dialect-decls", + "-dialect", + "lq", + ], + "ir/lce_dialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect", + "lq", + ], + "ir/lce_dialect.cc.inc", + ), + ( + ["-gen-dialect-doc"], + "g3doc/lce_ops.md", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/lce_ops.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - ":lce_ops_td_file", - ], + deps = [":lce_ops_td_file"], ) -gentbl( +gentbl_cc_library( name = "op_removal_lce_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_op_removal.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_op_removal.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/op_removal_patterns.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "@llvm-project//mlir:StdOpsTdFiles", - ], + deps = [":op_removal_patterns_file"], ) -gentbl( +gentbl_cc_library( name = "prepare_lce_target_arm_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_prepare_target_arm.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_prepare_target_arm.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_patterns_target_arm.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - ":lce_ops_td_file", - "transforms/op_removal_patterns.td", - "transforms/prepare_patterns_common.td", - "@llvm-project//mlir:StdOpsTdFiles", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", - ], + deps = [":prepare_patterns_common_file"], ) -gentbl( +gentbl_cc_library( name = "prepare_lce_target_other_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_prepare_target_other.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_prepare_target_other.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_patterns_common.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - ":lce_ops_td_file", - "transforms/op_removal_patterns.td", - "@llvm-project//mlir:StdOpsTdFiles", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", - ], + deps = [":prepare_patterns_common_file"], ) -gentbl( +gentbl_cc_library( name = "optimize_lce_target_arm_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_optimize_target_arm.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_optimize_target_arm.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_patterns_target_arm.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - ":lce_ops_td_file", - "transforms/optimize_patterns_common.td", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", - "@llvm-project//mlir:StdOpsTdFiles", - ], + deps = [":optimize_patterns_common_file"], ) -gentbl( +gentbl_cc_library( name = "optimize_lce_target_other_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_optimize_target_other.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_optimize_target_other.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_patterns_common.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - ":lce_ops_td_file", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", - "@llvm-project//mlir:StdOpsTdFiles", - ], + deps = [":optimize_patterns_common_file"], ) -gentbl( +gentbl_cc_library( name = "bitpack_activations_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_bitpack_activations.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_bitpack_activations.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/bitpack_activations_patterns.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ + deps = [ ":lce_ops_td_file", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", "@llvm-project//mlir:StdOpsTdFiles", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", ], ) -gentbl( +gentbl_cc_library( name = "bitpack_weights_lce_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_bitpack_weights.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_bitpack_weights.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/bitpack_weights_patterns.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ + deps = [ ":lce_ops_td_file", - "transforms/op_removal_patterns.td", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", "@llvm-project//mlir:StdOpsTdFiles", + "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", ], ) -gentbl( +gentbl_cc_library( name = "quantize_lce_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_quantize.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_quantize.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_patterns.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ + deps = [ ":lce_ops_td_file", "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", ], ) -gentbl( +gentbl_cc_library( name = "fuse_padding_inc_gen", tbl_outs = [ - ("-gen-rewriters", "transforms/generated_fuse_padding.inc"), + ( + ["-gen-rewriters"], + "transforms/generated_fuse_padding.inc", + ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/fuse_padding.td", - td_includes = ["external/org_tensorflow"], - td_srcs = [ - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite_ops_td_files", + deps = [ + ":fuse_padding_file", ], ) @@ -193,6 +255,7 @@ cc_library( cc_library( name = "larq_compute_engine", srcs = [ + "ir/lce_dialect.cc.inc", "ir/lce_dialect.h.inc", "ir/lce_enum.cc.inc", "ir/lce_enum.h.inc", @@ -398,7 +461,6 @@ cc_library( "@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config", "@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_passes", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:decode_constant_pass", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tensorflow_passes", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_passes", ], @@ -411,12 +473,16 @@ cc_library( "tf_to_tfl_flatbuffer.h", ], deps = [ + ":lce_tfl_passes", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@org_tensorflow//tensorflow/compiler/mlir:op_or_arg_name_mapper", "@org_tensorflow//tensorflow/compiler/mlir/lite:flatbuffer_export", + "@org_tensorflow//tensorflow/compiler/mlir/lite/metrics:error_collector", + "@org_tensorflow//tensorflow/compiler/mlir/lite/quantization:quantization_config", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:error_util", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_saved_model_freeze_variables", "@org_tensorflow//tensorflow/stream_executor/lib", ], ) @@ -426,11 +492,7 @@ cc_library( srcs = ["python/common.cc"], hdrs = ["python/common.h"], deps = [ - ":lce_tfl_passes", ":tf_to_tfl_flatbuffer", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite", - "@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", - "@org_tensorflow//tensorflow/compiler/mlir/lite/python:tf_tfl_flatbuffer_helpers", "@org_tensorflow//tensorflow/core:ops", "@pybind11", ], @@ -461,7 +523,7 @@ pybind_extension( "@org_tensorflow//tensorflow/compiler/mlir/lite:tensorflow_lite", "@org_tensorflow//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "@org_tensorflow//tensorflow/compiler/mlir/lite/python:tf_tfl_flatbuffer_helpers", - "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:export_graphdef", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:import_utils", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "@org_tensorflow//tensorflow/compiler/mlir/tensorflow:tf_dialect_passes", diff --git a/larq_compute_engine/mlir/ir/lce_ops.cc b/larq_compute_engine/mlir/ir/lce_ops.cc index f9cc1d1ff..1dc143c5d 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.cc +++ b/larq_compute_engine/mlir/ir/lce_ops.cc @@ -3,8 +3,12 @@ #include "flatbuffers/flexbuffers.h" #include "larq_compute_engine/core/bitpacking/bitpack.h" #include "larq_compute_engine/mlir/transforms/bitpack.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "tensorflow/lite/schema/schema_generated.h" +// Generated dialect defs. +#include "larq_compute_engine/mlir/ir/lce_dialect.cc.inc" + #define GET_OP_CLASSES #include "larq_compute_engine/mlir/ir/lce_enum.cc.inc" #include "larq_compute_engine/mlir/ir/lce_ops.cc.inc" @@ -73,5 +77,12 @@ void LarqDialect::initialize() { #include "larq_compute_engine/mlir/ir/lce_ops.cc.inc" >(); } + +Operation* LarqDialect::materializeConstant(OpBuilder& builder, Attribute value, + Type type, Location loc) { + if (arith::ConstantOp::isBuildableWith(value, type)) + return builder.create(loc, type, value); + return nullptr; +} } // namespace lq } // namespace mlir diff --git a/larq_compute_engine/mlir/ir/lce_ops.td b/larq_compute_engine/mlir/ir/lce_ops.td index c763ae065..9db54dab4 100644 --- a/larq_compute_engine/mlir/ir/lce_ops.td +++ b/larq_compute_engine/mlir/ir/lce_ops.td @@ -46,6 +46,8 @@ def LarqDialect : Dialect { TF graphs to be deployed on Larq Compute Engine. }]; + let hasConstantMaterializer = 1; + let cppNamespace = "::mlir::lq"; } diff --git a/larq_compute_engine/mlir/lce_mlir_opt.cc b/larq_compute_engine/mlir/lce_mlir_opt.cc index 367805181..a1ef47e25 100644 --- a/larq_compute_engine/mlir/lce_mlir_opt.cc +++ b/larq_compute_engine/mlir/lce_mlir_opt.cc @@ -9,9 +9,9 @@ int main(int argc, char** argv) { mlir::registerTransformsPasses(); mlir::DialectRegistry registry; - registry.insert(); + registry.insert(); return failed(mlir::MlirOptMain(argc, argv, "Larq Compute Engine pass driver\n", registry, /*preloadDialectsInContext=*/false)); diff --git a/larq_compute_engine/mlir/python/common.cc b/larq_compute_engine/mlir/python/common.cc index 0934e652f..9f3c2548b 100644 --- a/larq_compute_engine/mlir/python/common.cc +++ b/larq_compute_engine/mlir/python/common.cc @@ -18,14 +18,9 @@ limitations under the License. #include -#include "larq_compute_engine/mlir/tf_tfl_passes.h" #include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h" -#include "larq_compute_engine/mlir/transforms/passes.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/Pass/Pass.h" #include "pybind11/pybind11.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/status.h" @@ -77,8 +72,10 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs) { pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( mlir::OwningModuleRef* module, mlir::MLIRContext& context, const LCETarget target, const pybind11::object& default_ranges, - const int num_inputs, const bool should_quantize, - const bool mark_as_post_training_quant) { + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, const int num_inputs, + const bool should_quantize, const bool mark_as_post_training_quant) { mlir::TFL::QuantizationSpecs quant_specs; if (should_quantize) { // Normally we'd only set `inference_type` to QINT8 when there are @@ -118,18 +115,10 @@ pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( } } - mlir::PassManager pm(&context, mlir::OpPassManager::Nesting::Implicit); - tensorflow::SetCrashReproducer(pm); - - tensorflow::AddTFToLCETFLConversionPasses(quant_specs, &pm, target); - - // Convert back to outlined while format for export back to flatbuffer. - pm.addPass(mlir::TFL::CreateWhileOutlinePass()); - pm.addPass(mlir::TFL::CreateRuntimeVerifyPass()); - std::string result; - auto status = ConvertTFExecutorToFlatbuffer( - module->get(), /*export_to_mlir=*/false, &result, &pm); + auto status = ConvertTFExecutorToTFLOrFlatbuffer( + module->get(), /*export_to_mlir=*/false, target, quant_specs, + saved_model_tags, saved_model_dir, session, &result); if (!status.ok()) { throw std::runtime_error("Could not translate to flatbuffer."); diff --git a/larq_compute_engine/mlir/python/common.h b/larq_compute_engine/mlir/python/common.h index 19403d14f..72e661191 100644 --- a/larq_compute_engine/mlir/python/common.h +++ b/larq_compute_engine/mlir/python/common.h @@ -3,6 +3,7 @@ #include "mlir/Pass/Pass.h" #include "pybind11/pybind11.h" #include "tensorflow/core/platform/status.h" +#include "tensorflow/core/public/session.h" namespace tensorflow { @@ -13,7 +14,9 @@ Status GetNumInputs(mlir::OwningModuleRef* module, int* num_inputs); pybind11::bytes ConvertMLIRModuleToTFLiteFlatBuffer( mlir::OwningModuleRef* module, mlir::MLIRContext& context, const LCETarget target, const pybind11::object& default_ranges, - const int num_inputs, const bool should_quantize, - const bool mark_as_post_training_quant); + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, const int num_inputs, + const bool should_quantize, const bool mark_as_post_training_quant); } // namespace tensorflow diff --git a/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc b/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc index ec2427ca9..d630aabe7 100644 --- a/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/python/graphdef_tfl_flatbuffer.cc @@ -64,7 +64,9 @@ pybind11::bytes ConvertGraphDefToTFLiteFlatBuffer( return ConvertMLIRModuleToTFLiteFlatBuffer( &module.ValueOrDie(), context, target, default_ranges, - input_arrays.size(), should_quantize, + /*saved_model_tags=*/{}, + /*saved_model_dir=*/"", /*session=*/llvm::None, input_arrays.size(), + should_quantize, /*mark_as_post_training_quant=*/false); } diff --git a/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc b/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc index 0e36eb197..2247bd3c7 100644 --- a/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/python/saved_model_tfl_flatbuffer.cc @@ -42,8 +42,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer( auto target = GetLCETarget(target_str); - if (exported_names.size() != 1) { - throw std::runtime_error("Only a single exported name is supported."); + if (exported_names.empty()) { + throw std::runtime_error("Need at least one exported name."); } tensorflow::GraphImportConfig specs; @@ -84,7 +84,8 @@ pybind11::bytes ConvertSavedModelToTFLiteFlatBuffer( } return ConvertMLIRModuleToTFLiteFlatBuffer( - &module.ValueOrDie(), context, target, default_ranges, num_inputs, + &module.ValueOrDie(), context, target, default_ranges, tags, + saved_model_dir, bundle ? bundle->GetSession() : nullptr, num_inputs, /*should_quantize=*/true, /*mark_as_post_training_quant=*/true); } diff --git a/larq_compute_engine/mlir/tests/bitpack-weights.mlir b/larq_compute_engine/mlir/tests/bitpack-weights.mlir index 9a184be0d..86fbdd696 100644 --- a/larq_compute_engine/mlir/tests/bitpack-weights.mlir +++ b/larq_compute_engine/mlir/tests/bitpack-weights.mlir @@ -2,11 +2,11 @@ // CHECK-LABEL: @bitpack_bconv2d_filters func @bitpack_bconv2d_filters(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> { - %cst = constant dense<1.0> : tensor<16x3x3x3xf32> + %cst = arith.constant dense<1.0> : tensor<16x3x3x3xf32> %0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> return %0 : tensor<256x30x30x16xf32> - // CHECK: %cst = constant dense<0> : tensor<16x3x3x1xi32> + // CHECK: %cst = arith.constant dense<0> : tensor<16x3x3x1xi32> // CHECK: %0 = "lq.Bconv2d"(%arg0, %cst, %arg1, %arg2, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x1xi32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> // CHECK-NEXT: return %0 } diff --git a/larq_compute_engine/mlir/tests/const-fold.mlir b/larq_compute_engine/mlir/tests/const-fold.mlir index 08bee5f4d..034f51360 100644 --- a/larq_compute_engine/mlir/tests/const-fold.mlir +++ b/larq_compute_engine/mlir/tests/const-fold.mlir @@ -2,26 +2,26 @@ // CHECK-LABEL: @quantize func @quantize() -> (tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32>) { - %pos = constant dense< 0.5> : tensor<1x1x2x32xf32> - %neg = constant dense<-0.5> : tensor<1x1x2x32xf32> + %pos = arith.constant dense< 0.5> : tensor<1x1x2x32xf32> + %neg = arith.constant dense<-0.5> : tensor<1x1x2x32xf32> %0 = "lq.Quantize"(%pos) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32> %1 = "lq.Quantize"(%neg) {} : (tensor<1x1x2x32xf32>) -> tensor<1x1x2x1xi32> return %0, %1 : tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32> - // CHECK: %[[neg:.*]] = constant dense<-1> : tensor<1x1x2x1xi32> - // CHECK: %[[pos:.*]] = constant dense<0> : tensor<1x1x2x1xi32> + // CHECK: %[[neg:.*]] = arith.constant dense<-1> : tensor<1x1x2x1xi32> + // CHECK: %[[pos:.*]] = arith.constant dense<0> : tensor<1x1x2x1xi32> // CHECK: return %[[pos]], %[[neg]] : tensor<1x1x2x1xi32>, tensor<1x1x2x1xi32> } // CHECK-LABEL: @dequantize func @dequantize() -> (tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32>) { - %pos = constant dense<0> : tensor<1x1x2x1xi32> - %neg = constant dense<-1> : tensor<1x1x2x1xi32> + %pos = arith.constant dense<0> : tensor<1x1x2x1xi32> + %neg = arith.constant dense<-1> : tensor<1x1x2x1xi32> %0 = "lq.Dequantize"(%pos) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32> %1 = "lq.Dequantize"(%neg) {} : (tensor<1x1x2x1xi32>) -> tensor<1x1x2x32xf32> return %0, %1 : tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32> - // CHECK: %[[neg:.*]] = constant dense<-1.000000e+00> : tensor<1x1x2x32xf32> - // CHECK: %[[pos:.*]] = constant dense<1.000000e+00> : tensor<1x1x2x32xf32> + // CHECK: %[[neg:.*]] = arith.constant dense<-1.000000e+00> : tensor<1x1x2x32xf32> + // CHECK: %[[pos:.*]] = arith.constant dense<1.000000e+00> : tensor<1x1x2x32xf32> // CHECK: return %[[pos]], %[[neg]] : tensor<1x1x2x32xf32>, tensor<1x1x2x32xf32> } diff --git a/larq_compute_engine/mlir/tests/fuse_padding.mlir b/larq_compute_engine/mlir/tests/fuse_padding.mlir index 26a573ea2..10e9544d2 100644 --- a/larq_compute_engine/mlir/tests/fuse_padding.mlir +++ b/larq_compute_engine/mlir/tests/fuse_padding.mlir @@ -2,9 +2,9 @@ // CHECK-LABEL: @fuse_pad_into_conv_valid func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> { - %cst0 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> - %cst1 = constant dense<1.0> : tensor<16x3x3x8xf32> - %cst2 = constant dense<1.0> : tensor<16xf32> + %cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst1 = arith.constant dense<1.0> : tensor<16x3x3x8xf32> + %cst2 = arith.constant dense<1.0> : tensor<16xf32> %0 = "tfl.pad"(%arg0, %cst0) : (tensor<1x64x64x8xf32>, tensor<4x2xi32>) -> tensor<1x66x66x8xf32> %1 = "tfl.conv_2d"(%0, %cst1, %cst2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x8xf32>, tensor<16x3x3x8xf32>, tensor<16xf32>) -> tensor<1x64x64x16xf32> return %1 : tensor<1x64x64x16xf32> @@ -15,10 +15,10 @@ func @fuse_pad_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x1 // CHECK-LABEL: @fuse_padv2_into_conv_valid func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> { - %cst0 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> - %cst1 = constant dense<0.0> : tensor - %cst2 = constant dense<1.0> : tensor<16x3x3x8xf32> - %cst3 = constant dense<1.0> : tensor<16xf32> + %cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst1 = arith.constant dense<0.0> : tensor + %cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32> + %cst3 = arith.constant dense<1.0> : tensor<16xf32> %0 = "tfl.padv2"(%arg0, %cst0, %cst1) : (tensor<1x64x64x8xf32>, tensor<4x2xi32>, tensor) -> tensor<1x66x66x8xf32> %1 = "tfl.conv_2d"(%0, %cst2, %cst3) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x8xf32>, tensor<16x3x3x8xf32>, tensor<16xf32>) -> tensor<1x64x64x16xf32> return %1 : tensor<1x64x64x16xf32> @@ -29,9 +29,9 @@ func @fuse_padv2_into_conv_valid(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64 // CHECK-LABEL: @fuse_pad_into_dwconv_valid func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x64x16xf32> { - %cst0 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> - %cst1 = constant dense<1.0> : tensor<1x3x3x16xf32> - %cst2 = constant dense<1.0> : tensor<16xf32> + %cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32> + %cst2 = arith.constant dense<1.0> : tensor<16xf32> %0 = "tfl.pad"(%arg0, %cst0) : (tensor<1x64x64x16xf32>, tensor<4x2xi32>) -> tensor<1x66x66x16xf32> %1 = "tfl.depthwise_conv_2d"(%0, %cst1, %cst2) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x16xf32>, tensor<1x3x3x16xf32>, tensor<16xf32>) -> tensor<1x64x64x16xf32> return %1 : tensor<1x64x64x16xf32> @@ -42,10 +42,10 @@ func @fuse_pad_into_dwconv_valid(%arg0: tensor<1x64x64x16xf32>) -> tensor<1x64x6 // CHECK-LABEL: @do_not_fuse_padv2_into_conv_wrong_pad_value func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x64x64x16xf32> { - %cst0 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> - %cst1 = constant dense<1.0> : tensor - %cst2 = constant dense<1.0> : tensor<16x3x3x8xf32> - %cst3 = constant dense<1.0> : tensor<16xf32> + %cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst1 = arith.constant dense<1.0> : tensor + %cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32> + %cst3 = arith.constant dense<1.0> : tensor<16xf32> %0 = "tfl.padv2"(%arg0, %cst0, %cst1) : (tensor<1x64x64x8xf32>, tensor<4x2xi32>, tensor) -> tensor<1x66x66x8xf32> %1 = "tfl.conv_2d"(%0, %cst2, %cst3) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x8xf32>, tensor<16x3x3x8xf32>, tensor<16xf32>) -> tensor<1x64x64x16xf32> return %1 : tensor<1x64x64x16xf32> @@ -55,10 +55,10 @@ func @do_not_fuse_padv2_into_conv_wrong_pad_value(%arg0: tensor<1x64x64x8xf32>) // CHECK-LABEL: @do_not_fuse_pad_into_conv_same func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x66x66x16xf32> { - %cst0 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> - %cst1 = constant dense<1.0> : tensor - %cst2 = constant dense<1.0> : tensor<16x3x3x8xf32> - %cst3 = constant dense<1.0> : tensor<16xf32> + %cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst1 = arith.constant dense<1.0> : tensor + %cst2 = arith.constant dense<1.0> : tensor<16x3x3x8xf32> + %cst3 = arith.constant dense<1.0> : tensor<16xf32> %0 = "tfl.padv2"(%arg0, %cst0, %cst1) : (tensor<1x64x64x8xf32>, tensor<4x2xi32>, tensor) -> tensor<1x66x66x8xf32> %1 = "tfl.conv_2d"(%0, %cst2, %cst3) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x8xf32>, tensor<16x3x3x8xf32>, tensor<16xf32>) -> tensor<1x66x66x16xf32> return %1 : tensor<1x66x66x16xf32> @@ -68,9 +68,9 @@ func @do_not_fuse_pad_into_conv_same(%arg0: tensor<1x64x64x8xf32>) -> tensor<1x6 // CHECK-LABEL: @do_not_fuse_pad_into_dwconv_channelpad func @do_not_fuse_pad_into_dwconv_channelpad(%arg0: tensor<1x64x64x12xf32>) -> tensor<1x64x64x16xf32> { - %cst0 = constant dense<[[0, 0], [1, 1], [1, 1], [1, 3]]> : tensor<4x2xi32> - %cst1 = constant dense<1.0> : tensor<1x3x3x16xf32> - %cst2 = constant dense<1.0> : tensor<16xf32> + %cst0 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [1, 3]]> : tensor<4x2xi32> + %cst1 = arith.constant dense<1.0> : tensor<1x3x3x16xf32> + %cst2 = arith.constant dense<1.0> : tensor<16xf32> %0 = "tfl.pad"(%arg0, %cst0) : (tensor<1x64x64x12xf32>, tensor<4x2xi32>) -> tensor<1x66x66x16xf32> %1 = "tfl.depthwise_conv_2d"(%0, %cst1, %cst2) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x66x66x16xf32>, tensor<1x3x3x16xf32>, tensor<16xf32>) -> tensor<1x64x64x16xf32> return %1 : tensor<1x64x64x16xf32> diff --git a/larq_compute_engine/mlir/tests/optimize.mlir b/larq_compute_engine/mlir/tests/optimize.mlir index 1bc967335..a5b0c3e8c 100644 --- a/larq_compute_engine/mlir/tests/optimize.mlir +++ b/larq_compute_engine/mlir/tests/optimize.mlir @@ -3,7 +3,7 @@ // CHECK-LABEL: @optimize_quantize_greater_equal_zero func @optimize_quantize_greater_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> { - %cst = constant dense<0.0> : tensor + %cst = arith.constant dense<0.0> : tensor %0 = "tfl.greater_equal"(%arg0, %cst) : (tensor<48x48x64xf32>, tensor) -> tensor<48x48x64xi1> %1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32> return %1 : tensor<48x48x2xi32> @@ -25,7 +25,7 @@ func @optimize_quantize_greater_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg // CHECK-LABEL: @optimize_quantize_less_equal_zero func @optimize_quantize_less_equal_zero(%arg0: tensor<48x48x64xf32>) -> tensor<48x48x2xi32> { - %cst = constant dense<0.0> : tensor<64xf32> + %cst = arith.constant dense<0.0> : tensor<64xf32> %0 = "tfl.less_equal"(%cst, %arg0) : (tensor<64xf32>, tensor<48x48x64xf32>) -> tensor<48x48x64xi1> %1 = "lq.Quantize"(%0) : (tensor<48x48x64xi1>) -> tensor<48x48x2xi32> return %1 : tensor<48x48x2xi32> @@ -47,64 +47,64 @@ func @optimize_quantize_less_equal_non_zero(%arg0: tensor<48x48x64xf32>, %arg1: // CHECK-LABEL: @fuse_add_into_bconv2d func @fuse_add_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> { - %cst = constant dense<1.5> : tensor<16xf32> - %post_activation_bias = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst = arith.constant dense<1.5> : tensor<16xf32> + %post_activation_bias = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %post_activation_bias, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.add"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK-NEXT: %cst = constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> + // CHECK-NEXT: %cst = arith.constant dense<[2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01, 1.650000e+01, 1.750000e+01]> // CHECK-NEXT: %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %cst, %arg3) // CHECK-NEXT: return %0 } // CHECK-LABEL: @fuse_sub_into_bconv2d func @fuse_sub_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>, %arg3: none) -> tensor<256x30x30x16xf32> { - %cst = constant dense<0.5> : tensor<16xf32> - %post_activation_bias = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst = arith.constant dense<0.5> : tensor<16xf32> + %post_activation_bias = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %post_activation_bias, %arg3) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.sub"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK-NEXT: %cst = constant dense<[5.000000e-01, 1.500000e+00, 2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01]> + // CHECK-NEXT: %cst = arith.constant dense<[5.000000e-01, 1.500000e+00, 2.500000e+00, 3.500000e+00, 4.500000e+00, 5.500000e+00, 6.500000e+00, 7.500000e+00, 8.500000e+00, 9.500000e+00, 1.050000e+01, 1.150000e+01, 1.250000e+01, 1.350000e+01, 1.450000e+01, 1.550000e+01]> // CHECK-NEXT: %0 = "lq.Bconv2d"(%arg0, %arg1, %arg2, %cst, %arg3) // CHECK-NEXT: return %0 } // CHECK-LABEL: @fuse_div_into_bconv2d func @fuse_div_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %cst = constant dense<0.5> : tensor<16xf32> - %post_activation_bias = constant dense<1.5> : tensor<16xf32> - %post_activation_multiplier = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst = arith.constant dense<0.5> : tensor<16xf32> + %post_activation_bias = arith.constant dense<1.5> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.div"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK-NEXT: %cst = constant dense<[2.000000e+00, 4.000000e+00, 6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01, 1.400000e+01, 1.600000e+01, 1.800000e+01, 2.000000e+01, 2.200000e+01, 2.400000e+01, 2.600000e+01, 2.800000e+01, 3.000000e+01, 3.200000e+01]> - // CHECK-NEXT: %cst_0 = constant dense<3.000000e+00> + // CHECK-NEXT: %cst = arith.constant dense<[2.000000e+00, 4.000000e+00, 6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01, 1.400000e+01, 1.600000e+01, 1.800000e+01, 2.000000e+01, 2.200000e+01, 2.400000e+01, 2.600000e+01, 2.800000e+01, 3.000000e+01, 3.200000e+01]> + // CHECK-NEXT: %cst_0 = arith.constant dense<3.000000e+00> // CHECK-NEXT: %0 = "lq.Bconv2d"(%arg0, %arg1, %cst, %cst_0, %arg2) // CHECK-NEXT: return %0 } // CHECK-LABEL: @fuse_mul_into_bconv2d func @fuse_mul_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %cst = constant dense<2.0> : tensor<16xf32> - %post_activation_bias = constant dense<1.5> : tensor<16xf32> - %post_activation_multiplier = constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> + %cst = arith.constant dense<2.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<1.5> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.mul"(%0, %cst) {fused_activation_function = "NONE"} : (tensor<256x30x30x16xf32>, tensor<16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> - // CHECK-NEXT: %cst = constant dense<[2.000000e+00, 4.000000e+00, 6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01, 1.400000e+01, 1.600000e+01, 1.800000e+01, 2.000000e+01, 2.200000e+01, 2.400000e+01, 2.600000e+01, 2.800000e+01, 3.000000e+01, 3.200000e+01]> - // CHECK-NEXT: %cst_0 = constant dense<3.000000e+00> + // CHECK-NEXT: %cst = arith.constant dense<[2.000000e+00, 4.000000e+00, 6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01, 1.400000e+01, 1.600000e+01, 1.800000e+01, 2.000000e+01, 2.200000e+01, 2.400000e+01, 2.600000e+01, 2.800000e+01, 3.000000e+01, 3.200000e+01]> + // CHECK-NEXT: %cst_0 = arith.constant dense<3.000000e+00> // CHECK-NEXT: %0 = "lq.Bconv2d"(%arg0, %arg1, %cst, %cst_0, %arg2) // CHECK-NEXT: return %0 } // CHECK-LABEL: @fuse_relu_into_bconv2d func @fuse_relu_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> - %post_activation_bias = constant dense<0.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<1.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<0.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> @@ -115,8 +115,8 @@ func @fuse_relu_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x // CHECK-LABEL: @fuse_relu6_into_bconv2d func @fuse_relu6_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> - %post_activation_bias = constant dense<0.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<1.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<0.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.relu6"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> @@ -127,8 +127,8 @@ func @fuse_relu6_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3 // CHECK-LABEL: @fuse_relu1_into_bconv2d func @fuse_relu1_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> - %post_activation_bias = constant dense<0.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<1.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<0.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.relu_n1_to_1"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> @@ -139,8 +139,8 @@ func @fuse_relu1_into_bconv2d(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3 // CHECK-LABEL: @fuse_relu_into_bconv2d_padding_same_one func @fuse_relu_into_bconv2d_padding_same_one(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x32x32x16xf32> { - %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> - %post_activation_bias = constant dense<0.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<1.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<0.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x32x32x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> return %1 : tensor<256x32x32x16xf32> @@ -151,8 +151,8 @@ func @fuse_relu_into_bconv2d_padding_same_one(%arg0: tensor<256x32x32x1xi32>, %a // CHECK-LABEL: @do_not_fuse_relu_into_bconv2d_padding_same_zero func @do_not_fuse_relu_into_bconv2d_padding_same_zero(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x32x32x16xf32> { - %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> - %post_activation_bias = constant dense<0.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<1.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<0.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x32x32x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x32x32x16xf32>) -> tensor<256x32x32x16xf32> return %1 : tensor<256x32x32x16xf32> @@ -164,8 +164,8 @@ func @do_not_fuse_relu_into_bconv2d_padding_same_zero(%arg0: tensor<256x32x32x1x // CHECK-LABEL: @do_not_fuse_relu_into_bconv2d_no_post_activation_bias func @do_not_fuse_relu_into_bconv2d_no_post_activation_bias(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %post_activation_multiplier = constant dense<1.0> : tensor<16xf32> - %post_activation_bias = constant dense<5.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<1.0> : tensor<16xf32> + %post_activation_bias = arith.constant dense<5.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> @@ -177,8 +177,8 @@ func @do_not_fuse_relu_into_bconv2d_no_post_activation_bias(%arg0: tensor<256x32 // CHECK-LABEL: @do_not_fuse_relu_into_bconv2d_no_post_activation_multiplier func @do_not_fuse_relu_into_bconv2d_no_post_activation_multiplier(%arg0: tensor<256x32x32x1xi32>, %arg1: tensor<16x3x3x3xf32>, %arg2: none) -> tensor<256x30x30x16xf32> { - %post_activation_multiplier = constant dense<0.8> : tensor<16xf32> - %post_activation_bias = constant dense<0.0> : tensor<16xf32> + %post_activation_multiplier = arith.constant dense<0.8> : tensor<16xf32> + %post_activation_bias = arith.constant dense<0.0> : tensor<16xf32> %0 = "lq.Bconv2d"(%arg0, %arg1, %post_activation_multiplier, %post_activation_bias, %arg2) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x30x30x16xf32> %1 = "tfl.relu"(%0) : (tensor<256x30x30x16xf32>) -> tensor<256x30x30x16xf32> return %1 : tensor<256x30x30x16xf32> @@ -216,9 +216,9 @@ func @do_not_reorder_maxpool_2d_quantize_multiple_uses(%arg0: tensor<256x32x32x6 // CHECK-LABEL: @bitpack_activation_thresholds_with_negative_post_multipliers func @bitpack_activation_thresholds_with_negative_post_multipliers(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x32x32x1xi32> { - %filter = constant dense<1.0> : tensor<8x2x2x1xf32> - %post_activation_multiplier = constant dense<[-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]> : tensor<8xf32> - %post_activation_bias = constant dense<[-10.0, 8.0, 0.4, 1.0, -0.01, 0.5, -1.0, 2.71]> : tensor<8xf32> + %filter = arith.constant dense<1.0> : tensor<8x2x2x1xf32> + %post_activation_multiplier = arith.constant dense<[-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]> : tensor<8xf32> + %post_activation_bias = arith.constant dense<[-10.0, 8.0, 0.4, 1.0, -0.01, 0.5, -1.0, 2.71]> : tensor<8xf32> %cst = constant unit %0 = "lq.Bconv2d"(%arg0, %filter, %post_activation_multiplier, %post_activation_bias, %cst) {channels_in = 1 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<8x2x2x1xf32>, tensor<8xf32>, tensor<8xf32>, none) -> tensor<256x32x32x8xf32> %1 = "lq.Quantize"(%0) : (tensor<256x32x32x8xf32>) -> tensor<256x32x32x1xi32> @@ -228,13 +228,13 @@ func @bitpack_activation_thresholds_with_negative_post_multipliers(%arg0: tensor // have had their signs flipped. The syntax is very ugly because filecheck // tries to perform string substitution when it encounters `[[` and so we have // to wrap the whole thing in a regex, and escape the square brackets. - // CHECK: %cst = constant {{dense<\[\[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\]\]>}} : tensor<8x2x2x1xf32> + // CHECK: %cst = arith.constant {{dense<\[\[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[-1.000000e\+00\], \[-1.000000e\+00\]\], \[\[-1.000000e\+00\], \[-1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\], \[\[\[1.000000e\+00\], \[1.000000e\+00\]\], \[\[1.000000e\+00\], \[1.000000e\+00\]\]\]\]>}} : tensor<8x2x2x1xf32> // The `none` value that will replace the post-multiplier and post-bias. // CHECK-NEXT: %cst_0 = constant unit // Verify correct thresholds. These have been manually computed. - // CHECK-NEXT: %cst_1 = constant dense<[0, 3, 2, 2, -2147483648, 2, 1, 2]> : tensor<8xi32> + // CHECK-NEXT: %cst_1 = arith.constant dense<[0, 3, 2, 2, -2147483648, 2, 1, 2]> : tensor<8xi32> // CHECK-NEXT: %0 = "lq.Bconv2d"(%arg0, %cst, %cst_0, %cst_0, %cst_1) {channels_in = 1 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<8x2x2x1xf32>, none, none, tensor<8xi32>) -> tensor<256x32x32x1xi32> // CHECK-NEXT: return %0 @@ -242,9 +242,9 @@ func @bitpack_activation_thresholds_with_negative_post_multipliers(%arg0: tensor // CHECK-LABEL: @bitpack_activations_valid_padding func @bitpack_activations_valid_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x30x30x3xi32> { - %filter = constant dense<1.0> : tensor<65x3x3x3xf32> - %post_activation_multiplier = constant dense<0.5> : tensor<65xf32> - %post_activation_bias = constant dense<-1.0> : tensor<65xf32> + %filter = arith.constant dense<1.0> : tensor<65x3x3x3xf32> + %post_activation_multiplier = arith.constant dense<0.5> : tensor<65xf32> + %post_activation_bias = arith.constant dense<-1.0> : tensor<65xf32> %cst = constant unit %0 = "lq.Bconv2d"(%arg0, %filter, %post_activation_multiplier, %post_activation_bias, %cst) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>, none) -> tensor<256x30x30x65xf32> %1 = "lq.Quantize"(%0) : (tensor<256x30x30x65xf32>) -> tensor<256x30x30x3xi32> @@ -256,9 +256,9 @@ func @bitpack_activations_valid_padding(%arg0: tensor<256x32x32x1xi32>) -> tenso // CHECK-LABEL: @bitpack_activations_same_one_padding func @bitpack_activations_same_one_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xi32> { - %filter = constant dense<1.0> : tensor<65x3x3x3xf32> - %post_activation_multiplier = constant dense<0.5> : tensor<65xf32> - %post_activation_bias = constant dense<-1.0> : tensor<65xf32> + %filter = arith.constant dense<1.0> : tensor<65x3x3x3xf32> + %post_activation_multiplier = arith.constant dense<0.5> : tensor<65xf32> + %post_activation_bias = arith.constant dense<-1.0> : tensor<65xf32> %cst = constant unit %0 = "lq.Bconv2d"(%arg0, %filter, %post_activation_multiplier, %post_activation_bias, %cst) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>, none) -> tensor<256x32x32x65xf32> %1 = "lq.Quantize"(%0) : (tensor<256x32x32x65xf32>) -> tensor<256x32x32x3xi32> @@ -270,9 +270,9 @@ func @bitpack_activations_same_one_padding(%arg0: tensor<256x32x32x1xi32>) -> te // CHECK-LABEL: @do_not_bitpack_activations_same_zero_padding func @do_not_bitpack_activations_same_zero_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xi32> { - %filter = constant dense<1.0> : tensor<65x3x3x3xf32> - %post_activation_multiplier = constant dense<0.5> : tensor<65xf32> - %post_activation_bias = constant dense<-1.0> : tensor<65xf32> + %filter = arith.constant dense<1.0> : tensor<65x3x3x3xf32> + %post_activation_multiplier = arith.constant dense<0.5> : tensor<65xf32> + %post_activation_bias = arith.constant dense<-1.0> : tensor<65xf32> %cst = constant unit %0 = "lq.Bconv2d"(%arg0, %filter, %post_activation_multiplier, %post_activation_bias, %cst) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>, none) -> tensor<256x32x32x65xf32> %1 = "lq.Quantize"(%0) : (tensor<256x32x32x65xf32>) -> tensor<256x32x32x3xi32> @@ -285,9 +285,9 @@ func @do_not_bitpack_activations_same_zero_padding(%arg0: tensor<256x32x32x1xi32 // CHECK-LABEL: @do_not_bitpack_activations_multiple_uses func @do_not_bitpack_activations_multiple_uses(%arg0: tensor<256x32x32x1xi32>) -> (tensor<256x30x30x65xf32>, tensor<256x30x30x3xi32>) { - %filter = constant dense<1.0> : tensor<65x3x3x3xf32> - %post_activation_multiplier = constant dense<0.5> : tensor<65xf32> - %post_activation_bias = constant dense<-1.0> : tensor<65xf32> + %filter = arith.constant dense<1.0> : tensor<65x3x3x3xf32> + %post_activation_multiplier = arith.constant dense<0.5> : tensor<65xf32> + %post_activation_bias = arith.constant dense<-1.0> : tensor<65xf32> %cst = constant unit %0 = "lq.Bconv2d"(%arg0, %filter, %post_activation_multiplier, %post_activation_bias, %cst) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<256x32x32x1xi32>, tensor<65x3x3x3xf32>, tensor<65xf32>, tensor<65xf32>, none) -> tensor<256x30x30x65xf32> %1 = "lq.Quantize"(%0) : (tensor<256x30x30x65xf32>) -> tensor<256x30x30x3xi32> diff --git a/larq_compute_engine/mlir/tests/prepare-tf.mlir b/larq_compute_engine/mlir/tests/prepare-tf.mlir index aaeb4d5c4..69490cef6 100644 --- a/larq_compute_engine/mlir/tests/prepare-tf.mlir +++ b/larq_compute_engine/mlir/tests/prepare-tf.mlir @@ -3,8 +3,8 @@ // CHECK-LABEL: @fuse_bsign_tf_where func @fuse_bsign_tf_where(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { - %cst_l = constant dense<1.0> : tensor<8x16xf32> - %cst_r = constant dense<-1.0> : tensor<8x16xf32> + %cst_l = arith.constant dense<1.0> : tensor<8x16xf32> + %cst_r = arith.constant dense<-1.0> : tensor<8x16xf32> %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -15,8 +15,8 @@ func @fuse_bsign_tf_where(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { // CHECK-LABEL: @fuse_bsign_tf_where_inverted func @fuse_bsign_tf_where_inverted(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { - %cst_l = constant dense<-1.0> : tensor<8x16xf32> - %cst_r = constant dense<1.0> : tensor<8x16xf32> + %cst_l = arith.constant dense<-1.0> : tensor<8x16xf32> + %cst_r = arith.constant dense<1.0> : tensor<8x16xf32> %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -28,12 +28,12 @@ func @fuse_bsign_tf_where_inverted(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { // CHECK-LABEL: @fuse_bsign_tf_where_broadcast_cond func @fuse_bsign_tf_where_broadcast_cond(%arg0: tensor<8x1xi1>) -> tensor<8x16xf32> { - %cst_l = constant dense<1.0> : tensor<8x16xf32> - %cst_r = constant dense<-1.0> : tensor<8x16xf32> + %cst_l = arith.constant dense<1.0> : tensor<8x16xf32> + %cst_r = arith.constant dense<-1.0> : tensor<8x16xf32> %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x1xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> - // CHECK-NEXT: %cst = constant dense<[8, 16]> : tensor<2xi64> + // CHECK-NEXT: %cst = arith.constant dense<[8, 16]> : tensor<2xi64> // CHECK-NEXT: %0 = "tf.BroadcastTo"(%arg0, %cst) : (tensor<8x1xi1>, tensor<2xi64>) -> tensor<8x16xi1> // CHECK-NEXT: %1 = "lq.Quantize"(%0) : (tensor<8x16xi1>) -> tensor<8x1xi32> // CHECK-NEXT: %2 = "lq.Dequantize"(%1) : (tensor<8x1xi32>) -> tensor<8x16xf32> @@ -42,8 +42,8 @@ func @fuse_bsign_tf_where_broadcast_cond(%arg0: tensor<8x1xi1>) -> tensor<8x16xf // CHECK-LABEL: @fuse_bsign_tf_where_broadcast_lhs_rhs func @fuse_bsign_tf_where_broadcast_lhs_rhs(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { - %cst_l = constant dense<1.0> : tensor - %cst_r = constant dense<-1.0> : tensor<8x1xf32> + %cst_l = arith.constant dense<1.0> : tensor + %cst_r = arith.constant dense<-1.0> : tensor<8x1xf32> %0 = "tf.SelectV2"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor, tensor<8x1xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -54,8 +54,8 @@ func @fuse_bsign_tf_where_broadcast_lhs_rhs(%arg0: tensor<8x16xi1>) -> tensor<8x // CHECK-LABEL: @fuse_bsign_tf_where_select_v1_op func @fuse_bsign_tf_where_select_v1_op(%arg0: tensor<8x16xi1>) -> tensor<8x16xf32> { - %cst_l = constant dense<1.0> : tensor<8x16xf32> - %cst_r = constant dense<-1.0> : tensor<8x16xf32> + %cst_l = arith.constant dense<1.0> : tensor<8x16xf32> + %cst_r = arith.constant dense<-1.0> : tensor<8x16xf32> %0 = "tf.Select"(%arg0, %cst_l, %cst_r) : (tensor<8x16xi1>, tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> return %0 : tensor<8x16xf32> @@ -67,7 +67,7 @@ func @fuse_bsign_tf_where_select_v1_op(%arg0: tensor<8x16xi1>) -> tensor<8x16xf3 // CHECK-LABEL: @fuse_bsign_legacy_tf_sign func @fuse_bsign_legacy_tf_sign(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sign"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> - %cst = constant dense<0.1> : tensor + %cst = arith.constant dense<0.1> : tensor %2 = "tf.AddV2"(%0, %cst) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> %3 = "tf.Sign"(%2) : (tensor<8x16xf32>) -> tensor<8x16xf32> return %3 : tensor<8x16xf32> @@ -84,9 +84,9 @@ func @fuse_bconv2d_valid_padding(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112 %1 = "tf.Conv2D"(%0, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x3x2x2xf32>) -> tensor<1x112x110x2xf32> return %1 : tensor<1x112x110x2xf32> - // CHECK: %cst = constant - // CHECK: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<2xf32> - // CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32> + // CHECK: %cst = arith.constant + // CHECK: %[[post_activation_multiplier:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK: %[[post_activation_bias:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> // CHECK: %[[output_threshold:.*]] = constant unit // CHECK: %[[transpose:.*]] = "tf.Transpose" // CHECK-NEXT: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]], %[[output_threshold:.*]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x1xi32>, tensor<2x1x3x2xf32>, tensor<2xf32>, tensor<2xf32>, none) -> tensor<1x112x110x2xf32> @@ -100,9 +100,9 @@ func @target_specific_fuse_bconv2d_same_zero_padding(%arg0: tensor<1x112x112x1xi %1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x112x112x2xf32> return %1 : tensor<1x112x112x2xf32> - // CHECK-ARM: %cst = constant - // CHECK-ARM: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<2xf32> - // CHECK-ARM: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32> + // CHECK-ARM: %cst = arith.constant + // CHECK-ARM: %[[post_activation_multiplier:.*]] = arith.constant dense<1.000000e+00> : tensor<2xf32> + // CHECK-ARM: %[[post_activation_bias:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> // CHECK-ARM: %[[output_threshold:.*]] = constant unit // CHECK-ARM: %[[transpose:.*]] = "tf.Transpose" // CHECK-ARM-NEXT: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]], %[[output_threshold:.*]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x1xi32>, tensor<2x1x2x2xf32>, tensor<2xf32>, tensor<2xf32>, none) -> tensor<1x112x112x2xf32> @@ -120,9 +120,9 @@ func @fuse_bconv2d_grouped_convolution(%arg0: tensor<1x112x112x4xi32>) -> tensor %1 = "tf.Conv2D"(%0, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x112x112x128xf32>, tensor<3x3x64x16xf32>) -> tensor<1x110x110x16xf32> return %1 : tensor<1x110x110x16xf32> - // CHECK: %cst = constant - // CHECK: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<16xf32> - // CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<16xf32> + // CHECK: %cst = arith.constant + // CHECK: %[[post_activation_multiplier:.*]] = arith.constant dense<1.000000e+00> : tensor<16xf32> + // CHECK: %[[post_activation_bias:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> // CHECK: %[[output_threshold:.*]] = constant unit // CHECK: %[[transpose:.*]] = "tf.Transpose" // CHECK-NEXT: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]], %[[output_threshold:.*]]) {channels_in = 128 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x4xi32>, tensor<16x3x3x64xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<1x110x110x16xf32> @@ -143,14 +143,14 @@ func @do_not_fuse_bconv2d_grouped_convolution_group_size_not_mul_32(%arg0: tenso // CHECK-LABEL: @fuse_scaled_bconv2d func @fuse_scaled_bconv2d(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x110x2xf32> { - %cst = constant dense<[[[[0.3, -0.1], [0.3, 0.1]], [[-0.3, 0.1], [-0.3, 0.1]], [[-0.3, -0.1], [0.3, 0.1]]]]> : tensor<1x3x2x2xf32> + %cst = arith.constant dense<[[[[0.3, -0.1], [0.3, 0.1]], [[-0.3, 0.1], [-0.3, 0.1]], [[-0.3, -0.1], [0.3, 0.1]]]]> : tensor<1x3x2x2xf32> %0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> %1 = "tf.Conv2D"(%0, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x3x2x2xf32>) -> tensor<1x112x110x2xf32> return %1 : tensor<1x112x110x2xf32> - // CHECK: %cst = constant - // CHECK: %[[post_activation_multiplier:.*]] = constant dense<[3.000000e-01, 1.000000e-01]> : tensor<2xf32> - // CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<2xf32> + // CHECK: %cst = arith.constant + // CHECK: %[[post_activation_multiplier:.*]] = arith.constant dense<[3.000000e-01, 1.000000e-01]> : tensor<2xf32> + // CHECK: %[[post_activation_bias:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> // CHECK: %[[output_threshold:.*]] = constant unit // CHECK: %[[transpose:.*]] = "tf.Transpose" // CHECK-NEXT: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]], %[[output_threshold:.*]]) {channels_in = 2 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x112x112x1xi32>, tensor<2x1x3x2xf32>, tensor<2xf32>, tensor<2xf32>, none) -> tensor<1x112x110x2xf32> @@ -159,9 +159,9 @@ func @fuse_scaled_bconv2d(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x110x2x // CHECK-LABEL: @fuse_dilated_bconv func @fuse_dilated_bconv(%arg0: tensor<1x128x128x1xi32>) -> tensor<1x128x128x8xf32> { - %cst = constant dense<[2, 2]> : tensor<2xi32> - %cst_0 = constant dense<4> : tensor<2x2xi32> - %cst_1 = constant dense<1.0> : tensor<5x5x3x8xf32> + %cst = arith.constant dense<[2, 2]> : tensor<2xi32> + %cst_0 = arith.constant dense<4> : tensor<2x2xi32> + %cst_1 = arith.constant dense<1.0> : tensor<5x5x3x8xf32> %cst_2 = constant unit %0 = "lq.Dequantize"(%arg0) : (tensor<1x128x128x1xi32>) -> tensor<1x128x128x3xf32> %1 = "tf.SpaceToBatchND"(%0, %cst, %cst_0) : (tensor<1x128x128x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<4x68x68x3xf32> @@ -169,8 +169,8 @@ func @fuse_dilated_bconv(%arg0: tensor<1x128x128x1xi32>) -> tensor<1x128x128x8xf %3 = "tf.BatchToSpaceND"(%2, %cst, %cst_0) : (tensor<4x64x64x8xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor<1x128x128x8xf32> return %3 : tensor<1x128x128x8xf32> - // CHECK: %[[post_activation_multiplier:.*]] = constant dense<1.000000e+00> : tensor<8xf32> - // CHECK: %[[post_activation_bias:.*]] = constant dense<0.000000e+00> : tensor<8xf32> + // CHECK: %[[post_activation_multiplier:.*]] = arith.constant dense<1.000000e+00> : tensor<8xf32> + // CHECK: %[[post_activation_bias:.*]] = arith.constant dense<0.000000e+00> : tensor<8xf32> // CHECK: %[[output_threshold:.*]] = constant unit // CHECK: %[[transpose:.*]] = "tf.Transpose" // CHECK-NEXT: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %[[transpose]], %[[post_activation_multiplier]], %[[post_activation_bias]], %[[output_threshold:.*]]) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 2 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "VALID", stride_height = 1 : i32, stride_width = 1 : i32} : (tensor<1x128x128x1xi32>, tensor<8x5x5x3xf32>, tensor<8xf32>, tensor<8xf32>, none) -> tensor<1x128x128x8xf32> @@ -179,12 +179,12 @@ func @fuse_dilated_bconv(%arg0: tensor<1x128x128x1xi32>) -> tensor<1x128x128x8xf // CHECK-LABEL: @do_not_fuse_bconv2d_non_binary_weights func @do_not_fuse_bconv2d_non_binary_weights(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> { - %cst = constant dense<[[[[3.0, -1.0], [0.1, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]> : tensor<1x2x2x2xf32> + %cst = arith.constant dense<[[[[3.0, -1.0], [0.1, 1.0]], [[-1.0, 1.0], [-1.0, 1.0]]]]> : tensor<1x2x2x2xf32> %0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> %1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x112x112x2xf32> return %1 : tensor<1x112x112x2xf32> - // CHECK-NEXT: %cst = constant + // CHECK-NEXT: %cst = arith.constant // CHECK-NEXT: %0 = "lq.Dequantize"(%arg0) // CHECK-NEXT: %1 = "tf.Conv2D"(%0, %cst) // CHECK-NEXT: return %1 @@ -192,12 +192,12 @@ func @do_not_fuse_bconv2d_non_binary_weights(%arg0: tensor<1x112x112x1xi32>) -> // CHECK-LABEL: @do_not_fuse_bconv2d_zero_weight func @do_not_fuse_bconv2d_zero_weight(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> { - %cst = constant dense<0.0> : tensor<1x2x2x2xf32> + %cst = arith.constant dense<0.0> : tensor<1x2x2x2xf32> %0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x1xi32>) -> tensor<1x112x112x2xf32> %1 = "tf.Conv2D"(%0, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x112x112x2xf32> return %1 : tensor<1x112x112x2xf32> - // CHECK-NEXT: %cst = constant + // CHECK-NEXT: %cst = arith.constant // CHECK-NEXT: %0 = "lq.Dequantize"(%arg0) // CHECK-NEXT: %1 = "tf.Conv2D"(%0, %cst) // CHECK-NEXT: return %1 @@ -205,16 +205,16 @@ func @do_not_fuse_bconv2d_zero_weight(%arg0: tensor<1x112x112x1xi32>) -> tensor< // CHECK-LABEL: @fuse_bconv2d_same_one_padding func @fuse_bconv2d_same_one_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x16x16x16xf32> { - %cst = constant dense<1.0> : tensor<3x3x3x16xf32> - %cst0 = constant dense<1.0> : tensor - %cst1 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst = arith.constant dense<1.0> : tensor<3x3x3x16xf32> + %cst0 = arith.constant dense<1.0> : tensor + %cst1 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xf32> %1 = "tf.PadV2"(%0, %cst1, %cst0) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>, tensor) -> tensor<256x34x34x3xf32> %2 = "tf.Conv2D"(%1, %cst) {padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<256x34x34x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x16x16x16xf32> return %2 : tensor<256x16x16x16xf32> - // CHECK: %[[CST1:.*]] = constant dense<1.000000e+00> : tensor<16xf32> - // CHECK: %[[CST2:.*]] = constant dense<0.000000e+00> : tensor<16xf32> + // CHECK: %[[CST1:.*]] = arith.constant dense<1.000000e+00> : tensor<16xf32> + // CHECK: %[[CST2:.*]] = arith.constant dense<0.000000e+00> : tensor<16xf32> // CHECK: %[[CST3:.*]] = constant unit // CHECK: %[[TRP:.*]] = "tf.Transpose" // CHECK: %[[CONV:.*]] = "lq.Bconv2d"(%arg0, %[[TRP]], %[[CST1]], %[[CST2]], %[[CST3:.*]]) {channels_in = 3 : i32, dilation_height_factor = 1 : i32, dilation_width_factor = 1 : i32, fused_activation_function = "NONE", pad_values = 1 : i32, padding = "SAME", stride_height = 2 : i32, stride_width = 2 : i32} : (tensor<256x32x32x1xi32>, tensor<16x3x3x3xf32>, tensor<16xf32>, tensor<16xf32>, none) -> tensor<256x16x16x16xf32> @@ -222,9 +222,9 @@ func @fuse_bconv2d_same_one_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<25 // CHECK-LABEL: @do_not_fuse_bconv2d_padding_same_twice func @do_not_fuse_bconv2d_padding_same_twice(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x34x34x16xf32> { - %cst = constant dense<1.0> : tensor<3x3x3x16xf32> - %cst0 = constant dense<1.0> : tensor - %cst1 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst = arith.constant dense<1.0> : tensor<3x3x3x16xf32> + %cst0 = arith.constant dense<1.0> : tensor + %cst1 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xf32> %1 = "tf.PadV2"(%0, %cst1, %cst0) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>, tensor) -> tensor<256x34x34x3xf32> %2 = "tf.Conv2D"(%1, %cst) {padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<256x34x34x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x34x34x16xf32> @@ -237,9 +237,9 @@ func @do_not_fuse_bconv2d_padding_same_twice(%arg0: tensor<256x32x32x1xi32>) -> // CHECK-LABEL: @do_not_fuse_bconv2d_unsupported_constant_padding func @do_not_fuse_bconv2d_unsupported_constant_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x32x32x16xf32> { - %cst = constant dense<1.0> : tensor<3x3x3x16xf32> - %cst0 = constant dense<0.0> : tensor - %cst1 = constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> + %cst = arith.constant dense<1.0> : tensor<3x3x3x16xf32> + %cst0 = arith.constant dense<0.0> : tensor + %cst1 = arith.constant dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi32> %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xf32> %1 = "tf.PadV2"(%0, %cst1, %cst0) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>, tensor) -> tensor<256x34x34x3xf32> %2 = "tf.Conv2D"(%1, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x34x34x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> @@ -252,9 +252,9 @@ func @do_not_fuse_bconv2d_unsupported_constant_padding(%arg0: tensor<256x32x32x1 // CHECK-LABEL: @do_not_fuse_bconv2d_padding_wrong_size func @do_not_fuse_bconv2d_padding_wrong_size(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x34x34x16xf32> { - %cst = constant dense<1.0> : tensor<3x3x3x16xf32> - %cst0 = constant dense<1.0> : tensor - %cst1 = constant dense<[[0, 0], [2, 2], [2, 2], [0, 0]]> : tensor<4x2xi32> + %cst = arith.constant dense<1.0> : tensor<3x3x3x16xf32> + %cst0 = arith.constant dense<1.0> : tensor + %cst1 = arith.constant dense<[[0, 0], [2, 2], [2, 2], [0, 0]]> : tensor<4x2xi32> %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xf32> %1 = "tf.PadV2"(%0, %cst1, %cst0) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>, tensor) -> tensor<256x36x36x3xf32> %2 = "tf.Conv2D"(%1, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x36x36x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x34x34x16xf32> @@ -267,9 +267,9 @@ func @do_not_fuse_bconv2d_padding_wrong_size(%arg0: tensor<256x32x32x1xi32>) -> // CHECK-LABEL: @do_not_fuse_bconv2d_unsymmetric_padding func @do_not_fuse_bconv2d_unsymmetric_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<256x32x32x16xf32> { - %cst = constant dense<1.0> : tensor<3x3x3x16xf32> - %cst0 = constant dense<1.0> : tensor - %cst1 = constant dense<[[0, 0], [2, 0], [2, 0], [0, 0]]> : tensor<4x2xi32> + %cst = arith.constant dense<1.0> : tensor<3x3x3x16xf32> + %cst0 = arith.constant dense<1.0> : tensor + %cst1 = arith.constant dense<[[0, 0], [2, 0], [2, 0], [0, 0]]> : tensor<4x2xi32> %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xf32> %1 = "tf.PadV2"(%0, %cst1, %cst0) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>, tensor) -> tensor<256x34x34x3xf32> %2 = "tf.Conv2D"(%1, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<256x34x34x3xf32>, tensor<3x3x3x16xf32>) -> tensor<256x32x32x16xf32> @@ -282,9 +282,9 @@ func @do_not_fuse_bconv2d_unsymmetric_padding(%arg0: tensor<256x32x32x1xi32>) -> // CHECK-LABEL: @do_not_fuse_bconv2d_non_spatial_padding func @do_not_fuse_bconv2d_non_spatial_padding(%arg0: tensor<256x32x32x1xi32>) -> tensor<258x32x32x16xf32> { - %cst = constant dense<1.0> : tensor<3x3x5x16xf32> - %cst0 = constant dense<1.0> : tensor - %cst1 = constant dense<[[1, 1], [1, 1], [1, 1], [1, 1]]> : tensor<4x2xi32> + %cst = arith.constant dense<1.0> : tensor<3x3x5x16xf32> + %cst0 = arith.constant dense<1.0> : tensor + %cst1 = arith.constant dense<[[1, 1], [1, 1], [1, 1], [1, 1]]> : tensor<4x2xi32> %0 = "lq.Dequantize"(%arg0) : (tensor<256x32x32x1xi32>) -> tensor<256x32x32x3xf32> %1 = "tf.PadV2"(%0, %cst1, %cst0) : (tensor<256x32x32x3xf32>, tensor<4x2xi32>, tensor) -> tensor<258x34x34x5xf32> %2 = "tf.Conv2D"(%1, %cst) {padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<258x34x34x5xf32>, tensor<3x3x5x16xf32>) -> tensor<258x32x32x16xf32> diff --git a/larq_compute_engine/mlir/tests/quantize.mlir b/larq_compute_engine/mlir/tests/quantize.mlir index c906ffda4..a0c7393ef 100644 --- a/larq_compute_engine/mlir/tests/quantize.mlir +++ b/larq_compute_engine/mlir/tests/quantize.mlir @@ -2,18 +2,18 @@ // CHECK-LABEL: quantize_bconv2d func @quantize_bconv2d(%arg0: tensor<1x224x224x1xi32>, %arg1: tensor<32x3x3x1xi32>, %arg2: none) -> tensor<1x112x112x32x!quant.uniform> { - %cst0 = constant dense<-1.23697901> : tensor<32xf32> + %cst0 = arith.constant dense<-1.23697901> : tensor<32xf32> %0 = "tfl.quantize"(%cst0) {qtype = tensor<32x!quant.uniform>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<32x!quant.uniform>) -> tensor<32xf32> - %cst1 = constant dense<1.10976315> : tensor<32xf32> + %cst1 = arith.constant dense<1.10976315> : tensor<32xf32> %2 = "tfl.quantize"(%cst1) {qtype = tensor<32x!quant.uniform>} : (tensor<32xf32>) -> tensor<32x!quant.uniform> %3 = "tfl.dequantize"(%2) : (tensor<32x!quant.uniform>) -> tensor<32xf32> %4 = "lq.Bconv2d"(%arg0, %arg1, %1, %3, %arg2) {channels_in = 3 : i32, dilation_height_factor = 2 : i32, dilation_width_factor = 3 : i32, fused_activation_function = "NONE", pad_values = 0 : i32, padding = "SAME", stride_height = 4 : i32, stride_width = 5 : i32} : (tensor<1x224x224x1xi32>, tensor<32x3x3x1xi32>, tensor<32xf32>, tensor<32xf32>, none) -> tensor<1x112x112x32xf32> %5 = "tfl.quantize"(%4) {qtype = tensor<1x112x112x32x!quant.uniform>} : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x32x!quant.uniform> return %5 : tensor<1x112x112x32x!quant.uniform> -// CHECK: %[[cst0:.*]] = constant dense<-1.23697901> : tensor<32xf32> -// CHECK: %[[cst1:.*]] = constant dense<1.10976315> : tensor<32xf32> +// CHECK: %[[cst0:.*]] = arith.constant dense<-1.23697901> : tensor<32xf32> +// CHECK: %[[cst1:.*]] = arith.constant dense<1.10976315> : tensor<32xf32> // CHECK: %[[conv:.*]] = "lq.Bconv2d"(%arg0, %arg1, %[[cst0]], %[[cst1]], %arg2) // CHECK: return %[[conv]] : tensor<1x112x112x32x!quant.uniform> } @@ -37,3 +37,13 @@ func @quantize_lce_dequantize(%arg0: tensor<1x112x112x1xi32>) -> tensor<1x112x11 // CHECK-NEXT: %0 = "lq.Dequantize"(%arg0) : (tensor<1x112x112x1xi32>) -> tensor<1x112x112x32x!quant.uniform> // CHECK-NEXT: return %0 : tensor<1x112x112x32x!quant.uniform> } + +// CHECK-LABEL: dequantize_lce_quantize +func @dequantize_lce_quantize(%arg0: tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x1xi32> { + %0 = "tfl.dequantize"(%arg0) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x32xf32> + %1 = "lq.Quantize"(%0) : (tensor<1x112x112x32xf32>) -> tensor<1x112x112x1xi32> + return %1 : tensor<1x112x112x1xi32> + +// CHECK: %[[quant:.*]] = "lq.Quantize"(%arg0) : (tensor<1x112x112x32x!quant.uniform>) -> tensor<1x112x112x1xi32> +// CHECK-NEXT: return %[[quant]] : tensor<1x112x112x1xi32> +} diff --git a/larq_compute_engine/mlir/tests/set_batch_size.mlir b/larq_compute_engine/mlir/tests/set_batch_size.mlir index 203291c98..9c4d83eff 100644 --- a/larq_compute_engine/mlir/tests/set_batch_size.mlir +++ b/larq_compute_engine/mlir/tests/set_batch_size.mlir @@ -18,9 +18,9 @@ func @simple(%arg0: tensor, %arg1: tensor<2x6xf32>) -> (tensor // Both inputs have a dynamic batch size // CHECK-LABEL: @dual_input_model -func @dual_input_model(%arg0: tensor {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor>>) -> tensor<6xf32> - %1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor>>) -> tensor<4x6xf32> +func @dual_input_model(%arg0: tensor {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor>>) -> tensor<6xf32> + %1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor>>) -> tensor<4x6xf32> %2 = "tf.MatMul"(%arg1, %1) {device = "", transpose_a = false, transpose_b = false} : (tensor, tensor<4x6xf32>) -> tensor %3 = "tf.BiasAdd"(%2, %0) {data_format = "NHWC", device = ""} : (tensor, tensor<6xf32>) -> tensor %4 = "tf.AddV2"(%3, %arg0) {device = ""} : (tensor, tensor) -> tensor @@ -30,15 +30,15 @@ func @dual_input_model(%arg0: tensor {tf_saved_model.index_path = ["inp // CHECK: %arg0: tensor<1x6xf32> {tf_saved_model.index_path = ["input_2"]} // CHECK: %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]} // The resource objects and attributes should be unchanged - // CHECK: %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + // CHECK: %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { } // This is the same model, but one of the two inputs has been given a fixed batch size in Python // CHECK-LABEL: @dual_input_one_fixed_size -func @dual_input_one_fixed_size(%arg0: tensor {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { - %0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor>>) -> tensor<6xf32> - %1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor>>) -> tensor<4x6xf32> +func @dual_input_one_fixed_size(%arg0: tensor {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]}, %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor>>) -> tensor<6xf32> + %1 = "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor>>) -> tensor<4x6xf32> %2 = "tf.MatMul"(%arg1, %1) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x4xf32>, tensor<4x6xf32>) -> tensor<1x6xf32> %3 = "tf.BiasAdd"(%2, %0) {data_format = "NHWC", device = ""} : (tensor<1x6xf32>, tensor<6xf32>) -> tensor<1x6xf32> %4 = "tf.AddV2"(%3, %arg0) {device = ""} : (tensor<1x6xf32>, tensor) -> tensor @@ -47,5 +47,5 @@ func @dual_input_one_fixed_size(%arg0: tensor {tf_saved_model.index_pat return %6 : tensor // CHECK: %arg0: tensor<1x6xf32> {tf_saved_model.index_path = ["input_2"]} // CHECK: %arg1: tensor<1x4xf32> {tf_saved_model.index_path = ["input_1"]} - // CHECK: %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { -} \ No newline at end of file + // CHECK: %arg2: tensor>> {tf_saved_model.bound_input = @"dense/bias"}, %arg3: tensor>> {tf_saved_model.bound_input = @"dense/kernel"}) -> (tensor {tf_saved_model.index_path = ["tf.__operators__.add"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_2:0,serving_default_input_1:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { +} diff --git a/larq_compute_engine/mlir/tf_tfl_passes.cc b/larq_compute_engine/mlir/tf_tfl_passes.cc index 07b4202ed..20d622647 100644 --- a/larq_compute_engine/mlir/tf_tfl_passes.cc +++ b/larq_compute_engine/mlir/tf_tfl_passes.cc @@ -7,7 +7,6 @@ #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" @@ -28,7 +27,7 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager) { pass_manager->addNestedPass( mlir::TFL::CreatePrepareQuantizePass(quant_specs)); - pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); + pass_manager->addNestedPass(mlir::TFL::CreateLCEQuantizePass()); if (quant_specs.default_ranges.first.hasValue() || quant_specs.default_ranges.second.hasValue()) { pass_manager->addNestedPass( @@ -36,19 +35,25 @@ void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, quant_specs.default_ranges.first.getValueOr(0.0), quant_specs.default_ranges.second.getValueOr(0.0), quant_specs.IsSignedInferenceType())); - pass_manager->addPass(mlir::TFL::CreateLCEQuantizePass()); + pass_manager->addNestedPass( + mlir::TFL::CreateLCEQuantizePass()); } pass_manager->addNestedPass(mlir::TFL::CreateQuantizePass()); bool emit_quant_adaptor_ops = quant_specs.inference_type != quant_specs.inference_input_type; pass_manager->addNestedPass( mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); + pass_manager->addNestedPass(mlir::TFL::CreateLCEQuantizePass()); + pass_manager->addNestedPass( + mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); } } // namespace -void AddTFToLCETFLConversionPasses( - const mlir::TFL::QuantizationSpecs& quant_specs, - mlir::OpPassManager* pass_manager, const LCETarget target) { +// This is the early part of the conversion in isolation. This enables a caller +// to inject more information in the middle of the conversion before resuming +// it. +void AddPreVariableFreezingTFToLCETFLConversionPasses( + mlir::OpPassManager* pass_manager) { // This pass wraps all the tf.FakeQuant ops in a custom op so they are not // folded before being converted to tfl.quantize and tfl.dequantize ops. auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps(); @@ -79,7 +84,15 @@ void AddTFToLCETFLConversionPasses( // during which resources dont get frozen in the python layer. pass_manager->addNestedPass( mlir::TFDevice::CreateDecomposeResourceOpsPass()); +} +// This is the later part of the conversion in isolation. This enables a caller +// to resume the conversion after injecting more information in the middle of +// it. +void AddPostVariableFreezingTFToLCETFLConversionPasses( + llvm::StringRef saved_model_dir, + const mlir::TFL::QuantizationSpecs& quant_specs, + mlir::OpPassManager* pass_manager, const LCETarget target) { // Note: // We need to fuse composite ops before LowerStaticTensorList pass. // The tensorflow list is not supported right now by that pass. @@ -100,7 +113,7 @@ void AddTFToLCETFLConversionPasses( // Set the batch size of the function input to 1 and let shape inference // propagate this in the next pass. - pass_manager->addPass(mlir::CreateSetBatchSizePass()); + pass_manager->addNestedPass(mlir::CreateSetBatchSizePass()); // Add a shape inference pass to optimize away the unnecessary casts. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); @@ -116,12 +129,9 @@ void AddTFToLCETFLConversionPasses( // function inliner interface. pass_manager->addPass(mlir::createInlinerPass()); - // TODO(jpienaar): Revise post dialect constants. - pass_manager->addNestedPass( - mlir::TF::CreateDecodeConstantPass()); // Remove passthrough ops early so constant folding can happen before // LCE ops are injected - pass_manager->addPass(mlir::TFL::CreateOpRemovalPass()); + pass_manager->addNestedPass(mlir::TFL::CreateOpRemovalPass()); // The following pass used to be just after createSymbolDCEPass but we move it // before createCanonicalizerPass because without it, the tf.Sign op is not @@ -131,6 +141,21 @@ void AddTFToLCETFLConversionPasses( // constant ops. pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass()); + if (!saved_model_dir.empty()) { + // This pass 'freezes' tf saved model asset ops and inlines as string values + // in a format of the tf constant op. + pass_manager->addPass( + mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str())); + } + + // Reduce operands of TFL::While without changing the outcome. + // It needs to stay here because: + // 1. WhileOps are in TFL dialect. + // 2. The body and cond are inlined. + // 3. We need to do this before while canonicalization, otherwise it would be + // difficult to find dependencies. + pass_manager->addNestedPass( + mlir::TFL::CreateReduceWhileOperandsPass()); // Canonicalization includes const folding, which is utilized here to optimize // away ops that can't get constant folded after PrepareTF pass. For example, // tf.Conv2D is split into tf.Transpose and tfl.Conv2D. @@ -139,6 +164,7 @@ void AddTFToLCETFLConversionPasses( // This pass does dead code elimination based on symbol visibility. pass_manager->addPass(mlir::createSymbolDCEPass()); + // Run shape inference after variables are converted to constants. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); // Force layout supported by TFLite, this will transpose the data // to match 'kTFLiteDataLayout' @@ -148,7 +174,8 @@ void AddTFToLCETFLConversionPasses( mlir::TF::CreateLayoutOptimizationPipeline(pass_manager->nest(), layout_optimization_options); // Inject Larq Compute Engine Ops - pass_manager->addPass(mlir::TFL::CreatePrepareLCEPass(target)); + pass_manager->addNestedPass( + mlir::TFL::CreatePrepareLCEPass(target)); // Prepare for TFLite dialect, rerun canonicalization, and then legalize to // the TFLite dialect. pass_manager->addNestedPass( @@ -159,22 +186,32 @@ void AddTFToLCETFLConversionPasses( // TODO(fengliuai): remove this pass if TableGen patterns have a better // to control the shapes for the intermediate results. pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); + // Inline function calls that left in the graph after folding functional // control flow ops (IfOp, CaseOp). pass_manager->addPass(mlir::createInlinerPass()); + // This pass removes the asset file dependencies in hash table use cases. + pass_manager->addNestedPass( + mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str())); + // This pass removes the asset file dependencies in hash table use cases. pass_manager->addNestedPass( mlir::TF::CreateInitTextFileToImportPass()); pass_manager->addNestedPass( mlir::TFL::CreateLegalizeTFPass(true)); + pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass()); + pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass()); pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); - pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target)); + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizeLCEPass(target)); pass_manager->addNestedPass( mlir::TFL::CreateOptimizePass(true)); - pass_manager->addPass(mlir::TFL::CreateOptimizeLCEPass(target)); - pass_manager->addPass(mlir::TFL::CreateBitpackWeightsLCEPass()); + pass_manager->addNestedPass( + mlir::TFL::CreateOptimizeLCEPass(target)); + pass_manager->addNestedPass( + mlir::TFL::CreateBitpackWeightsLCEPass()); // This pass operates on TensorFlow ops but is triggered after legalization // so that it can target constants introduced once TensorFlow Identity ops @@ -187,13 +224,14 @@ void AddTFToLCETFLConversionPasses( pass_manager->addNestedPass(mlir::createCanonicalizerPass()); pass_manager->addNestedPass(mlir::createCSEPass()); - pass_manager->addPass(mlir::TFL::CreateFusePaddingPass()); + pass_manager->addNestedPass(mlir::TFL::CreateFusePaddingPass()); // Run quantization after all the floating point model conversion is // completed. if (quant_specs.RunPropagationAndRewriteQuantizationPasses()) { AddQuantizationPasses(quant_specs, pass_manager); } + pass_manager->addPass(mlir::createCanonicalizerPass()); // This pass should be always at the end of the model // conversion (even after quantization). Some TFL ops like unidirectional @@ -209,8 +247,12 @@ void AddTFToLCETFLConversionPasses( // model dialect. pass_manager->addPass( mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass()); + pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass()); + pass_manager->addPass(mlir::TFL::CreateWhileOutlinePass()); + pass_manager->addNestedPass( + mlir::TFL::CreateRuntimeVerifyPass()); - pass_manager->addPass(mlir::TFL::CreateLegalizeLCEPass()); + pass_manager->addNestedPass(mlir::TFL::CreateLegalizeLCEPass()); } } // namespace tensorflow diff --git a/larq_compute_engine/mlir/tf_tfl_passes.h b/larq_compute_engine/mlir/tf_tfl_passes.h index 3b9dccb88..f5d3d923d 100644 --- a/larq_compute_engine/mlir/tf_tfl_passes.h +++ b/larq_compute_engine/mlir/tf_tfl_passes.h @@ -9,8 +9,11 @@ namespace tensorflow { -// Add the TF to TFLite passes into a pass_manager. -void AddTFToLCETFLConversionPasses( +void AddPreVariableFreezingTFToLCETFLConversionPasses( + mlir::OpPassManager* pass_manager); + +void AddPostVariableFreezingTFToLCETFLConversionPasses( + llvm::StringRef saved_model_dir, const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager, const LCETarget target = LCETarget::ARM); diff --git a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc index 84382d23b..e5a4a9520 100644 --- a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc +++ b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.cc @@ -1,10 +1,14 @@ #include "larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h" +#include "larq_compute_engine/mlir/tf_tfl_passes.h" +#include "larq_compute_engine/mlir/transforms/passes.h" #include "llvm/Support/raw_ostream.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/PassManager.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -27,11 +31,13 @@ mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) { : mlir::WalkResult::advance(); }); if (result.wasInterrupted()) { - module.emitError( - "The graph has Control Flow V1 ops. TFLite converter doesn't support " - "Control Flow V1 ops. Consider using Control Flow V2 ops instead. See " - "https://www.tensorflow.org/api_docs/python/tf/compat/v1/" - "enable_control_flow_v2."); + mlir::TFL::AttachErrorCode( + module.emitError( + "The graph has Control Flow V1 ops. TFLite converter doesn't " + "support Control Flow V1 ops. Consider using Control Flow V2 ops " + "instead. See https://www.tensorflow.org/api_docs/python/tf/compat/" + "v1/enable_control_flow_v2."), + tflite::metrics::ConverterErrorData::ERROR_UNSUPPORTED_CONTROL_FLOW_V1); return mlir::failure(); } return mlir::success(); @@ -49,10 +55,12 @@ class TruncateOpOrArgLocNameMapper : public OpOrArgLocNameMapper { }; } // namespace - -Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir, - std::string* result, - mlir::PassManager* pass_manager) { +Status ConvertTFExecutorToTFLOrFlatbuffer( + mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, + mlir::TFL::QuantizationSpecs quant_specs, + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, std::string* result) { // Explicitly disable dumping Op details on failures. module.getContext()->printOpOnDiagnostic(false); @@ -70,22 +78,64 @@ Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir, mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(), /*propagate=*/true); + if (failed(IsValidGraph(module))) { + return statusHandler.ConsumeStatus(); + } + + mlir::PassManager pass_manager(module.getContext()); + mlir::applyPassManagerCLOptions(pass_manager); + pass_manager.addInstrumentation( + std::make_unique( + pass_manager.getContext())); - if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) { + tensorflow::AddPreVariableFreezingTFToLCETFLConversionPasses(&pass_manager); + if (failed(pass_manager.run(module))) { return statusHandler.ConsumeStatus(); } + // Freeze variables if a session is provided. + if (session.hasValue()) { + mlir::TFL::ErrorCollectorInstrumentation collector(module.getContext()); + if (failed(mlir::tf_saved_model::FreezeVariables(module, + session.getValue()))) { + auto status = statusHandler.ConsumeStatus(); + mlir::TFL::ErrorCollector* collector = + mlir::TFL::ErrorCollector::GetErrorCollector(); + if (!collector->CollectedErrors().empty()) { + return errors::InvalidArgument("Variable constant folding has failed."); + } + return status; + } + } + pass_manager.clear(); + tensorflow::AddPostVariableFreezingTFToLCETFLConversionPasses( + saved_model_dir, quant_specs, &pass_manager, target); + if (failed(pass_manager.run(module))) { + auto status = statusHandler.ConsumeStatus(); + mlir::TFL::ErrorCollector* collector = + mlir::TFL::ErrorCollector::GetErrorCollector(); + for (const auto& error_data : collector->CollectedErrors()) { + if (error_data.subcomponent() == "FreezeGlobalTensorsPass") { + return errors::InvalidArgument("Variable constant folding is failed."); + } + } + return status; + } + if (export_to_mlir) { llvm::raw_string_ostream os(*result); module.print(os); - return Status::OK(); + return statusHandler.ConsumeStatus(); } - // This is the only modification compared to the upstream tensorflow file + // Write MLIR TFLite dialect into FlatBuffer TruncateOpOrArgLocNameMapper op_or_arg_name_mapper; + toco::TocoFlags toco_flags; + toco_flags.set_force_select_tf_ops(false); + toco_flags.set_allow_custom_ops(true); tflite::FlatbufferExportOptions options; - options.emit_builtin_tflite_ops = true; - options.emit_custom_ops = true; + options.toco_flags = toco_flags; + options.saved_model_tags = saved_model_tags; options.op_or_arg_name_mapper = &op_or_arg_name_mapper; if (!tflite::MlirToFlatBufferTranslateFunction(module, options, result)) { return statusHandler.ConsumeStatus(); diff --git a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h index 3f8d54161..f1aa84d1a 100644 --- a/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h +++ b/larq_compute_engine/mlir/tf_to_tfl_flatbuffer.h @@ -1,18 +1,24 @@ #ifndef LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ #define LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ +#include + +#include "larq_compute_engine/mlir/transforms/passes.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/PassManager.h" +#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" +#include "tensorflow/core/public/session.h" #include "tensorflow/stream_executor/lib/statusor.h" - namespace tensorflow { // This is a fork of ConvertTFExecutorToTFLOrFlatbuffer to enable custom // OpOrArgLocNameMapper -// https://github.com/tensorflow/tensorflow/blob/v2.5.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L55-L69 -Status ConvertTFExecutorToFlatbuffer(mlir::ModuleOp module, bool export_to_mlir, - std::string* result, - mlir::PassManager* pass_manager); +// https://github.com/tensorflow/tensorflow/blob/v2.8.0/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h#L60-L78 +Status ConvertTFExecutorToTFLOrFlatbuffer( + mlir::ModuleOp module, bool export_to_mlir, const LCETarget target, + mlir::TFL::QuantizationSpecs quant_specs, + const std::unordered_set& saved_model_tags, + llvm::StringRef saved_model_dir, + llvm::Optional session, std::string* result); } // namespace tensorflow #endif // LARQ_COMPUTE_ENGINE_MLIR_TF_TO_TFL_FLATBUFFER_H_ diff --git a/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td b/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td index 2bc3d535d..611d5b221 100644 --- a/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td +++ b/larq_compute_engine/mlir/transforms/bitpack_activations_patterns.td @@ -1,4 +1,5 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "larq_compute_engine/mlir/ir/lce_ops.td" @@ -13,32 +14,32 @@ def CreateNoneAttrValue : NativeCodeCall<"$_builder.getUnitAttr()">; def GetSignsOfVectorAndBroadcast4D : NativeCodeCall<"GetSignsOfVectorAndBroadcast4D($0)">; def GetBitpackedOutputThresholds : NativeCodeCall<"GetBitpackedOutputThresholds($_builder, $0, $1, $2, $3)">; -class WriteBitpackedActivationsPat : +class WriteBitpackedActivationsPat : Pat<(LQ_QuantizeOp (LQ_Bconv2dOp:$output $input, - (ConstantOp F32ElementsAttr:$filter), - (ConstantOp F32ElementsAttr:$post_activation_multiplier), - (ConstantOp F32ElementsAttr:$post_activation_bias), + (Arith_ConstantOp F32ElementsAttr:$filter), + (Arith_ConstantOp F32ElementsAttr:$post_activation_multiplier), + (Arith_ConstantOp F32ElementsAttr:$post_activation_bias), (ConstantOp UnitAttr), $channels_in, $dilation_height_factor, $dilation_width_factor, $fused_activation_function, ConstantAttr, - ConstantAttr, + padding_type, $stride_height, $stride_width)), (LQ_Bconv2dOp $input, (TFL_MulOp - (ConstantOp $filter), - (ConstantOp + (Arith_ConstantOp $filter), + (Arith_ConstantOp (GetSignsOfVectorAndBroadcast4D $post_activation_multiplier)), TFL_AF_None), (ConstantOp (CreateNoneAttrValue)), (ConstantOp (CreateNoneAttrValue)), - (ConstantOp + (Arith_ConstantOp (GetBitpackedOutputThresholds $filter, $post_activation_multiplier, @@ -49,9 +50,9 @@ class WriteBitpackedActivationsPat : $dilation_width_factor, $fused_activation_function, ConstantAttr, - ConstantAttr, + padding_type, $stride_height, $stride_width), [(HasOneUse $output)], (addBenefit 10)>; -def : WriteBitpackedActivationsPat<"VALID", "0">; -def : WriteBitpackedActivationsPat<"SAME", "1">; +def : WriteBitpackedActivationsPat; +def : WriteBitpackedActivationsPat; diff --git a/larq_compute_engine/mlir/transforms/bitpack_weights.cc b/larq_compute_engine/mlir/transforms/bitpack_weights.cc index 95c9f0367..8a6a8e179 100644 --- a/larq_compute_engine/mlir/transforms/bitpack_weights.cc +++ b/larq_compute_engine/mlir/transforms/bitpack_weights.cc @@ -12,6 +12,12 @@ namespace TFL { namespace { struct BitpackWeightsLCE : public PassWrapper { + llvm::StringRef getArgument() const final { + return "tfl-lce-bitpack-weights"; + } + llvm::StringRef getDescription() const final { + return "Bitpack binary weights"; + } void runOnFunction() override; }; @@ -39,8 +45,7 @@ std::unique_ptr> CreateBitpackWeightsLCEPass() { return std::make_unique(); } -static PassRegistration pass("tfl-lce-bitpack-weights", - "Bitpack binary weights"); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/bitpack_weights_patterns.td b/larq_compute_engine/mlir/transforms/bitpack_weights_patterns.td index 7003984c4..47f978d33 100644 --- a/larq_compute_engine/mlir/transforms/bitpack_weights_patterns.td +++ b/larq_compute_engine/mlir/transforms/bitpack_weights_patterns.td @@ -1,4 +1,5 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "larq_compute_engine/mlir/ir/lce_ops.td" include "larq_compute_engine/mlir/transforms/op_removal_patterns.td" @@ -9,7 +10,7 @@ def Bitpack : NativeCodeCall<"Bitpack(&$_builder, $0)">; def : Pat<(LQ_Bconv2dOp $input, - (ConstantOp Conv2DFilter:$filter), + (Arith_ConstantOp Conv2DFilter:$filter), $post_activation_multiplier, $post_activation_bias, $output_threshold, @@ -23,7 +24,7 @@ def : Pat<(LQ_Bconv2dOp $stride_width), (LQ_Bconv2dOp $input, - (ConstantOp (Bitpack $filter)), + (Arith_ConstantOp (Bitpack $filter)), $post_activation_multiplier, $post_activation_bias, $output_threshold, diff --git a/larq_compute_engine/mlir/transforms/fuse_padding.cc b/larq_compute_engine/mlir/transforms/fuse_padding.cc index 99046f7f1..2761478e6 100644 --- a/larq_compute_engine/mlir/transforms/fuse_padding.cc +++ b/larq_compute_engine/mlir/transforms/fuse_padding.cc @@ -37,6 +37,10 @@ bool IsSamePaddingPartial(Attribute paddings_attr, Value input, Value output, // Prepare LCE operations in functions for subsequent legalization. struct FusePadding : public PassWrapper { + llvm::StringRef getArgument() const final { return "tfl-fuse-padding"; } + llvm::StringRef getDescription() const final { + return "Fuse padding ops into (Depthwise)Convs"; + } FusePadding() = default; FusePadding(const FusePadding& pass) {} @@ -59,8 +63,7 @@ std::unique_ptr> CreateFusePaddingPass() { return std::make_unique(); } -static PassRegistration pass( - "tfl-fuse-padding", "Fuse padding ops into (Depthwise)Convs."); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/fuse_padding.td b/larq_compute_engine/mlir/transforms/fuse_padding.td index 9aeaa62d6..d94825b4a 100644 --- a/larq_compute_engine/mlir/transforms/fuse_padding.td +++ b/larq_compute_engine/mlir/transforms/fuse_padding.td @@ -1,4 +1,5 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" def HasOneUse : Constraint>; @@ -20,13 +21,13 @@ def SamePaddingWidth : Constraint, + TFL_PAD_Valid, $stride_h, $stride_w), (TFL_Conv2DOp $input, @@ -35,7 +36,7 @@ def : Pat<(TFL_Conv2DOp:$conv_output $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Same, $stride_h, $stride_w), [(HasOneUse $pad_output), @@ -49,14 +50,14 @@ def : Pat<(TFL_Conv2DOp:$conv_output def : Pat<(TFL_Conv2DOp:$conv_output (TFL_PadV2Op:$pad_output $input, - (ConstantOp $paddings), - (ConstantOp $pad_values)), + (Arith_ConstantOp $paddings), + (Arith_ConstantOp $pad_values)), $filter, $bias, $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Valid, $stride_h, $stride_w), (TFL_Conv2DOp $input, @@ -65,7 +66,7 @@ def : Pat<(TFL_Conv2DOp:$conv_output $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Same, $stride_h, $stride_w), [(HasOneUse $pad_output), @@ -79,13 +80,13 @@ def : Pat<(TFL_Conv2DOp:$conv_output def : Pat<(TFL_DepthwiseConv2DOp:$conv_output (TFL_PadOp:$pad_output $input, - (ConstantOp $paddings)), + (Arith_ConstantOp $paddings)), $filter, $bias, $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Valid, $stride_h, $stride_w, $depth_multiplier), @@ -95,7 +96,7 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Same, $stride_h, $stride_w, $depth_multiplier), @@ -109,14 +110,14 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output def : Pat<(TFL_DepthwiseConv2DOp:$conv_output (TFL_PadV2Op:$pad_output $input, - (ConstantOp $paddings), - (ConstantOp $pad_values)), + (Arith_ConstantOp $paddings), + (Arith_ConstantOp $pad_values)), $filter, $bias, $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Valid, $stride_h, $stride_w, $depth_multiplier), @@ -126,7 +127,7 @@ def : Pat<(TFL_DepthwiseConv2DOp:$conv_output $h_factor, $w_factor, $act_fn, - ConstantAttr, + TFL_PAD_Same, $stride_h, $stride_w, $depth_multiplier), diff --git a/larq_compute_engine/mlir/transforms/legalize_tflite.cc b/larq_compute_engine/mlir/transforms/legalize_tflite.cc index 4b1fc5b81..861ca84af 100644 --- a/larq_compute_engine/mlir/transforms/legalize_tflite.cc +++ b/larq_compute_engine/mlir/transforms/legalize_tflite.cc @@ -10,6 +10,10 @@ namespace TFL { namespace { struct LegalizeLCE : public PassWrapper { + llvm::StringRef getArgument() const final { return "tfl-legalize-lce"; } + llvm::StringRef getDescription() const final { + return "Legalize LCE ops in TensorFlow Lite dialect"; + } void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } @@ -57,8 +61,7 @@ std::unique_ptr> CreateLegalizeLCEPass() { return std::make_unique(); } -static PassRegistration pass( - "tfl-legalize-lce", "Legalize LCE ops in TensorFlow Lite dialect"); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/op_removal.cc b/larq_compute_engine/mlir/transforms/op_removal.cc index 051874e51..d508eee5d 100644 --- a/larq_compute_engine/mlir/transforms/op_removal.cc +++ b/larq_compute_engine/mlir/transforms/op_removal.cc @@ -12,6 +12,10 @@ namespace { // Op removal of pass through ops to make following patterns easier and enable // early constant folding struct OpRemoval : public PassWrapper { + llvm::StringRef getArgument() const final { return "lce-op-removal-tf"; } + llvm::StringRef getDescription() const final { + return "Remove pass through TensorFlow ops"; + } void runOnFunction() override; }; @@ -32,8 +36,7 @@ std::unique_ptr> CreateOpRemovalPass() { return std::make_unique(); } -static PassRegistration pass("lce-op-removal-tf", - "Remove pass through TensorFlow ops"); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/op_removal_patterns.td b/larq_compute_engine/mlir/transforms/op_removal_patterns.td index 8534f40fc..333985559 100644 --- a/larq_compute_engine/mlir/transforms/op_removal_patterns.td +++ b/larq_compute_engine/mlir/transforms/op_removal_patterns.td @@ -1,4 +1,5 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def NonOpaqueElementsAttr : ElementsAttrBase< @@ -6,7 +7,7 @@ def NonOpaqueElementsAttr : ElementsAttrBase< "non-opaque constant tensor">; // Convert to std constant for statically shaped, non-opaque constants. -def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (ConstantOp $value), +def : Pat<(TF_ConstOp:$res NonOpaqueElementsAttr:$value), (Arith_ConstantOp $value), [(AnyStaticShapeTensor $res)]>; // Partially supported in TFLite, treated as passthrough IdentityOp diff --git a/larq_compute_engine/mlir/transforms/optimize.cc b/larq_compute_engine/mlir/transforms/optimize.cc index cfff25a1a..4a29b9dd9 100644 --- a/larq_compute_engine/mlir/transforms/optimize.cc +++ b/larq_compute_engine/mlir/transforms/optimize.cc @@ -20,6 +20,10 @@ namespace { // Optimize LCE operations in functions. struct OptimizeLCE : public PassWrapper { + llvm::StringRef getArgument() const final { return "tfl-optimize-lce"; } + llvm::StringRef getDescription() const final { + return "Optimize within the TensorFlow Lite dialect"; + } OptimizeLCE() = default; OptimizeLCE(const OptimizeLCE& pass) {} OptimizeLCE(const LCETarget target) { target_.setValue(target); } @@ -202,7 +206,7 @@ DenseElementsAttr GetSignsOfVectorAndBroadcast4D(Attribute vector_attr) { std::vector signs(vector_length); for (std::size_t i = 0; i < vector_length; ++i) { - const auto sign = vector.getValue({i}) >= 0.0f ? 1.0f : -1.0f; + const auto sign = vector.getValues()[i] >= 0.0f ? 1.0f : -1.0f; signs[i] = FloatAttr::get(element_type, sign); } @@ -284,8 +288,7 @@ std::unique_ptr> CreateOptimizeLCEPass( return std::make_unique(target); } -static PassRegistration pass( - "tfl-optimize-lce", "Optimize within the TensorFlow Lite dialect"); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/optimize_patterns_common.td b/larq_compute_engine/mlir/transforms/optimize_patterns_common.td index e48b4a735..c5d68a891 100644 --- a/larq_compute_engine/mlir/transforms/optimize_patterns_common.td +++ b/larq_compute_engine/mlir/transforms/optimize_patterns_common.td @@ -1,4 +1,5 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "larq_compute_engine/mlir/ir/lce_ops.td" @@ -14,7 +15,7 @@ class ConstantValue : AttrConstraint))), + (Arith_ConstantOp ConstantValue<"0.0f">))), (LQ_QuantizeOp $input), [(HasOneUse $ge_op)], (addBenefit 150)>; @@ -42,7 +43,7 @@ multiclass FuseAddOrSubWithBConv2D { $input, $filter, $post_activation_multiplier, - (ConstantOp F32ElementsAttr:$post_activation_bias), + (Arith_ConstantOp F32ElementsAttr:$post_activation_bias), $output_threshold, $channels_in, $dilation_height_factor, @@ -52,13 +53,13 @@ multiclass FuseAddOrSubWithBConv2D { $padding, $stride_height, $stride_width), - (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$value), TFL_AF_None), (LQ_Bconv2dOp $input, $filter, $post_activation_multiplier, - (binaryOp (ConstantOp $post_activation_bias), - (ConstantOp $value), TFL_AF_None), + (binaryOp (Arith_ConstantOp $post_activation_bias), + (Arith_ConstantOp $value), TFL_AF_None), $output_threshold, $channels_in, $dilation_height_factor, @@ -79,8 +80,8 @@ multiclass FuseMulOrDivWithBConv2D { (LQ_Bconv2dOp:$conv_output $input, $filter, - (ConstantOp F32ElementsAttr:$post_activation_multiplier), - (ConstantOp F32ElementsAttr:$post_activation_bias), + (Arith_ConstantOp F32ElementsAttr:$post_activation_multiplier), + (Arith_ConstantOp F32ElementsAttr:$post_activation_bias), $output_threshold, $channels_in, $dilation_height_factor, @@ -90,14 +91,14 @@ multiclass FuseMulOrDivWithBConv2D { $padding, $stride_height, $stride_width), - (ConstantOp F32ElementsAttr:$value), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$value), TFL_AF_None), (LQ_Bconv2dOp $input, $filter, - (binaryOp (ConstantOp $post_activation_multiplier), - (ConstantOp $value), TFL_AF_None), - (binaryOp (ConstantOp $post_activation_bias), - (ConstantOp $value), TFL_AF_None), + (binaryOp (Arith_ConstantOp $post_activation_multiplier), + (Arith_ConstantOp $value), TFL_AF_None), + (binaryOp (Arith_ConstantOp $post_activation_bias), + (Arith_ConstantOp $value), TFL_AF_None), $output_threshold, $channels_in, $dilation_height_factor, @@ -119,22 +120,22 @@ multiclass FuseActFnIntoConvOpPat { (LQ_Bconv2dOp:$conv_output $input, $filter, - (ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier), - (ConstantOp ConstantValue<"0.0f">:$post_activation_bias), + (Arith_ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier), + (Arith_ConstantOp ConstantValue<"0.0f">:$post_activation_bias), $output_threshold, $channels_in, $dilation_height_factor, $dilation_width_factor, TFL_AF_None, $pad_values, - ConstantAttr:$padding, + TFL_PAD_Valid:$padding, $stride_height, $stride_width)), (LQ_Bconv2dOp $input, $filter, - (ConstantOp $post_activation_multiplier), - (ConstantOp $post_activation_bias), + (Arith_ConstantOp $post_activation_multiplier), + (Arith_ConstantOp $post_activation_bias), $output_threshold, $channels_in, $dilation_height_factor, @@ -149,22 +150,22 @@ multiclass FuseActFnIntoConvOpPat { (LQ_Bconv2dOp:$conv_output $input, $filter, - (ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier), - (ConstantOp ConstantValue<"0.0f">:$post_activation_bias), + (Arith_ConstantOp ConstantValue<"1.0f">:$post_activation_multiplier), + (Arith_ConstantOp ConstantValue<"0.0f">:$post_activation_bias), $output_threshold, $channels_in, $dilation_height_factor, $dilation_width_factor, TFL_AF_None, ConstantAttr:$pad_values, - ConstantAttr:$padding, + TFL_PAD_Same:$padding, $stride_height, $stride_width)), (LQ_Bconv2dOp $input, $filter, - (ConstantOp $post_activation_multiplier), - (ConstantOp $post_activation_bias), + (Arith_ConstantOp $post_activation_multiplier), + (Arith_ConstantOp $post_activation_bias), $output_threshold, $channels_in, $dilation_height_factor, diff --git a/larq_compute_engine/mlir/transforms/padding.h b/larq_compute_engine/mlir/transforms/padding.h index 0dfb17948..69a8c7545 100644 --- a/larq_compute_engine/mlir/transforms/padding.h +++ b/larq_compute_engine/mlir/transforms/padding.h @@ -33,8 +33,8 @@ inline ShapeRefType GetShape4D(Value tensor) { inline bool IsSamePadding1D(DenseElementsAttr paddings, uint64_t dimension, int input_size, int output_size, int stride) { using compute_engine::core::CeilDiv; - int pad_before = paddings.getValue({dimension, 0}); - int pad_after = paddings.getValue({dimension, 1}); + int pad_before = paddings.getValues()[{dimension, 0}]; + int pad_after = paddings.getValues()[{dimension, 1}]; const int pad_total = pad_before + pad_after; return (output_size == CeilDiv(input_size, stride)) && (pad_before == (pad_total / 2)) && @@ -42,8 +42,8 @@ inline bool IsSamePadding1D(DenseElementsAttr paddings, uint64_t dimension, } inline bool IsNoPadding(DenseElementsAttr paddings, uint64_t dimension) { - return paddings.getValue({dimension, 0}) == 0 && - paddings.getValue({dimension, 1}) == 0; + return paddings.getValues()[{dimension, 0}] == 0 && + paddings.getValues()[{dimension, 1}] == 0; } } // namespace TFL diff --git a/larq_compute_engine/mlir/transforms/prepare_patterns_common.td b/larq_compute_engine/mlir/transforms/prepare_patterns_common.td index 36b6d4249..fc51e5b6f 100644 --- a/larq_compute_engine/mlir/transforms/prepare_patterns_common.td +++ b/larq_compute_engine/mlir/transforms/prepare_patterns_common.td @@ -1,4 +1,5 @@ include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" include "larq_compute_engine/mlir/ir/lce_ops.td" @@ -22,8 +23,8 @@ def CreateTFShapeOp : NativeCodeCall< multiclass QuantDequantPatterns { def : Pat<(SelectOp:$select_op $cond, - (ConstantOp ConstantValue<"1.0f">), - (ConstantOp ConstantValue<"-1.0f">)), + (Arith_ConstantOp ConstantValue<"1.0f">), + (Arith_ConstantOp ConstantValue<"-1.0f">)), (LQ_DequantizeOp (LQ_QuantizeOp (CreateTFBroadcastToOp @@ -36,8 +37,8 @@ multiclass QuantDequantPatterns { [], (addBenefit 100)>; def : Pat<(SelectOp:$select_op $cond, - (ConstantOp ConstantValue<"-1.0f">), - (ConstantOp ConstantValue<"1.0f">)), + (Arith_ConstantOp ConstantValue<"-1.0f">), + (Arith_ConstantOp ConstantValue<"1.0f">)), (LQ_DequantizeOp (LQ_QuantizeOp (CreateTFBroadcastToOp @@ -85,18 +86,18 @@ def BinaryFilter : Constraint>; def GetScaleVector : NativeCodeCall<"GetScaleVector($0)">; def GetNumChannels : NativeCodeCall<"GetNumChannels($_builder, $0)">; def ValidFilterShape : Constraint>; -def IsDataFormatNHWC : ConstantAttr; +def IsDataFormatNHWC : ConstantAttr; def CreateNoneAttrValue : NativeCodeCall<"$_builder.getUnitAttr()">; // All targets support this pattern with "VALID" padding, but only the "arm" // target supports it with "SAME" padding. -class PrepareBConvPadValue0Pat : +class PrepareBConvPadValue0Pat : Pat<(TF_Conv2DOp (LQ_DequantizeOp:$dequantized_input $input), - (ConstantOp:$filter_op $filter), + (Arith_ConstantOp:$filter_op $filter), IsIntList1XY1:$strides, $use_cudnn, - ConstantAttr:$padding, + padding_type:$padding, $explicit_padding, IsDataFormatNHWC:$data_format, IsIntList1XY1:$dilations), @@ -104,11 +105,11 @@ class PrepareBConvPadValue0Pat : $input, (TF_TransposeOp (TF_DivOp - (ConstantOp $filter), - (ConstantOp (GetScaleVector $filter))), - (ConstantOp ConstantAttr, "{3, 0, 1, 2}">)), - (ConstantOp (GetScaleVector $filter)), - (ConstantOp (GetConstantVector<"0.0f"> $filter)), + (Arith_ConstantOp $filter), + (Arith_ConstantOp (GetScaleVector $filter))), + (Arith_ConstantOp ConstantAttr, "{3, 0, 1, 2}">)), + (Arith_ConstantOp (GetScaleVector $filter)), + (Arith_ConstantOp (GetConstantVector<"0.0f"> $filter)), (ConstantOp (CreateNoneAttrValue)), (GetNumChannels $dequantized_input), ExtractI32At<1>:$dilations, @@ -121,7 +122,7 @@ class PrepareBConvPadValue0Pat : [(BinaryFilter $filter), (ValidFilterShape $dequantized_input, $filter_op)], (addBenefit 90)>; -def : PrepareBConvPadValue0Pat<"VALID">; +def : PrepareBConvPadValue0Pat; def ConstFloatValueIsOne : Constraint< CPred<"$0.isa() && " @@ -133,30 +134,30 @@ def SamePadding : Constraint>; def : Pat<(TF_Conv2DOp:$output (TF_PadV2Op (LQ_DequantizeOp:$dequantized_input $input), - (ConstantOp $paddings), - (ConstantOp $pad_values)), - (ConstantOp:$filter_op $filter), + (Arith_ConstantOp $paddings), + (Arith_ConstantOp $pad_values)), + (Arith_ConstantOp:$filter_op $filter), IsIntList1XY1:$strides, $use_cudnn, - ConstantAttr, + TFL_PAD_Valid, $explicit_padding, IsDataFormatNHWC:$data_format, IsIntList1XY1:$dilations), (LQ_Bconv2dOp $input, (TF_TransposeOp (TF_DivOp - (ConstantOp $filter), - (ConstantOp (GetScaleVector $filter))), - (ConstantOp ConstantAttr, "{3, 0, 1, 2}">)), - (ConstantOp (GetScaleVector $filter)), - (ConstantOp (GetConstantVector<"0.0f"> $filter)), + (Arith_ConstantOp $filter), + (Arith_ConstantOp (GetScaleVector $filter))), + (Arith_ConstantOp ConstantAttr, "{3, 0, 1, 2}">)), + (Arith_ConstantOp (GetScaleVector $filter)), + (Arith_ConstantOp (GetConstantVector<"0.0f"> $filter)), (ConstantOp (CreateNoneAttrValue)), (GetNumChannels $dequantized_input), ExtractI32At<1>:$dilations, ExtractI32At<2>:$dilations, TFL_AF_None, ConstantAttr, - ConstantAttr, + TFL_PAD_Same, ExtractI32At<1>:$strides, ExtractI32At<2>:$strides), [(BinaryFilter $filter), diff --git a/larq_compute_engine/mlir/transforms/prepare_patterns_target_arm.td b/larq_compute_engine/mlir/transforms/prepare_patterns_target_arm.td index 436391c31..c4e28fadf 100644 --- a/larq_compute_engine/mlir/transforms/prepare_patterns_target_arm.td +++ b/larq_compute_engine/mlir/transforms/prepare_patterns_target_arm.td @@ -1,4 +1,4 @@ include "larq_compute_engine/mlir/transforms/prepare_patterns_common.td" // On ARM we support 'same-zero' padding. -def : PrepareBConvPadValue0Pat<"SAME">; +def : PrepareBConvPadValue0Pat; diff --git a/larq_compute_engine/mlir/transforms/prepare_tf.cc b/larq_compute_engine/mlir/transforms/prepare_tf.cc index 907acd56b..635868346 100644 --- a/larq_compute_engine/mlir/transforms/prepare_tf.cc +++ b/larq_compute_engine/mlir/transforms/prepare_tf.cc @@ -20,6 +20,8 @@ using compute_engine::core::bitpacking_bitwidth; // Prepare LCE operations in functions for subsequent legalization. struct PrepareLCE : public PassWrapper { + llvm::StringRef getArgument() const final { return "tfl-prepare-lce"; } + llvm::StringRef getDescription() const final { return "Inject LCE Ops"; } PrepareLCE() = default; PrepareLCE(const PrepareLCE& pass) {} PrepareLCE(const LCETarget target) { target_.setValue(target); } @@ -61,7 +63,7 @@ DenseElementsAttr GetScaleVector(Attribute filter_attr) { std::vector scales(channels); for (std::size_t i = 0; i < channels; ++i) { - auto scale = std::abs(filter.getValue({0, 0, 0, i})); + auto scale = std::abs(filter.getValues()[{0, 0, 0, i}]); scales[i] = FloatAttr::get(element_type, scale); } @@ -83,10 +85,10 @@ bool IsBinaryFilter(Attribute filter_attr) { for (std::size_t w = 0; w < shape[1]; ++w) { for (std::size_t i = 0; i < shape[2]; ++i) { for (std::size_t o = 0; o < shape[3]; ++o) { - auto scale = filter.getValue({0, 0, 0, o}); + auto scale = filter.getValues()[{0, 0, 0, o}]; if (std::abs(scale) <= std::numeric_limits::epsilon()) return false; - auto value = filter.getValue({h, w, i, o}); + auto value = filter.getValues()[{h, w, i, o}]; if (std::abs(std::abs(value / scale) - 1.0f) > 0.005f) return false; } } @@ -194,7 +196,7 @@ std::unique_ptr> CreatePrepareLCEPass( return std::make_unique(target); } -static PassRegistration pass("tfl-prepare-lce", "Inject LCE Ops."); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/quantize.cc b/larq_compute_engine/mlir/transforms/quantize.cc index fdb1bf5b1..9f8d3e11a 100644 --- a/larq_compute_engine/mlir/transforms/quantize.cc +++ b/larq_compute_engine/mlir/transforms/quantize.cc @@ -16,6 +16,10 @@ namespace { // Applies quantization on the model in TFL dialect. struct LCEQuantizePass : public PassWrapper { + llvm::StringRef getArgument() const final { return "lce-quantize"; } + llvm::StringRef getDescription() const final { + return "Apply hybrid quantization on models in TensorFlow Lite dialect"; + } void runOnFunction() override; }; @@ -34,9 +38,7 @@ std::unique_ptr> CreateLCEQuantizePass() { return std::make_unique(); } -static PassRegistration pass( - "lce-quantize", - "Apply hybrid quantization on models in TensorFlow Lite dialect"); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/quantize_patterns.td b/larq_compute_engine/mlir/transforms/quantize_patterns.td index 40a1b4262..c6e239f42 100644 --- a/larq_compute_engine/mlir/transforms/quantize_patterns.td +++ b/larq_compute_engine/mlir/transforms/quantize_patterns.td @@ -69,3 +69,7 @@ def : Pat<(TFL_QuantizeOp def : Pat<(TFL_QuantizeOp (LQ_DequantizeOp:$output $input), $qtype), (LQ_DequantizeOp $input), [(HasOneUse $output)]>; + +def : Pat<(LQ_QuantizeOp (TFL_DequantizeOp:$output $input)), + (LQ_QuantizeOp $input), + [(HasOneUse $output)]>; diff --git a/larq_compute_engine/mlir/transforms/set_batch_size.cc b/larq_compute_engine/mlir/transforms/set_batch_size.cc index a4d8a5781..7ac56f068 100644 --- a/larq_compute_engine/mlir/transforms/set_batch_size.cc +++ b/larq_compute_engine/mlir/transforms/set_batch_size.cc @@ -23,6 +23,8 @@ mlir::Type SetBatchSize(mlir::Type type) { } struct SetBatchSizePass : public PassWrapper { + llvm::StringRef getArgument() const final { return "mlir-setbatchsize"; } + llvm::StringRef getDescription() const final { return "Set batch size to 1"; } void runOnFunction() override { FuncOp func = getFunction(); @@ -62,7 +64,6 @@ std::unique_ptr> CreateSetBatchSizePass() { return std::make_unique(); } -static PassRegistration pass("mlir-setbatchsize", - "Set batch size to 1"); +static PassRegistration pass; } // namespace mlir diff --git a/larq_compute_engine/mlir/transforms/translate_tflite.cc b/larq_compute_engine/mlir/transforms/translate_tflite.cc index 1fe799b77..a3efd4531 100644 --- a/larq_compute_engine/mlir/transforms/translate_tflite.cc +++ b/larq_compute_engine/mlir/transforms/translate_tflite.cc @@ -11,6 +11,10 @@ namespace TFL { namespace { struct TranslateToLCE : public PassWrapper { + llvm::StringRef getArgument() const final { return "lce-translate-tfl"; } + llvm::StringRef getDescription() const final { + return "Translate TFL custom ops to LCE ops"; + } void getDependentDialects(DialectRegistry& registry) const override { registry.insert(); } @@ -76,8 +80,7 @@ std::unique_ptr> CreateTranslateToLCEPass() { return std::make_unique(); } -static PassRegistration pass( - "lce-translate-tfl", "Translate TFL custom ops to LCE ops"); +static PassRegistration pass; } // namespace TFL } // namespace mlir diff --git a/larq_compute_engine/tests/end2end_test.py b/larq_compute_engine/tests/end2end_test.py index ab6af6ba3..c201b98fa 100644 --- a/larq_compute_engine/tests/end2end_test.py +++ b/larq_compute_engine/tests/end2end_test.py @@ -38,7 +38,9 @@ def dummy(x): use_bias=False, activation=activation, )(x) - x = tf.keras.layers.BatchNormalization(momentum=0.7)(x) + x = tf.keras.layers.BatchNormalization( + momentum=0.7, beta_initializer="random_normal" + )(x) return tf.keras.layers.add([x, shortcut]) return dummy @@ -235,10 +237,8 @@ def test_simple_model(dataset, conversion_function, model_cls): def test_int8_input_output( conversion_function, model_cls, inference_input_type, inference_output_type ): - if conversion_function == convert_keras_model_as_saved_model and ( - model_cls == toy_model or version.parse(tf.__version__) < version.parse("2.2") - ): - pytest.skip("convert_keras_model_as_saved_model currently fails in this case.") + if version.parse(tf.__version__) < version.parse("2.2"): + pytest.skip("TensorFlow 2.2 or newer is required for saved model conversion.") model_lce = conversion_function( model_cls(), diff --git a/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc b/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc index 6994bb911..852442502 100644 --- a/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc +++ b/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc @@ -71,4 +71,4 @@ TfLiteStatus LceBenchmarkTfLiteModel::Run(int argc, char** argv) { } } // namespace benchmark -} // namespace tflite \ No newline at end of file +} // namespace tflite diff --git a/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h b/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h index 34fb4b32c..32c76f580 100644 --- a/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h +++ b/larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.h @@ -44,4 +44,4 @@ class LceBenchmarkTfLiteModel : public BenchmarkTfLiteModel { } // namespace benchmark } // namespace tflite -#endif // COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_ \ No newline at end of file +#endif // COMPUTE_ENGINE_TFLITE_BENCHMARK_LCE_BENCHMARK_TFLITE_MODEL_H_ diff --git a/larq_compute_engine/tflite/build_make/Makefile b/larq_compute_engine/tflite/build_make/Makefile deleted file mode 100644 index b844f43f8..000000000 --- a/larq_compute_engine/tflite/build_make/Makefile +++ /dev/null @@ -1,194 +0,0 @@ -# -# This is based on -# tensorflow/tensorflow/lite/tools/make/Makefile -# -# The makefile will always be run from the root of the compute engine repository - -# Make uses /bin/sh by default, which is incompatible with the bashisms seen -# below. -SHELL := /bin/bash - -TF_DIR := third_party/tensorflow -TF_MAKEFILE_DIR := $(TF_DIR)/tensorflow/lite/tools/make - -ifeq ($(LCE_GEN_DIR),) -$(error Please set LCE_GEN_DIR to specify an output dir) -endif - -# Try to figure out the host system -HOST_OS := -ifeq ($(OS),Windows_NT) - HOST_OS = windows -else - UNAME_S := $(shell uname -s) - ifeq ($(UNAME_S),Linux) - HOST_OS := linux - endif - ifeq ($(UNAME_S),Darwin) - HOST_OS := osx - endif -endif - -HOST_ARCH := $(shell if uname -m | grep -q i[345678]86; then echo x86_32; else uname -m; fi) - -# Override these on the make command line to target a specific architecture. For example: -# make -f tensorflow/lite/tools/make/Makefile TARGET=rpi TARGET_ARCH=armv7l -TARGET := $(HOST_OS) -TARGET_ARCH := $(HOST_ARCH) - -#LCE: Removed the following includes. It is unclear what they were for. -#-I$(TF_MAKEFILE_DIR)/../../../../../ \ -#-I$(TF_MAKEFILE_DIR)/../../../../../../ \ -#-I$(OBJDIR) - -INCLUDES := \ --Ilarq_compute_engine/tflite/cc \ --I. \ --I$(TF_DIR) \ --I$(TF_MAKEFILE_DIR)/downloads/ \ --I$(TF_MAKEFILE_DIR)/downloads/eigen \ --I$(TF_MAKEFILE_DIR)/downloads/absl \ --I$(TF_MAKEFILE_DIR)/downloads/gemmlowp \ --I$(TF_MAKEFILE_DIR)/downloads/ruy \ --I$(TF_MAKEFILE_DIR)/downloads/neon_2_sse \ --I$(TF_MAKEFILE_DIR)/downloads/farmhash/src \ --I$(TF_MAKEFILE_DIR)/downloads/flatbuffers/include \ --I$(TF_MAKEFILE_DIR)/downloads/fp16/include -# This is at the end so any globally-installed frameworks like protobuf don't -# override local versions in the source tree. -INCLUDES += -I/usr/local/include - -# These are the default libraries needed, but they can be added to or -# overridden by the platform-specific settings in target makefiles. -LIBS := \ --lstdc++ \ --lpthread \ --lm \ --lz \ --ldl - -# There are no rules for compiling objects for the host system (since we don't -# generate things like the protobuf compiler that require that), so all of -# these settings are for the target compiler. -CFLAGS := -O3 -DNDEBUG -fPIC $(EXTRA_CFLAGS) -CXXFLAGS := $(CFLAGS) --std=c++14 $(EXTRA_CXXFLAGS) -LDOPTS := -L/usr/local/lib -ARFLAGS := -r -TARGET_TOOLCHAIN_PREFIX := -CC_PREFIX := - -# Added by LCE: -CXXFLAGS += -DTFLITE_WITH_RUY -BUILD_WITH_RUY_PROFILER ?= false -ifeq ($(BUILD_WITH_RUY_PROFILER),true) - CXXFLAGS += -DRUY_PROFILER -endif - -ifeq ($(HOST_OS),windows) -CXXFLAGS += -fext-numeric-literals -D__LITTLE_ENDIAN__ -endif - -# Auto-detect optimization opportunity if building natively. -ifeq ($(HOST_OS),$(TARGET)) -ifeq ($(HOST_ARCH),$(TARGET_ARCH)) -ifeq ($(TARGET_ARCH),armv7l) -ifneq ($(shell cat /proc/cpuinfo | grep Features | grep neon),) - ifneq ($(shell cat /proc/cpuinfo | grep Features | grep vfpv4),) - CXXFLAGS += -mfpu=neon-vfpv4 - else - CXXFLAGS += -mfpu=neon - endif -endif # ifeq ($(TARGET_ARCH),armv7l) -endif # ifeq ($(HOST_ARCH),$(TARGET_ARCH)) -endif # ifeq ($(HOST_OS),$(TARGET)) -endif - -# This library is the main target for this makefile. It will contain a minimal -# runtime that can be linked in to other programs. -CORE_LIB_NAME := libtensorflow-lite.a -BENCHMARK_LIB_NAME := benchmark-lib.a - -# What sources we want to compile, must be kept in sync with the main Bazel -# build files. - -LCE_CORE_SRCS := $(wildcard larq_compute_engine/tflite/kernels/*.cc) - -LCE_EXAMPLE_SRCS := \ - examples/lce_minimal.cc - -LCE_BENCHMARK_SRCS := \ - larq_compute_engine/tflite/benchmark/lce_benchmark_tflite_model.cc \ - larq_compute_engine/tflite/benchmark/lce_benchmark_main.cc - -# These target-specific makefiles should modify or replace options like -# CXXFLAGS or LIBS to work for a specific targeted architecture. All logic -# based on platforms or architectures should happen within these files, to -# keep this main makefile focused on the sources and dependencies. -include $(wildcard $(TF_MAKEFILE_DIR)/targets/*_makefile.inc) - -# Where compiled objects are stored. -TARGET_OUT_DIR ?= $(TARGET)_$(TARGET_ARCH) -GENDIR := $(TF_MAKEFILE_DIR)/gen/$(TARGET_OUT_DIR)/ -OBJDIR := $(GENDIR)obj/ -LIBDIR := $(GENDIR)lib/ -BINDIR := $(LCE_GEN_DIR)/$(TARGET_OUT_DIR)/ - -CORE_LIB := $(LIBDIR)$(CORE_LIB_NAME) -BENCHMARK_LIB := $(LIBDIR)$(BENCHMARK_LIB_NAME) -LCE_EXAMPLE_BINARY := $(BINDIR)lce_minimal -LCE_BENCHMARK_BINARY := $(BINDIR)lce_benchmark - -CXX := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}g++ -CC := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}gcc -AR := $(CC_PREFIX)${TARGET_TOOLCHAIN_PREFIX}ar - -LCE_CORE_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(patsubst %.cpp,%.o,$(LCE_CORE_SRCS))))) - -LCE_EXAMPLE_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(LCE_EXAMPLE_SRCS)))) - -LCE_BENCHMARK_OBJS := $(addprefix $(OBJDIR), \ -$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(LCE_BENCHMARK_SRCS)))) - -# The target that's compiled if there's no command-line arguments. -all: $(LCE_EXAMPLE_BINARY) $(LCE_BENCHMARK_BINARY) - -# For normal manually-created TensorFlow Lite C++ source files. -$(OBJDIR)%.o: %.cpp - @mkdir -p $(dir $@) - $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - -$(OBJDIR)%.o: %.cc - @mkdir -p $(dir $@) - $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - -# For normal manually-created TensorFlow Lite C source files. -$(OBJDIR)%.o: %.c - @mkdir -p $(dir $@) - $(CC) $(CFLAGS) $(INCLUDES) -c $< -o $@ -$(OBJDIR)%.o: %.cpp - @mkdir -p $(dir $@) - $(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@ - -$(LCE_EXAMPLE_BINARY): $(LCE_CORE_OBJS) $(LCE_EXAMPLE_OBJS) $(CORE_LIB) - @mkdir -p $(dir $@) - $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(LCE_EXAMPLE_BINARY) $(LCE_CORE_OBJS) $(LCE_EXAMPLE_OBJS) \ - $(LIBFLAGS) $(CORE_LIB) $(LDFLAGS) $(LIBS) - -$(LCE_BENCHMARK_BINARY): $(LCE_CORE_OBJS) $(LCE_BENCHMARK_OBJS) $(BENCHMARK_LIB) - @mkdir -p $(dir $@) - $(CXX) $(CXXFLAGS) $(INCLUDES) \ - -o $(LCE_BENCHMARK_BINARY) $(LCE_CORE_OBJS) $(LCE_BENCHMARK_OBJS) \ - $(LIBFLAGS) $(BENCHMARK_LIB) $(LDFLAGS) $(LIBS) - -# Gets rid of all generated files. -clean: - rm -rf $(TF_MAKEFILE_DIR)/gen - -# Gets rid of target files only, leaving the host alone. Also leaves the lib -# directory untouched deliberately, so we can persist multiple architectures -# across builds for iOS and Android. -cleantarget: - rm -rf $(OBJDIR) diff --git a/larq_compute_engine/tflite/build_make/build_lce.sh b/larq_compute_engine/tflite/build_make/build_lce.sh deleted file mode 100755 index d60f7ba94..000000000 --- a/larq_compute_engine/tflite/build_make/build_lce.sh +++ /dev/null @@ -1,139 +0,0 @@ -#!/bin/bash -set -e - -usage() -{ - echo "Usage: build_lqce.sh [--native] [--rpi] [--ios] [--aarch64] [--benchmark] [--clean] - ---native Build for the host platform ---rpi Compile for Raspberry Pi (32-bit armv7) ---ios Compile for iOS ---aarch64 Compile for Aarch64 (e.g. 64-bit Raspberry Pi) - -When building on a Raspberry Pi, it is advised to use the --rpi or --aarch64 options instead of --native, in order to set the correct compiler optimization flags. - -For cross-compiling, the relevant toolchains have to be installed using the systems package manager. -The --rpi option requires the arm-linux-gnueabihf toolchain. -The --aarch64 option requires the aarch64-linux-gnu toolchain. -The --ios option requires the iOS SDK. - ---benchmark Compile with RUY profiling enabled ---clean Delete intermediate build files - -If doing a benchmark build when you have previously built without --benchmark -then you should pass --clean to do a complete rebuild." -} - - -if [[ $# -eq 0 ]] ; then - usage - exit 0 -fi - -native=0 -rpi=0 -ios=0 -aarch64=0 -benchmark=0 -clean=0 - -while [ "$1" != "" ]; do - case $1 in - -n | --native) - native=1 - ;; - --rpi) - rpi=1 - ;; - --ios) - ios=1 - ;; - --aarch64) - aarch64=1 - ;; - -b | --benchmark) - benchmark=1 - ;; - -c | --clean) - clean=1 - ;; - -h | --help ) - usage - exit - ;; - * ) - usage - exit 1 - ;; - esac - shift -done - - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -ROOT_DIR="${SCRIPT_DIR}/../../.." -TF_DIR="${ROOT_DIR}/third_party/tensorflow" -LCE_MAKEFILE="larq_compute_engine/tflite/build_make/Makefile" -TF_GEN_DIR="${TF_DIR}/tensorflow/lite/tools/make/gen" -export LCE_GEN_DIR="${ROOT_DIR}/gen" # export to pass it to the Makefile - -# number of hyper threads -NUM_HYPERTHREADS=$(grep -c ^processor /proc/cpuinfo 2>/dev/null || sysctl -n hw.ncpu) - -export BUILD_WITH_RUY=true -if [ "$benchmark" == "1" ]; then - export BUILD_WITH_RUY_PROFILER=true -fi - -if [ "$clean" == "1" ]; then - echo " --> clean" - rm -rf ${TF_GEN_DIR} - rm -rf ${LCE_GEN_DIR} -fi - -# Check if dependencies need to be downloaded -if [ ! -d "${TF_DIR}/tensorflow/lite/tools/make/downloads" ]; then - ${TF_DIR}/tensorflow/lite/tools/make/download_dependencies.sh -fi - -if [ "$native" == "1" ]; then - echo " --> native build" - # Build the tflite lib (will automatically skip when up-to-date - # this line is taken from "tensorflow/lite/tools/make/build_lib.sh" - make -j ${NUM_HYPERTHREADS} BUILD_WITH_NNAPI=false -C "${TF_DIR}" -f tensorflow/lite/tools/make/Makefile - # Build compute-engine kernels and benchmark binary - make -j ${NUM_HYPERTHREADS} BUILD_WITH_NNAPI=false -C "${ROOT_DIR}" -f ${LCE_MAKEFILE} -fi - -if [ "$rpi" == "1" ]; then - echo " --> rpi build" - # Stored in gen/rpi_armv7l - # This line is taken form "tensorflow/lite/tools/make/build_rpi_lib.sh" - make -j ${NUM_HYPERTHREADS} TARGET=rpi -C "${TF_DIR}" -f tensorflow/lite/tools/make/Makefile - # Build compute-engine kernels and benchmark binary - make -j ${NUM_HYPERTHREADS} TARGET=rpi -C "${ROOT_DIR}" -f ${LCE_MAKEFILE} -fi - -if [ "$ios" == "1" ]; then - echo " --> ios build" - profiling_opt="" - if [ "$benchmark" == "1" ]; then - profiling_opt="-p" - fi - ${TF_DIR}/tensorflow/lite/tools/make/build_ios_universal_lib.sh $profiling_opt - IOS_ARCHS="x86_64 armv7 armv7s arm64" - for arch in $BUILD_ARCHS - do - # Stored in gen/ios_$arch - make -j ${NUM_HYPERTHREADS} TARGET=ios TARGET_ARCH=${arch} -C "${ROOT_DIR}" -f ${LCE_MAKEFILE} - done -fi - -if [ "$aarch64" == "1" ]; then - echo " --> aarch64 build" - # Stored in gen/aarch64_armv8-a - # This line is taken from "tensorflow/lite/tools/make/build_aarch64_lib.sh" - make -j ${NUM_HYPERTHREADS} TARGET=aarch64 -C "${TF_DIR}" -f tensorflow/lite/tools/make/Makefile - # Build compute-engine kernels and benchmark binary - make -j ${NUM_HYPERTHREADS} TARGET=aarch64 -C "${ROOT_DIR}" -f ${LCE_MAKEFILE} -fi diff --git a/larq_compute_engine/tflite/kernels/bconv2d.cc b/larq_compute_engine/tflite/kernels/bconv2d.cc index 179ffb078..4eded66a0 100644 --- a/larq_compute_engine/tflite/kernels/bconv2d.cc +++ b/larq_compute_engine/tflite/kernels/bconv2d.cc @@ -276,8 +276,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Allocate the im2col tensor if necessary if (op_data->im2col_id == kTensorNotAllocated) { context->AddTensors(context, 1, &op_data->im2col_id); - node->temporaries->data[op_data->im2col_index] = op_data->im2col_id; } + node->temporaries->data[op_data->im2col_index] = op_data->im2col_id; // Resize the im2col tensor const std::int32_t bitpacked_channels_in = diff --git a/larq_compute_engine/tflite/tests/interpreter_test.py b/larq_compute_engine/tflite/tests/interpreter_test.py index 5c6b4e354..848d7f937 100644 --- a/larq_compute_engine/tflite/tests/interpreter_test.py +++ b/larq_compute_engine/tflite/tests/interpreter_test.py @@ -49,15 +49,18 @@ def test_interpreter_multi_input(use_iterator): interpreter = Interpreter(converter.convert(), num_threads=2) assert interpreter.input_types == [np.float32, np.float32] assert interpreter.output_types == [np.float32, np.float32] - assert interpreter.input_shapes == [(1, 24, 24, 2), (1, 24, 24, 1)] - assert interpreter.output_shapes == [(1, 24 * 24 * 2), (1, 24 * 24 * 1)] + assert interpreter.input_shapes == [(1, 24, 24, 1), (1, 24, 24, 2)] + assert sorted(interpreter.output_shapes) == [(1, 24 * 24 * 1), (1, 24 * 24 * 2)] def input_fn(): if use_iterator: - return ([x, y] for x, y in zip(x_np, y_np)) - return [x_np, y_np] + return ([y, x] for x, y in zip(x_np, y_np)) + return [y_np, x_np] output_x, output_y = interpreter.predict(input_fn()) + # Output order is not deterministic, decide based on shape + if output_y.shape == expected_output_x.shape: + output_x, output_y = output_y, output_x np.testing.assert_allclose(output_x, expected_output_x) np.testing.assert_allclose(output_y, expected_output_y) diff --git a/third_party/install_android.sh b/third_party/install_android.sh index f6d50753a..f7bf05575 100755 --- a/third_party/install_android.sh +++ b/third_party/install_android.sh @@ -1,11 +1,14 @@ #!/usr/bin/env bash set -e +# **NOTE**: This requires Java 8 and won't work on never versions. See: +# https://stackoverflow.com/questions/46402772/failed-to-install-android-sdk-java-lang-noclassdeffounderror-javax-xml-bind-a + # default LCE Android Env. variables export ANDROID_SDK_URL="https://dl.google.com/android/repository/sdk-tools-linux-3859397.zip" export ANDROID_HOME="/tmp/lce_android" export ANDROID_VERSION=29 -export ANDROID_BUILD_TOOLS_VERSION=28.0.3 +export ANDROID_BUILD_TOOLS_VERSION=30.0.2 export ANDROID_NDK_VERSION=19.2.5345600 # download android SDK diff --git a/third_party/tensorflow b/third_party/tensorflow index 3aa40c3ce..3f878cff5 160000 --- a/third_party/tensorflow +++ b/third_party/tensorflow @@ -1 +1 @@ -Subproject commit 3aa40c3ce9d16eae296f086bc4ac4d62deb2affc +Subproject commit 3f878cff5b698b82eea85db2b60d65a2e320850e diff --git a/third_party/tensorflow_patches/fix_armhf_xnnpack.patch b/third_party/tensorflow_patches/fix_armhf_xnnpack.patch new file mode 100644 index 000000000..1c382716f --- /dev/null +++ b/third_party/tensorflow_patches/fix_armhf_xnnpack.patch @@ -0,0 +1,24 @@ +diff --git a/.bazelrc b/.bazelrc +index b1d4eee905d..ada1395298f 100644 +--- a/.bazelrc ++++ b/.bazelrc +@@ -572,6 +572,7 @@ build:elinux_aarch64 --distinct_host_configuration=true + build:elinux_armhf --config=elinux + build:elinux_armhf --cpu=armhf + build:elinux_armhf --distinct_host_configuration=true ++build:elinux_armhf --copt -mfp16-format=ieee + # END TF REMOTE BUILD EXECUTION OPTIONS + + # Config-specific options should come above this line. +diff --git a/tensorflow/lite/tools/evaluation/BUILD b/tensorflow/lite/tools/evaluation/BUILD +index 86bfa35bd1e..e22e9efd840 100644 +--- a/tensorflow/lite/tools/evaluation/BUILD ++++ b/tensorflow/lite/tools/evaluation/BUILD +@@ -53,7 +53,6 @@ cc_library( + ], + "//conditions:default": [], + }) + select({ +- "//tensorflow:linux_armhf": [], + "//tensorflow:linux_s390x": [], + "//conditions:default": [ + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", diff --git a/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/BUILD b/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/BUILD new file mode 100644 index 000000000..44172e9f5 --- /dev/null +++ b/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/BUILD @@ -0,0 +1,118 @@ +# This file is expanded from a template by cuda_configure.bzl +# Update cuda_configure.bzl#verify_build_defines when adding new variables. + +load(":cc_toolchain_config.bzl", "cc_toolchain_config") + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +toolchain( + name = "toolchain-linux-x86_64", + exec_compatible_with = [ + "@bazel_tools//platforms:linux", + "@bazel_tools//platforms:x86_64", + ], + target_compatible_with = [ + "@bazel_tools//platforms:linux", + "@bazel_tools//platforms:x86_64", + ], + toolchain = ":cc-compiler-local", + toolchain_type = "@bazel_tools//tools/cpp:toolchain_type", +) + +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "local|compiler": ":cc-compiler-local", + "darwin|compiler": ":cc-compiler-darwin", + "k8": ":cc-compiler-local", + "darwin": ":cc-compiler-darwin", + }, +) + +cc_toolchain( + name = "cc-compiler-local", + all_files = ":crosstool_wrapper_driver_is_not_gcc", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":crosstool_wrapper_driver_is_not_gcc", + objcopy_files = ":empty", + strip_files = ":empty", + # To support linker flags that need to go to the start of command line + # we need the toolchain to support parameter files. Parameter files are + # last on the command line and contain all shared libraries to link, so all + # regular options will be left of them. + supports_param_files = 1, + toolchain_config = ":cc-compiler-local-config", + toolchain_identifier = "local_linux", +) + +cc_toolchain_config( + name = "cc-compiler-local-config", + builtin_include_directories = [ + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/usr/local/cuda-11.2/targets/x86_64-linux/include", + "/usr/local/cuda-11.2/include", + "/usr/local/cuda-11.2/extras/CUPTI/include", + "/usr/include", + ], + cpu = "local", + extra_no_canonical_prefixes_flags = ["-fno-canonical-system-headers"], + host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc", + host_compiler_prefix = "/usr/bin", + host_compiler_warnings = [], + host_unfiltered_compile_flags = [], + linker_bin_path = "/usr/bin", +) + +cc_toolchain( + name = "cc-compiler-darwin", + all_files = ":crosstool_wrapper_driver_is_not_gcc", + compiler_files = ":empty", + dwp_files = ":empty", + linker_files = ":crosstool_wrapper_driver_is_not_gcc", + objcopy_files = ":empty", + strip_files = ":empty", + supports_param_files = 0, + toolchain_config = ":cc-compiler-local-darwin", + toolchain_identifier = "local_darwin", +) + +cc_toolchain_config( + name = "cc-compiler-local-darwin", + builtin_include_directories = [ + "/dt7/usr/include/c++/7", + "/dt7/usr/include/c++/7/x86_64-pc-linux-gnu", + "/dt7/usr/include/c++/7/backward", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include", + "/dt7/usr/lib/gcc/x86_64-pc-linux-gnu/7/include-fixed", + "/dt7/usr/include", + "/usr/local/cuda-11.2/targets/x86_64-linux/include", + "/usr/local/cuda-11.2/include", + "/usr/local/cuda-11.2/extras/CUPTI/include", + "/usr/include", + ], + cpu = "darwin", + extra_no_canonical_prefixes_flags = ["-fno-canonical-system-headers"], + host_compiler_path = "clang/bin/crosstool_wrapper_driver_is_not_gcc", + host_compiler_prefix = "/usr/bin", + host_compiler_warnings = [], + host_unfiltered_compile_flags = [], + linker_bin_path = "/usr/bin", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "crosstool_wrapper_driver_is_not_gcc", + srcs = ["clang/bin/crosstool_wrapper_driver_is_not_gcc"], +) diff --git a/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/cc_toolchain_config.bzl b/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/cc_toolchain_config.bzl new file mode 100644 index 000000000..ba002b454 --- /dev/null +++ b/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/cc_toolchain_config.bzl @@ -0,0 +1,1493 @@ +"""cc_toolchain_config rule for configuring CUDA toolchains on Linux, Mac, and Windows.""" + +load( + "@bazel_tools//tools/cpp:cc_toolchain_config_lib.bzl", + "action_config", + "env_entry", + "env_set", + "feature", + "feature_set", + "flag_group", + "flag_set", + "tool", + "tool_path", + "variable_with_value", +) +load( + "@bazel_tools//tools/build_defs/cc:action_names.bzl", + "ASSEMBLE_ACTION_NAME", + "CC_FLAGS_MAKE_VARIABLE_ACTION_NAME", + "CLIF_MATCH_ACTION_NAME", + "CPP_COMPILE_ACTION_NAME", + "CPP_HEADER_PARSING_ACTION_NAME", + "CPP_LINK_DYNAMIC_LIBRARY_ACTION_NAME", + "CPP_LINK_EXECUTABLE_ACTION_NAME", + "CPP_LINK_NODEPS_DYNAMIC_LIBRARY_ACTION_NAME", + "CPP_LINK_STATIC_LIBRARY_ACTION_NAME", + "CPP_MODULE_CODEGEN_ACTION_NAME", + "CPP_MODULE_COMPILE_ACTION_NAME", + "C_COMPILE_ACTION_NAME", + "LINKSTAMP_COMPILE_ACTION_NAME", + "LTO_BACKEND_ACTION_NAME", + "LTO_INDEXING_ACTION_NAME", + "OBJCPP_COMPILE_ACTION_NAME", + "OBJCPP_EXECUTABLE_ACTION_NAME", + "OBJC_ARCHIVE_ACTION_NAME", + "OBJC_COMPILE_ACTION_NAME", + "OBJC_EXECUTABLE_ACTION_NAME", + "OBJC_FULLY_LINK_ACTION_NAME", + "PREPROCESS_ASSEMBLE_ACTION_NAME", + "STRIP_ACTION_NAME", +) + +ACTION_NAMES = struct( + assemble = ASSEMBLE_ACTION_NAME, + c_compile = C_COMPILE_ACTION_NAME, + cc_flags_make_variable = CC_FLAGS_MAKE_VARIABLE_ACTION_NAME, + clif_match = CLIF_MATCH_ACTION_NAME, + cpp_compile = CPP_COMPILE_ACTION_NAME, + cpp_header_parsing = CPP_HEADER_PARSING_ACTION_NAME, + cpp_link_dynamic_library = CPP_LINK_DYNAMIC_LIBRARY_ACTION_NAME, + cpp_link_executable = CPP_LINK_EXECUTABLE_ACTION_NAME, + cpp_link_nodeps_dynamic_library = CPP_LINK_NODEPS_DYNAMIC_LIBRARY_ACTION_NAME, + cpp_link_static_library = CPP_LINK_STATIC_LIBRARY_ACTION_NAME, + cpp_module_codegen = CPP_MODULE_CODEGEN_ACTION_NAME, + cpp_module_compile = CPP_MODULE_COMPILE_ACTION_NAME, + ld_embed_data = "ld_embed_data", + linkstamp_compile = LINKSTAMP_COMPILE_ACTION_NAME, + lto_backend = LTO_BACKEND_ACTION_NAME, + lto_indexing = LTO_INDEXING_ACTION_NAME, + objc_archive = OBJC_ARCHIVE_ACTION_NAME, + objc_compile = OBJC_COMPILE_ACTION_NAME, + objc_executable = OBJC_EXECUTABLE_ACTION_NAME, + objc_fully_link = OBJC_FULLY_LINK_ACTION_NAME, + objcopy_embed_data = "objcopy_embed_data", + objcpp_compile = OBJCPP_COMPILE_ACTION_NAME, + objcpp_executable = OBJCPP_EXECUTABLE_ACTION_NAME, + preprocess_assemble = PREPROCESS_ASSEMBLE_ACTION_NAME, + strip = STRIP_ACTION_NAME, +) + +def _impl(ctx): + if (ctx.attr.cpu == "darwin"): + toolchain_identifier = "local_darwin" + elif (ctx.attr.cpu == "local"): + toolchain_identifier = "local_linux" + elif (ctx.attr.cpu == "x64_windows"): + toolchain_identifier = "local_windows" + else: + fail("Unreachable") + + host_system_name = "local" + + target_system_name = "local" + + if (ctx.attr.cpu == "darwin"): + target_cpu = "darwin" + elif (ctx.attr.cpu == "local"): + target_cpu = "local" + elif (ctx.attr.cpu == "x64_windows"): + target_cpu = "x64_windows" + else: + fail("Unreachable") + + if (ctx.attr.cpu == "local"): + target_libc = "local" + elif (ctx.attr.cpu == "darwin"): + target_libc = "macosx" + elif (ctx.attr.cpu == "x64_windows"): + target_libc = "msvcrt" + else: + fail("Unreachable") + + if (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + compiler = "compiler" + elif (ctx.attr.cpu == "x64_windows"): + compiler = "msvc-cl" + else: + fail("Unreachable") + + abi_version = "local" + + abi_libc_version = "local" + + cc_target_os = None + + builtin_sysroot = None + + all_link_actions = [ + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ] + + cpp_link_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_nodeps_dynamic_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_nodeps_dynamic_library, + implies = [ + "nologo", + "shared_flag", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + "has_configured_linker_path", + "def_file", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + cpp_link_static_library_action = action_config( + action_name = ACTION_NAMES.cpp_link_static_library, + implies = [ + "nologo", + "archiver_flags", + "input_param_flags", + "linker_param_file", + "msvc_env", + ], + tools = [tool(path = ctx.attr.msvc_lib_path)], + ) + + assemble_action = action_config( + action_name = ACTION_NAMES.assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + preprocess_assemble_action = action_config( + action_name = ACTION_NAMES.preprocess_assemble, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "sysroot", + ], + tools = [tool(path = ctx.attr.msvc_ml_path)], + ) + + c_compile_action = action_config( + action_name = ACTION_NAMES.c_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "parse_showincludes", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_compile_action = action_config( + action_name = ACTION_NAMES.cpp_compile, + implies = [ + "compiler_input_flags", + "compiler_output_flags", + "nologo", + "msvc_env", + "parse_showincludes", + "user_compile_flags", + "sysroot", + "unfiltered_compile_flags", + ], + tools = [tool(path = ctx.attr.msvc_cl_path)], + ) + + cpp_link_executable_action = action_config( + action_name = ACTION_NAMES.cpp_link_executable, + implies = [ + "nologo", + "linkstamps", + "output_execpath_flags", + "input_param_flags", + "user_link_flags", + "linker_subsystem_flag", + "linker_param_file", + "msvc_env", + "no_stripping", + ], + tools = [tool(path = ctx.attr.msvc_link_path)], + ) + + if (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + action_configs = [] + elif (ctx.attr.cpu == "x64_windows"): + action_configs = [ + assemble_action, + preprocess_assemble_action, + c_compile_action, + cpp_compile_action, + cpp_link_executable_action, + cpp_link_dynamic_library_action, + cpp_link_nodeps_dynamic_library_action, + cpp_link_static_library_action, + ] + else: + fail("Unreachable") + + no_windows_export_all_symbols_feature = feature(name = "no_windows_export_all_symbols") + + pic_feature = feature( + name = "pic", + enabled = True, + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group(flags = ["-fPIC"], expand_if_available = "pic"), + flag_group( + flags = ["-fPIE"], + expand_if_not_available = "pic", + ), + ], + ), + ], + ) + + preprocessor_defines_feature = feature( + name = "preprocessor_defines", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/D%{preprocessor_defines}"], + iterate_over = "preprocessor_defines", + ), + ], + ), + ], + ) + + generate_pdb_file_feature = feature( + name = "generate_pdb_file", + requires = [ + feature_set(features = ["dbg"]), + feature_set(features = ["fastbuild"]), + ], + ) + + linkstamps_feature = feature( + name = "linkstamps", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{linkstamp_paths}"], + iterate_over = "linkstamp_paths", + expand_if_available = "linkstamp_paths", + ), + ], + ), + ], + ) + + unfiltered_compile_flags_feature = feature( + name = "unfiltered_compile_flags", + flag_sets = ([ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ctx.attr.host_unfiltered_compile_flags, + ), + ], + ), + ] if ctx.attr.host_unfiltered_compile_flags else []), + ) + + determinism_feature = feature( + name = "determinism", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-Wno-builtin-macro-redefined", + "-D__DATE__=\"redacted\"", + "-D__TIMESTAMP__=\"redacted\"", + "-D__TIME__=\"redacted\"", + ], + ), + ], + ), + ], + ) + + nologo_feature = feature( + name = "nologo", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + flag_groups = [flag_group(flags = ["/nologo"])], + ), + ], + ) + + supports_pic_feature = feature(name = "supports_pic", enabled = True) + + output_execpath_flags_feature = feature( + name = "output_execpath_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + default_link_flags_feature = feature( + name = "default_link_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/MACHINE:X64"])], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + hardening_feature = feature( + name = "hardening", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-U_FORTIFY_SOURCE", + "-D_FORTIFY_SOURCE=1", + "-fstack-protector", + ], + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-Wl,-z,relro,-z,now"])], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_executable], + flag_groups = [flag_group(flags = ["-pie", "-Wl,-z,relro,-z,now"])], + ), + ], + ) + elif (ctx.attr.cpu == "darwin"): + hardening_feature = feature( + name = "hardening", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = [ + "-U_FORTIFY_SOURCE", + "-D_FORTIFY_SOURCE=1", + "-fstack-protector", + ], + ), + ], + ), + flag_set( + actions = [ACTION_NAMES.cpp_link_executable], + flag_groups = [flag_group(flags = ["-pie"])], + ), + ], + ) + else: + hardening_feature = None + + supports_dynamic_linker_feature = feature(name = "supports_dynamic_linker", enabled = True) + + targets_windows_feature = feature( + name = "targets_windows", + enabled = True, + implies = ["copy_dynamic_libraries_to_binary"], + ) + + msvc_env_feature = feature( + name = "msvc_env", + env_sets = [ + env_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_static_library, + ], + env_entries = [ + env_entry(key = "PATH", value = ctx.attr.msvc_env_path), + env_entry( + key = "INCLUDE", + value = ctx.attr.msvc_env_include, + ), + env_entry(key = "LIB", value = ctx.attr.msvc_env_lib), + env_entry(key = "TMP", value = ctx.attr.msvc_env_tmp), + env_entry(key = "TEMP", value = ctx.attr.msvc_env_tmp), + ], + ), + ], + ) + + linker_subsystem_flag_feature = feature( + name = "linker_subsystem_flag", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/SUBSYSTEM:CONSOLE"])], + ), + ], + ) + + dynamic_link_msvcrt_no_debug_feature = feature( + name = "dynamic_link_msvcrt_no_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MD"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrt.lib"])], + ), + ], + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ) + + warnings_feature = feature( + name = "warnings", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = ["-Wall"] + ctx.attr.host_compiler_warnings, + ), + ], + ), + ], + ) + + dynamic_link_msvcrt_debug_feature = feature( + name = "dynamic_link_msvcrt_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MDd"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:msvcrtd.lib"])], + ), + ], + requires = [feature_set(features = ["dbg"])], + ) + + compiler_output_flags_feature = feature( + name = "compiler_output_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.assemble], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}", "/Zi"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fo%{output_file}"], + expand_if_not_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + expand_if_not_available = "output_assembly_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/Fa%{output_file}"], + expand_if_available = "output_assembly_file", + ), + ], + expand_if_available = "output_file", + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["/P", "/Fi%{output_file}"], + expand_if_available = "output_preprocess_file", + ), + ], + expand_if_available = "output_file", + ), + ], + ), + ], + ) + + default_compile_flags_feature = feature( + name = "default_compile_flags", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.linkstamp_compile, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.lto_backend, + ACTION_NAMES.clif_match, + ], + flag_groups = [ + flag_group( + flags = [ + "/DCOMPILER_MSVC", + "/DNOMINMAX", + "/D_WIN32_WINNT=0x0600", + "/D_CRT_SECURE_NO_DEPRECATE", + "/D_CRT_SECURE_NO_WARNINGS", + "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS", + "/bigobj", + "/Zm500", + "/J", + "/Gy", + "/GF", + "/EHsc", + "/wd4351", + "/wd4291", + "/wd4250", + "/wd4996", + ], + ), + ], + ), + ], + ) + + static_link_msvcrt_debug_feature = feature( + name = "static_link_msvcrt_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MTd"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmtd.lib"])], + ), + ], + requires = [feature_set(features = ["dbg"])], + ) + + static_link_msvcrt_feature = feature(name = "static_link_msvcrt") + + if (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-g"])], + ), + ], + implies = ["common"], + ) + elif (ctx.attr.cpu == "x64_windows"): + dbg_feature = feature( + name = "dbg", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEBUG:FULL", "/INCREMENTAL:NO"])], + ), + ], + implies = ["generate_pdb_file"], + ) + else: + dbg_feature = None + + undefined_dynamic_feature = feature( + name = "undefined-dynamic", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [flag_group(flags = ["-undefined", "dynamic_lookup"])], + ), + ], + ) + + parse_showincludes_feature = feature( + name = "parse_showincludes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_header_parsing, + ], + flag_groups = [flag_group(flags = ["/showIncludes"])], + ), + ], + ) + + linker_param_file_feature = feature( + name = "linker_param_file", + flag_sets = [ + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["@%{linker_param_file}"], + expand_if_available = "linker_param_file", + ), + ], + ), + ], + ) + + static_link_msvcrt_no_debug_feature = feature( + name = "static_link_msvcrt_no_debug", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/MT"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["/DEFAULTLIB:libcmt.lib"])], + ), + ], + requires = [ + feature_set(features = ["fastbuild"]), + feature_set(features = ["opt"]), + ], + ) + + supports_interface_shared_libraries_feature = feature( + name = "supports_interface_shared_libraries", + enabled = True, + ) + + disable_assertions_feature = feature( + name = "disable-assertions", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-DNDEBUG"])], + ), + ], + ) + + if (ctx.attr.cpu == "x64_windows"): + fastbuild_feature = feature( + name = "fastbuild", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/Od", "/Z7", "/DDEBUG"])], + ), + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group(flags = ["/DEBUG:FASTLINK", "/INCREMENTAL:NO"]), + ], + ), + ], + implies = ["generate_pdb_file"], + ) + elif (ctx.attr.cpu == "darwin" or + ctx.attr.cpu == "local"): + fastbuild_feature = feature(name = "fastbuild", implies = ["common"]) + else: + fastbuild_feature = None + + user_compile_flags_feature = feature( + name = "user_compile_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["%{user_compile_flags}"], + iterate_over = "user_compile_flags", + expand_if_available = "user_compile_flags", + ), + ], + ), + ], + ) + + compiler_input_flags_feature = feature( + name = "compiler_input_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ], + flag_groups = [ + flag_group( + flags = ["/c", "%{source_file}"], + expand_if_available = "source_file", + ), + ], + ), + ], + ) + + no_legacy_features_feature = feature(name = "no_legacy_features") + + archiver_flags_feature = feature( + name = "archiver_flags", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + flags = ["/OUT:%{output_execpath}"], + expand_if_available = "output_execpath", + ), + ], + ), + ], + ) + + redirector_feature = feature( + name = "redirector", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ], + flag_groups = [ + flag_group( + flags = [ + "-B", + "external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py", + ], + ), + ], + ), + ], + ) + + linker_bin_path_feature = feature( + name = "linker-bin-path", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-B" + ctx.attr.linker_bin_path])], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = ["-g0", "-O2", "-ffunction-sections", "-fdata-sections"], + ), + ], + ), + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [flag_group(flags = ["-Wl,--gc-sections"])], + ), + ], + implies = ["common", "disable-assertions"], + ) + elif (ctx.attr.cpu == "darwin"): + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [ + flag_group( + flags = ["-g0", "-O2", "-ffunction-sections", "-fdata-sections"], + ), + ], + ), + ], + implies = ["common", "disable-assertions"], + ) + elif (ctx.attr.cpu == "x64_windows"): + opt_feature = feature( + name = "opt", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["/O2", "/DNDEBUG"])], + ), + ], + ) + else: + opt_feature = None + + include_paths_feature = feature( + name = "include_paths", + enabled = True, + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ], + flag_groups = [ + flag_group( + flags = ["/I%{quote_include_paths}"], + iterate_over = "quote_include_paths", + ), + flag_group( + flags = ["/I%{include_paths}"], + iterate_over = "include_paths", + ), + flag_group( + flags = ["/I%{system_include_paths}"], + iterate_over = "system_include_paths", + ), + ], + ), + ], + ) + + shared_flag_feature = feature( + name = "shared_flag", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["/DLL"])], + ), + ], + ) + + windows_export_all_symbols_feature = feature(name = "windows_export_all_symbols") + + frame_pointer_feature = feature( + name = "frame-pointer", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.c_compile, ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-fno-omit-frame-pointer"])], + ), + ], + ) + + build_id_feature = feature( + name = "build-id", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["-Wl,--build-id=md5", "-Wl,--hash-style=gnu"], + ), + ], + ), + ], + ) + + sysroot_feature = feature( + name = "sysroot", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.assemble, + ACTION_NAMES.preprocess_assemble, + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_header_parsing, + ACTION_NAMES.cpp_module_compile, + ACTION_NAMES.cpp_module_codegen, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["--sysroot=%{sysroot}"], + iterate_over = "sysroot", + expand_if_available = "sysroot", + ), + ], + ), + ], + ) + + def_file_feature = feature( + name = "def_file", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["/DEF:%{def_file_path}", "/ignore:4070"], + expand_if_available = "def_file_path", + ), + ], + ), + ], + ) + + if (ctx.attr.cpu == "darwin"): + stdlib_feature = feature( + name = "stdlib", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lc++"])], + ), + ], + ) + elif (ctx.attr.cpu == "local"): + stdlib_feature = feature( + name = "stdlib", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [flag_group(flags = ["-lstdc++"])], + ), + ], + ) + else: + stdlib_feature = None + + no_stripping_feature = feature(name = "no_stripping") + + alwayslink_feature = feature( + name = "alwayslink", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ACTION_NAMES.cpp_link_executable, + ], + flag_groups = [flag_group(flags = ["-Wl,-no-as-needed"])], + ), + ], + ) + + input_param_flags_feature = feature( + name = "input_param_flags", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = ["/IMPLIB:%{interface_library_output_path}"], + expand_if_available = "interface_library_output_path", + ), + ], + ), + flag_set( + actions = all_link_actions + + [ACTION_NAMES.cpp_link_static_library], + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link", + flag_groups = [ + flag_group( + iterate_over = "libraries_to_link.object_files", + flag_groups = [flag_group(flags = ["%{libraries_to_link.object_files}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file_group", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "object_file", + ), + ), + flag_group( + flag_groups = [flag_group(flags = ["%{libraries_to_link.name}"])], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "interface_library", + ), + ), + flag_group( + flag_groups = [ + flag_group( + flags = ["%{libraries_to_link.name}"], + expand_if_false = "libraries_to_link.is_whole_archive", + ), + flag_group( + flags = ["/WHOLEARCHIVE:%{libraries_to_link.name}"], + expand_if_true = "libraries_to_link.is_whole_archive", + ), + ], + expand_if_equal = variable_with_value( + name = "libraries_to_link.type", + value = "static_library", + ), + ), + ], + expand_if_available = "libraries_to_link", + ), + ], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + no_canonical_prefixes_feature = feature( + name = "no-canonical-prefixes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [ + flag_group( + flags = [ + "-no-canonical-prefixes", + ] + ctx.attr.extra_no_canonical_prefixes_flags, + ), + ], + ), + ], + ) + elif (ctx.attr.cpu == "darwin"): + no_canonical_prefixes_feature = feature( + name = "no-canonical-prefixes", + flag_sets = [ + flag_set( + actions = [ + ACTION_NAMES.c_compile, + ACTION_NAMES.cpp_compile, + ACTION_NAMES.cpp_link_executable, + ACTION_NAMES.cpp_link_dynamic_library, + ACTION_NAMES.cpp_link_nodeps_dynamic_library, + ], + flag_groups = [flag_group(flags = ["-no-canonical-prefixes"])], + ), + ], + ) + else: + no_canonical_prefixes_feature = None + + has_configured_linker_path_feature = feature(name = "has_configured_linker_path") + + copy_dynamic_libraries_to_binary_feature = feature(name = "copy_dynamic_libraries_to_binary") + + user_link_flags_feature = feature( + name = "user_link_flags", + flag_sets = [ + flag_set( + actions = all_link_actions, + flag_groups = [ + flag_group( + flags = ["%{user_link_flags}"], + iterate_over = "user_link_flags", + expand_if_available = "user_link_flags", + ), + ], + ), + ], + ) + + cpp11_feature = feature( + name = "c++11", + flag_sets = [ + flag_set( + actions = [ACTION_NAMES.cpp_compile], + flag_groups = [flag_group(flags = ["-std=c++11"])], + ), + ], + ) + + if (ctx.attr.cpu == "local"): + common_feature = feature( + name = "common", + implies = [ + "stdlib", + "c++11", + "determinism", + "alwayslink", + "hardening", + "warnings", + "frame-pointer", + "build-id", + "no-canonical-prefixes", + "linker-bin-path", + ], + ) + elif (ctx.attr.cpu == "darwin"): + common_feature = feature( + name = "common", + implies = [ + "stdlib", + "c++11", + "determinism", + "hardening", + "warnings", + "frame-pointer", + "no-canonical-prefixes", + "linker-bin-path", + "undefined-dynamic", + ], + ) + else: + common_feature = None + + if (ctx.attr.cpu == "local"): + features = [ + cpp11_feature, + stdlib_feature, + determinism_feature, + alwayslink_feature, + pic_feature, + hardening_feature, + warnings_feature, + frame_pointer_feature, + build_id_feature, + no_canonical_prefixes_feature, + disable_assertions_feature, + linker_bin_path_feature, + common_feature, + opt_feature, + fastbuild_feature, + dbg_feature, + supports_dynamic_linker_feature, + supports_pic_feature, + ] + elif (ctx.attr.cpu == "darwin"): + features = [ + cpp11_feature, + stdlib_feature, + determinism_feature, + pic_feature, + hardening_feature, + warnings_feature, + frame_pointer_feature, + no_canonical_prefixes_feature, + disable_assertions_feature, + linker_bin_path_feature, + undefined_dynamic_feature, + common_feature, + opt_feature, + fastbuild_feature, + dbg_feature, + supports_dynamic_linker_feature, + supports_pic_feature, + ] + elif (ctx.attr.cpu == "x64_windows"): + features = [ + no_legacy_features_feature, + redirector_feature, + nologo_feature, + has_configured_linker_path_feature, + no_stripping_feature, + targets_windows_feature, + copy_dynamic_libraries_to_binary_feature, + default_compile_flags_feature, + msvc_env_feature, + include_paths_feature, + preprocessor_defines_feature, + parse_showincludes_feature, + generate_pdb_file_feature, + shared_flag_feature, + linkstamps_feature, + output_execpath_flags_feature, + archiver_flags_feature, + input_param_flags_feature, + linker_subsystem_flag_feature, + user_link_flags_feature, + default_link_flags_feature, + linker_param_file_feature, + static_link_msvcrt_feature, + static_link_msvcrt_no_debug_feature, + dynamic_link_msvcrt_no_debug_feature, + static_link_msvcrt_debug_feature, + dynamic_link_msvcrt_debug_feature, + dbg_feature, + fastbuild_feature, + opt_feature, + user_compile_flags_feature, + sysroot_feature, + unfiltered_compile_flags_feature, + compiler_output_flags_feature, + compiler_input_flags_feature, + def_file_feature, + windows_export_all_symbols_feature, + no_windows_export_all_symbols_feature, + supports_dynamic_linker_feature, + supports_interface_shared_libraries_feature, + ] + else: + fail("Unreachable") + + cxx_builtin_include_directories = ctx.attr.builtin_include_directories + + if (ctx.attr.cpu == "x64_windows"): + tool_paths = [ + tool_path(name = "ar", path = ctx.attr.msvc_lib_path), + tool_path(name = "ml", path = ctx.attr.msvc_ml_path), + tool_path(name = "cpp", path = ctx.attr.msvc_cl_path), + tool_path(name = "gcc", path = ctx.attr.msvc_cl_path), + tool_path(name = "gcov", path = "wrapper/bin/msvc_nop.bat"), + tool_path(name = "ld", path = ctx.attr.msvc_link_path), + tool_path(name = "nm", path = "wrapper/bin/msvc_nop.bat"), + tool_path( + name = "objcopy", + path = "wrapper/bin/msvc_nop.bat", + ), + tool_path( + name = "objdump", + path = "wrapper/bin/msvc_nop.bat", + ), + tool_path( + name = "strip", + path = "wrapper/bin/msvc_nop.bat", + ), + ] + elif (ctx.attr.cpu == "local"): + tool_paths = [ + tool_path(name = "gcc", path = ctx.attr.host_compiler_path), + tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/ar"), + tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), + tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), + tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), + tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), + tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), + tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), + tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), + ] + elif (ctx.attr.cpu == "darwin"): + tool_paths = [ + tool_path(name = "gcc", path = ctx.attr.host_compiler_path), + tool_path(name = "ar", path = ctx.attr.host_compiler_prefix + "/libtool"), + tool_path(name = "compat-ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "cpp", path = ctx.attr.host_compiler_prefix + "/cpp"), + tool_path(name = "dwp", path = ctx.attr.host_compiler_prefix + "/dwp"), + tool_path(name = "gcov", path = ctx.attr.host_compiler_prefix + "/gcov"), + tool_path(name = "ld", path = ctx.attr.host_compiler_prefix + "/ld"), + tool_path(name = "nm", path = ctx.attr.host_compiler_prefix + "/nm"), + tool_path(name = "objcopy", path = ctx.attr.host_compiler_prefix + "/objcopy"), + tool_path(name = "objdump", path = ctx.attr.host_compiler_prefix + "/objdump"), + tool_path(name = "strip", path = ctx.attr.host_compiler_prefix + "/strip"), + ] + else: + fail("Unreachable") + + out = ctx.actions.declare_file(ctx.label.name) + ctx.actions.write(out, "Fake executable") + return [ + cc_common.create_cc_toolchain_config_info( + ctx = ctx, + features = features, + action_configs = action_configs, + artifact_name_patterns = [], + cxx_builtin_include_directories = cxx_builtin_include_directories, + toolchain_identifier = toolchain_identifier, + host_system_name = host_system_name, + target_system_name = target_system_name, + target_cpu = target_cpu, + target_libc = target_libc, + compiler = compiler, + abi_version = abi_version, + abi_libc_version = abi_libc_version, + tool_paths = tool_paths, + make_variables = [], + builtin_sysroot = builtin_sysroot, + cc_target_os = cc_target_os, + ), + DefaultInfo( + executable = out, + ), + ] + +cc_toolchain_config = rule( + attrs = { + "cpu": attr.string( + mandatory = True, + values = [ + "darwin", + "local", + "x64_windows", + ], + ), + "builtin_include_directories": attr.string_list(), + "extra_no_canonical_prefixes_flags": attr.string_list(), + "host_compiler_path": attr.string(), + "host_compiler_prefix": attr.string(), + "host_compiler_warnings": attr.string_list(), + "host_unfiltered_compile_flags": attr.string_list(), + "linker_bin_path": attr.string(), + "msvc_cl_path": attr.string(default = "msvc_not_used"), + "msvc_env_include": attr.string(default = "msvc_not_used"), + "msvc_env_lib": attr.string(default = "msvc_not_used"), + "msvc_env_path": attr.string(default = "msvc_not_used"), + "msvc_env_tmp": attr.string(default = "msvc_not_used"), + "msvc_lib_path": attr.string(default = "msvc_not_used"), + "msvc_link_path": attr.string(default = "msvc_not_used"), + "msvc_ml_path": attr.string(default = "msvc_not_used"), + }, + executable = True, + provides = [CcToolchainConfigInfo], + implementation = _impl, +) diff --git a/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/clang/bin/crosstool_wrapper_driver_is_not_gcc b/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/clang/bin/crosstool_wrapper_driver_is_not_gcc new file mode 100755 index 000000000..01b454807 --- /dev/null +++ b/third_party/toolchains/gcc7_manylinux2010-nvcc-cuda11/clang/bin/crosstool_wrapper_driver_is_not_gcc @@ -0,0 +1,281 @@ +#!/usr/bin/env python +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Crosstool wrapper for compiling CUDA programs. + +SYNOPSIS: + crosstool_wrapper_is_not_gcc [options passed in by cc_library() + or cc_binary() rule] + +DESCRIPTION: + This script is expected to be called by the cc_library() or cc_binary() bazel + rules. When the option "-x cuda" is present in the list of arguments passed + to this script, it invokes the nvcc CUDA compiler. Most arguments are passed + as is as a string to --compiler-options of nvcc. When "-x cuda" is not + present, this wrapper invokes hybrid_driver_is_not_gcc with the input + arguments as is. +""" + +from __future__ import print_function + +from argparse import ArgumentParser +import os +import subprocess +import re +import sys +import pipes + +# Template values set by cuda_autoconf. +CPU_COMPILER = ('/dt7/usr/bin/gcc') +GCC_HOST_COMPILER_PATH = ('/dt7/usr/bin/gcc') + +NVCC_PATH = '/usr/local/cuda-11.2/bin/nvcc' +PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH) +NVCC_VERSION = '11.2' + +def Log(s): + print('gpus/crosstool: {0}'.format(s)) + + +def GetOptionValue(argv, option): + """Extract the list of values for option from the argv list. + + Args: + argv: A list of strings, possibly the argv passed to main(). + option: The option whose value to extract, with the leading '-'. + + Returns: + A list of values, either directly following the option, + (eg., -opt val1 val2) or values collected from multiple occurrences of + the option (eg., -opt val1 -opt val2). + """ + + parser = ArgumentParser() + parser.add_argument(option, nargs='*', action='append') + option = option.lstrip('-').replace('-', '_') + args, _ = parser.parse_known_args(argv) + if not args or not vars(args)[option]: + return [] + else: + return sum(vars(args)[option], []) + + +def GetHostCompilerOptions(argv): + """Collect the -isystem, -iquote, and --sysroot option values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + The string that can be used as the --compiler-options to nvcc. + """ + + parser = ArgumentParser() + parser.add_argument('-isystem', nargs='*', action='append') + parser.add_argument('-iquote', nargs='*', action='append') + parser.add_argument('--sysroot', nargs=1) + parser.add_argument('-g', nargs='*', action='append') + parser.add_argument('-fno-canonical-system-headers', action='store_true') + parser.add_argument('-no-canonical-prefixes', action='store_true') + + args, _ = parser.parse_known_args(argv) + + opts = '' + + if args.isystem: + opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, [])) + if args.iquote: + opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, [])) + if args.g: + opts += ' -g' + ' -g'.join(sum(args.g, [])) + if args.fno_canonical_system_headers: + opts += ' -fno-canonical-system-headers' + if args.no_canonical_prefixes: + opts += ' -no-canonical-prefixes' + if args.sysroot: + opts += ' --sysroot ' + args.sysroot[0] + + return opts + +def _update_options(nvcc_options): + if NVCC_VERSION in ("7.0",): + return nvcc_options + + update_options = { "relaxed-constexpr" : "expt-relaxed-constexpr" } + return [ update_options[opt] if opt in update_options else opt + for opt in nvcc_options ] + +def GetNvccOptions(argv): + """Collect the -nvcc_options values from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + + Returns: + The string that can be passed directly to nvcc. + """ + + parser = ArgumentParser() + parser.add_argument('-nvcc_options', nargs='*', action='append') + + args, _ = parser.parse_known_args(argv) + + if args.nvcc_options: + options = _update_options(sum(args.nvcc_options, [])) + return ' '.join(['--'+a for a in options]) + return '' + +def system(cmd): + """Invokes cmd with os.system(). + + Args: + cmd: The command. + + Returns: + The exit code if the process exited with exit() or -signal + if the process was terminated by a signal. + """ + retv = os.system(cmd) + if os.WIFEXITED(retv): + return os.WEXITSTATUS(retv) + else: + return -os.WTERMSIG(retv) + +def InvokeNvcc(argv, log=False): + """Call nvcc with arguments assembled from argv. + + Args: + argv: A list of strings, possibly the argv passed to main(). + log: True if logging is requested. + + Returns: + The return value of calling system('nvcc ' + args) + """ + + host_compiler_options = GetHostCompilerOptions(argv) + nvcc_compiler_options = GetNvccOptions(argv) + opt_option = GetOptionValue(argv, '-O') + m_options = GetOptionValue(argv, '-m') + m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']]) + include_options = GetOptionValue(argv, '-I') + out_file = GetOptionValue(argv, '-o') + depfiles = GetOptionValue(argv, '-MF') + defines = GetOptionValue(argv, '-D') + defines = ''.join([' -D' + define for define in defines]) + undefines = GetOptionValue(argv, '-U') + undefines = ''.join([' -U' + define for define in undefines]) + std_options = GetOptionValue(argv, '-std') + # Supported -std flags as of CUDA 9.0. Only keep last to mimic gcc/clang. + nvcc_allowed_std_options = ["c++03", "c++11", "c++14"] + std_options = ''.join([' -std=' + define + for define in std_options if define in nvcc_allowed_std_options][-1:]) + fatbin_options = ''.join([' --fatbin-options=' + option + for option in GetOptionValue(argv, '-Xcuda-fatbinary')]) + + # The list of source files get passed after the -c option. I don't know of + # any other reliable way to just get the list of source files to be compiled. + src_files = GetOptionValue(argv, '-c') + + # Pass -w through from host to nvcc, but don't do anything fancier with + # warnings-related flags, since they're not necessarily the same across + # compilers. + warning_options = ' -w' if '-w' in argv else '' + + if len(src_files) == 0: + return 1 + if len(out_file) != 1: + return 1 + + opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0) + else ' -g') + + includes = (' -I ' + ' -I '.join(include_options) + if len(include_options) > 0 + else '') + + # Unfortunately, there are other options that have -c prefix too. + # So allowing only those look like C/C++ files. + src_files = [f for f in src_files if + re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)] + srcs = ' '.join(src_files) + out = ' -o ' + out_file[0] + + nvccopts = '-D_FORCE_INLINES ' + for capability in GetOptionValue(argv, "--cuda-gpu-arch"): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s\" ' % (capability, + capability) + for capability in GetOptionValue(argv, '--cuda-include-ptx'): + capability = capability[len('sm_'):] + nvccopts += r'-gencode=arch=compute_%s,\"code=compute_%s\" ' % (capability, + capability) + nvccopts += nvcc_compiler_options + nvccopts += undefines + nvccopts += defines + nvccopts += std_options + nvccopts += m_options + nvccopts += warning_options + nvccopts += fatbin_options + + if depfiles: + # Generate the dependency file + depfile = depfiles[0] + cmd = (NVCC_PATH + ' ' + nvccopts + + ' --compiler-options "' + host_compiler_options + '"' + + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' -I .' + + ' -x cu ' + opt + includes + ' ' + srcs + ' -M -o ' + depfile) + if log: Log(cmd) + exit_status = system(cmd) + if exit_status != 0: + return exit_status + + cmd = (NVCC_PATH + ' ' + nvccopts + + ' --compiler-options "' + host_compiler_options + ' -fPIC"' + + ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH + + ' -I .' + + ' -x cu ' + opt + includes + ' -c ' + srcs + out) + + # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'. + # Need to investigate and fix. + cmd = 'PATH=' + PREFIX_DIR + ':$PATH ' + cmd + if log: Log(cmd) + return system(cmd) + + +def main(): + parser = ArgumentParser() + parser.add_argument('-x', nargs=1) + parser.add_argument('--cuda_log', action='store_true') + args, leftover = parser.parse_known_args(sys.argv[1:]) + + if args.x and args.x[0] == 'cuda': + if args.cuda_log: Log('-x cuda') + leftover = [pipes.quote(s) for s in leftover] + if args.cuda_log: Log('using nvcc') + return InvokeNvcc(leftover, log=args.cuda_log) + + # Strip our flags before passing through to the CPU compiler for files which + # are not -x cuda. We can't just pass 'leftover' because it also strips -x. + # We not only want to pass -x to the CPU compiler, but also keep it in its + # relative location in the argv list (the compiler is actually sensitive to + # this). + cpu_compiler_flags = [flag for flag in sys.argv[1:] + if not flag.startswith(('--cuda_log'))] + + return subprocess.call([CPU_COMPILER] + cpu_compiler_flags) + +if __name__ == '__main__': + sys.exit(main())