From f810ee8ac92fc73dbc1309b1f0232e5617faca5b Mon Sep 17 00:00:00 2001 From: shardy authors Date: Mon, 30 Sep 2024 06:14:44 -0700 Subject: [PATCH] Use llvm error instead of asserts. This ensures the check isn't optimized away by the compiler. PiperOrigin-RevId: 680546512 --- docs/sdy_dialect.md | 57 + docs/sdy_export_passes.md | 49 + shardy/dialect/sdy/ir/dialect.cc | 4 +- shardy/dialect/sdy/ir/ops.td | 46 + .../test/named_computation_parse_print.mlir | 24 + .../test/named_computation_verification.mlir | 41 + shardy/dialect/sdy/ir/verifiers.cc | 40 + shardy/dialect/sdy/transforms/export/BUILD | 1 + .../sdy/transforms/export/export_pipeline.cc | 1 + .../export/insert_explicit_reshards.cc | 42 + .../dialect/sdy/transforms/export/passes.td | 52 + .../export/test/insert_explicit_reshards.mlir | 20 + .../sdy/transforms/import/import_pipeline.cc | 3 + .../import/sharding_group_import.cc | 137 +- .../import/test/import_pipeline.mlir | 56 + .../test/sharding_group_constraints.mlir | 73 + .../import/test/sharding_group_import.mlir | 104 + .../propagation/auto_partitioner_registry.cc | 17 +- .../propagation/auto_partitioner_registry.h | 2 +- .../propagation/op_sharding_rule_registry.cc | 27 +- .../test/op_sharding_rule_registry.mlir | 31 +- third_party/llvm/generated.patch | 4094 ----------------- third_party/llvm/workspace.bzl | 4 +- third_party/stablehlo/temporary.patch | 616 +++ third_party/stablehlo/workspace.bzl | 4 +- 25 files changed, 1386 insertions(+), 4159 deletions(-) create mode 100644 shardy/dialect/sdy/ir/test/named_computation_parse_print.mlir create mode 100644 shardy/dialect/sdy/ir/test/named_computation_verification.mlir create mode 100644 shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc create mode 100644 shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir diff --git a/docs/sdy_dialect.md b/docs/sdy_dialect.md index 5442c1e..4719975 100755 --- a/docs/sdy_dialect.md +++ b/docs/sdy_dialect.md @@ -215,6 +215,63 @@ Interfaces: `Symbol` +### `sdy.named_computation` (sdy::NamedComputationOp) + +_Named computation operation_ + + +Syntax: + +``` +operation ::= `sdy.named_computation` `<`$name`>` `` `(` $operands `)` + custom($body) + attr-dict + `:` functional-type($operands, results) +``` + +Groups a computation, i.e. a block of operations, and gives it a name. +Propagation will flow in/out of the region as if everything was inlined. + +This can be used to handle propagating through call instructions to other +functions. Any users of Shardy should write an import/export pass that +converts their call ops to `sdy.named_computation` ops, duplicating/copying +the body of the called function into the body of the `named_computation`. + +The type of each block arguments and returned values in the region must be +the same as the type of the operands and results type of the op. + +Example: + +```mlir +%1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) { + sdy.return %arg1 : tensor<16x32xf32> +} : (tensor<16x32xf32>) -> tensor<16x32xf32> +``` + +Traits: `IsolatedFromAbove`, `RecursiveMemoryEffects`, `RecursivelySpeculatableImplTrait`, `SingleBlockImplicitTerminator`, `SingleBlock` + +Interfaces: `ConditionallySpeculatable` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
name::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `operands` | variadic of any type + +#### Results: + +| Result | Description | +| :----: | ----------- | +«unnamed» | variadic of any type + + ### `sdy.propagation_barrier` (sdy::PropagationBarrierOp) _Propagation barrier operation_ diff --git a/docs/sdy_export_passes.md b/docs/sdy_export_passes.md index 7a7e3ef..d5f678f 100755 --- a/docs/sdy_export_passes.md +++ b/docs/sdy_export_passes.md @@ -1,4 +1,53 @@ +### `-sdy-insert-explicit-reshards` + +_Inserts explicit reshards to make all operations have compatible shardings._ + +A compatible sharding essentially means that the operation can accept the +sharded operands and produce a sharded result without requiring any reshard +communications (note that the operation might still require communication +such as all-reduce or halo-swaps). + +After propagation, some opeartions may still have incompatible shardings. + +Please note, when an axis (or sub-axis) is used to shard non-corresponding +dimensions (e.g. non-contracting dimensions in matmul) across multiple +tensors, or when an axis shards a dimension in one tensor but not the +corresponding dimension in the other tensor, it is said that the operation +has a sharding conflict. Hence, after this pass, the opeartions become +conflict-free. + +This pass injects reshard operations explicitly so that, for each operation, +corresponding dimensions become sharded in the same way across all operands +and results, and every axis (or sub-axis) can only be used to shard a single +dimension type. + +A clarifying example: + +Input: +```mlir +mesh = <"x"=4, "y"=2> +%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>} +%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} +stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} + : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> +``` + +Output: +```mlir +sdy.mesh = <"x"=4, "y"=2> +%lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>} +%rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} +%0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32> +stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} + : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> +``` + +In the example above, there is a conflict since `lhs` and `rhs` tensors +are both sharded on axis "x" on their non-contracting dimensions. Here, +`rhs` tensor is resharded, before the dot operation, explicitly to be +sharded only on its first dimension and on axis "x". This way, the dot +opearation becomes compatible. ### `-sdy-sharding-constraint-to-reshard` _Converts ShardingConstraintOp into ReshardOp._ diff --git a/shardy/dialect/sdy/ir/dialect.cc b/shardy/dialect/sdy/ir/dialect.cc index 00cf709..1fa27b7 100644 --- a/shardy/dialect/sdy/ir/dialect.cc +++ b/shardy/dialect/sdy/ir/dialect.cc @@ -59,8 +59,8 @@ struct ShardyDialectInlinerInterface : public DialectInlinerInterface { return true; } - // ManualComputationOp is an op with a region, and it should be allowed to be - // inlined into another op. + // `ManualComputationOp` and `NamedComputationOp` are ops with a region, and + // it should be allowed to be inlined into another op. bool isLegalToInline(Region*, Region*, bool, IRMapping&) const final { return true; } diff --git a/shardy/dialect/sdy/ir/ops.td b/shardy/dialect/sdy/ir/ops.td index 135f9b5..f43548e 100644 --- a/shardy/dialect/sdy/ir/ops.td +++ b/shardy/dialect/sdy/ir/ops.td @@ -372,4 +372,50 @@ def PropagationBarrierOp : Sdy_Op<"propagation_barrier", let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// NamedComputationOp +//===----------------------------------------------------------------------===// + +def NamedComputationOp : Sdy_Op<"named_computation", + [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"ReturnOp">, + RecursivelySpeculatable, IsolatedFromAbove]> { + let summary = "named computation operation"; + let description = [{ + Groups a computation, i.e. a block of operations, and gives it a name. + Propagation will flow in/out of the region as if everything was inlined. + + This can be used to handle propagating through call instructions to other + functions. Any users of Shardy should write an import/export pass that + converts their call ops to `sdy.named_computation` ops, duplicating/copying + the body of the called function into the body of the `named_computation`. + + The type of each block arguments and returned values in the region must be + the same as the type of the operands and results type of the op. + + Example: + + ```mlir + %1 = sdy.named_computation<"foo">(%0) (%arg1: tensor<16x32xf32>) { + sdy.return %arg1 : tensor<16x32xf32> + } : (tensor<16x32xf32>) -> tensor<16x32xf32> + ``` + }]; + + let arguments = (ins + StrAttr:$name, + Variadic:$operands + ); + let results = (outs Variadic); + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ + `<`$name`>` `` `(` $operands `)` + custom($body) + attr-dict + `:` functional-type($operands, results) + }]; + + let hasVerifier = 1; +} + #endif // SDY_OPS diff --git a/shardy/dialect/sdy/ir/test/named_computation_parse_print.mlir b/shardy/dialect/sdy/ir/test/named_computation_parse_print.mlir new file mode 100644 index 0000000..45de49d --- /dev/null +++ b/shardy/dialect/sdy/ir/test/named_computation_parse_print.mlir @@ -0,0 +1,24 @@ +// RUN: sdy_opt %s 2>&1 | FileCheck %s + +// CHECK-LABEL: func @one_input_output +func.func @one_input_output(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // CHECK-NEXT: %0 = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) { + // CHECK-NEXT: sdy.return %arg1 : tensor<8x2xi32> + // CHECK-NEXT: } : (tensor<8x2xi32>) -> tensor<8x2xi32> + %0 = sdy.named_computation<"foo">(%arg0) (%arg1: tensor<8x2xi32>) { + sdy.return %arg1 : tensor<8x2xi32> + } : (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + + +// CHECK-LABEL: func @two_inputs_outputs +func.func @two_inputs_outputs(%arg0: tensor<8x2xi32>, %arg1: tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) { + // CHECK-NEXT: %0:2 = sdy.named_computation<"named_computation">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) { + //CHECK-NEXT: sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32> + // CHECK-NEXT: } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) + %0:2 = sdy.named_computation<"named_computation">(%arg0, %arg1) (%arg2: tensor<8x2xi32>, %arg3: tensor<4x2xi32>) { + sdy.return %arg2, %arg3 : tensor<8x2xi32>, tensor<4x2xi32> + } : (tensor<8x2xi32>, tensor<4x2xi32>) -> (tensor<8x2xi32>, tensor<4x2xi32>) + return %0#0, %0#1 : tensor<8x2xi32>, tensor<4x2xi32> +} diff --git a/shardy/dialect/sdy/ir/test/named_computation_verification.mlir b/shardy/dialect/sdy/ir/test/named_computation_verification.mlir new file mode 100644 index 0000000..c24892e --- /dev/null +++ b/shardy/dialect/sdy/ir/test/named_computation_verification.mlir @@ -0,0 +1,41 @@ +// RUN: sdy_opt %s -split-input-file -verify-diagnostics + +func.func @invalid_operand_type(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // expected-error@+1 {{expected the type of the 0'th block argument to match the type of the corresponding operand: 'tensor<4x2xi32>' vs 'tensor<8x2xi32>'}} + %0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<4x2xi32>) { + %1 = stablehlo.custom_call @foo(%arg1) : (tensor<4x2xi32>) -> tensor<8x2xi32> + sdy.return %1 : tensor<8x2xi32> + }: (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +// ----- + +func.func @operand_count_mismatch(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // expected-error@+1 {{number of block arguments must match the number of operands: 2 != 1}} + %0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>, %arg2: tensor<8x2xi32>) { + sdy.return %arg1 : tensor<8x2xi32> + }: (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +// ----- + +func.func @invalid_result_type(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // expected-error@+1 {{expected the type of the 0'th returned value to match the type of the corresponding result: 'tensor<4x2xi32>' vs 'tensor<8x2xi32>'}} + %0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>) { + %1 = stablehlo.custom_call @foo(%arg1) : (tensor<8x2xi32>) -> tensor<4x2xi32> + sdy.return %1 : tensor<4x2xi32> + }: (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} + +// ----- + +func.func @result_count_mismatch(%arg0: tensor<8x2xi32>) -> tensor<8x2xi32> { + // expected-error@+1 {{number of returned values must match the number of results: 2 != 1}} + %0 = sdy.named_computation<"bar">(%arg0) (%arg1: tensor<8x2xi32>) { + sdy.return %arg1, %arg1 : tensor<8x2xi32>, tensor<8x2xi32> + }: (tensor<8x2xi32>) -> tensor<8x2xi32> + return %0 : tensor<8x2xi32> +} diff --git a/shardy/dialect/sdy/ir/verifiers.cc b/shardy/dialect/sdy/ir/verifiers.cc index 3d2e9bf..1eab21c 100644 --- a/shardy/dialect/sdy/ir/verifiers.cc +++ b/shardy/dialect/sdy/ir/verifiers.cc @@ -852,6 +852,46 @@ LogicalResult PropagationBarrierOp::verify() { return success(); } +namespace { + +LogicalResult AllInnerAndOuterTypesMatchInNamedComputation( + NamedComputationOp op, TypeRange innerTypes, TypeRange outerTypes, + StringRef innerName, StringRef outerName) { + if (innerTypes.size() != outerTypes.size()) { + return op.emitError("number of ") + << innerName << "s must match the number of " << outerName + << "s: " << innerTypes.size() << " != " << outerTypes.size(); + } + + for (auto [i, types] : + llvm::enumerate(llvm::zip_equal(innerTypes, outerTypes))) { + auto [innerType, outerType] = types; + if (innerType != outerType) { + return op.emitError("expected the type of the ") + << i << "'th " << innerName + << " to match the type of the corresponding " << outerName << ": " + << innerType << " vs " << outerType; + } + } + + return success(); +} + +} // namespace + +LogicalResult NamedComputationOp::verify() { + if (failed(AllInnerAndOuterTypesMatchInNamedComputation( + *this, getBody().getArgumentTypes(), getOperandTypes(), + "block argument", "operand")) || + failed(AllInnerAndOuterTypesMatchInNamedComputation( + *this, getBodyTerminatorOpOperandTypes(*this), getResultTypes(), + "returned value", "result"))) { + return failure(); + } + + return success(); +} + LogicalResult SdyDialect::verifyRegionArgAttribute(Operation* op, unsigned regionIndex, unsigned argIndex, diff --git a/shardy/dialect/sdy/transforms/export/BUILD b/shardy/dialect/sdy/transforms/export/BUILD index deb699d..16aed54 100644 --- a/shardy/dialect/sdy/transforms/export/BUILD +++ b/shardy/dialect/sdy/transforms/export/BUILD @@ -36,6 +36,7 @@ cc_library( name = "passes", srcs = [ "export_pipeline.cc", + "insert_explicit_reshards.cc", "sharding_constraint_to_reshard.cc", "sink_data_flow_edges.cc", "update_non_divisible_input_output_shardings.cc", diff --git a/shardy/dialect/sdy/transforms/export/export_pipeline.cc b/shardy/dialect/sdy/transforms/export/export_pipeline.cc index 1dde661..1de197c 100644 --- a/shardy/dialect/sdy/transforms/export/export_pipeline.cc +++ b/shardy/dialect/sdy/transforms/export/export_pipeline.cc @@ -28,6 +28,7 @@ void addExportPipeline(OpPassManager& pm, StringRef dumpDirectory) { pm.addNestedPass(createShardingConstraintToReshardPass()); pm.addNestedPass( createUpdateNonDivisibleInputOutputShardingsPass()); + pm.addNestedPass(createInsertExplicitReshardsPass()); pm.addPass(mlir::sdy::createSaveModuleOpPass(dumpDirectory, "sdy_module_after_sdy_export")); } diff --git a/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc new file mode 100644 index 0000000..e92b13a --- /dev/null +++ b/shardy/dialect/sdy/transforms/export/insert_explicit_reshards.cc @@ -0,0 +1,42 @@ +/* Copyright 2024 The Shardy Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // IWYU pragma: keep +#include "mlir/Pass/Pass.h" // IWYU pragma: keep +#include "shardy/dialect/sdy/ir/dialect.h" // IWYU pragma: keep + +namespace mlir { +namespace sdy { + +#define GEN_PASS_DEF_INSERTEXPLICITRESHARDSPASS +#include "shardy/dialect/sdy/transforms/export/passes.h.inc" + +namespace { + +struct InsertExplicitReshardsPass + : public impl::InsertExplicitReshardsPassBase { + using InsertExplicitReshardsPassBase::InsertExplicitReshardsPassBase; + + void runOnOperation() final { + // Not ready yet. It is currently a no-op. + } +}; + +} // namespace + +} // namespace sdy +} // namespace mlir diff --git a/shardy/dialect/sdy/transforms/export/passes.td b/shardy/dialect/sdy/transforms/export/passes.td index c87a24f..1c0725c 100644 --- a/shardy/dialect/sdy/transforms/export/passes.td +++ b/shardy/dialect/sdy/transforms/export/passes.td @@ -43,3 +43,55 @@ def UpdateNonDivisibleInputOutputShardingsPass : Pass<"sdy-update-non-divisible- }]; let dependentDialects = ["mlir::sdy::SdyDialect"]; } + +def InsertExplicitReshardsPass : Pass<"sdy-insert-explicit-reshards", "func::FuncOp"> { + let summary = "Inserts explicit reshards to make all operations have compatible shardings."; + let description = [{ + A compatible sharding essentially means that the operation can accept the + sharded operands and produce a sharded result without requiring any reshard + communications (note that the operation might still require communication + such as all-reduce or halo-swaps). + + After propagation, some opeartions may still have incompatible shardings. + + Please note, when an axis (or sub-axis) is used to shard non-corresponding + dimensions (e.g. non-contracting dimensions in matmul) across multiple + tensors, or when an axis shards a dimension in one tensor but not the + corresponding dimension in the other tensor, it is said that the operation + has a sharding conflict. Hence, after this pass, the opeartions become + conflict-free. + + This pass injects reshard operations explicitly so that, for each operation, + corresponding dimensions become sharded in the same way across all operands + and results, and every axis (or sub-axis) can only be used to shard a single + dimension type. + + A clarifying example: + + Input: + ```mlir + mesh = <"x"=4, "y"=2> + %lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"y"},{"x"}\]>} + %rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} + stablehlo.dot %lhs, %rhs {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} + : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> + ``` + + Output: + ```mlir + sdy.mesh = <"x"=4, "y"=2> + %lhs : tensor<8x32xf32> {sdy.sharding=<@mesh, \[{"x"}, {"y"}\]>} + %rhs : tensor<32x16xf32> {sdy.sharding=<@mesh, \[{"y"}, {"x"}\]>} + %0 = sdy.reshard %rhs <@mesh, \[{"y"}, {}\]> : tensor<32x16xf32> + stablehlo.dot %lhs, %0 {sdy.sharding_per_value=<[<@mesh, \[{"x"}, {}\]>]>} + : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> + ``` + + In the example above, there is a conflict since `lhs` and `rhs` tensors + are both sharded on axis "x" on their non-contracting dimensions. Here, + `rhs` tensor is resharded, before the dot operation, explicitly to be + sharded only on its first dimension and on axis "x". This way, the dot + opearation becomes compatible. + }]; + let dependentDialects = ["mlir::sdy::SdyDialect"]; +} diff --git a/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir new file mode 100644 index 0000000..3eb7c42 --- /dev/null +++ b/shardy/dialect/sdy/transforms/export/test/insert_explicit_reshards.mlir @@ -0,0 +1,20 @@ +// RUN: sdy_opt %s -sdy-insert-explicit-reshards | FileCheck %s + +sdy.mesh @mesh = <["x"=4, "y"=2]> + +// CHECK-LABEL: func @dot_matrix_matrix_compatible +func.func @dot_matrix_matrix_compatible(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {}]>}) -> tensor<8x16xf32> { + // CHECK: stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} + // CHECK-NOT: sdy.reshard + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + + +// CHECK-LABEL: func @dot_matrix_matrix_incompatible_same_non_contracting_dims +func.func @dot_matrix_matrix_incompatible_same_non_contracting_dims(%arg0: tensor<8x32xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}, %arg1: tensor<32x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"y"}, {"x"}]>}) -> tensor<8x16xf32> { + // CHECK: stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} + // CHECK-NOT: sdy.reshard + %0 = stablehlo.dot %arg0, %arg1 {sdy.sharding_per_value = #sdy.sharding<@mesh, [{"x"}, {}]>} : (tensor<8x32xf32>, tensor<32x16xf32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} diff --git a/shardy/dialect/sdy/transforms/import/import_pipeline.cc b/shardy/dialect/sdy/transforms/import/import_pipeline.cc index 5ff4027..7c89a85 100644 --- a/shardy/dialect/sdy/transforms/import/import_pipeline.cc +++ b/shardy/dialect/sdy/transforms/import/import_pipeline.cc @@ -36,6 +36,9 @@ void addImportPipeline(OpPassManager& pm, StringRef dumpDirectory) { pm.addNestedPass(createConstantSplitterPass()); pm.addNestedPass(createAddDataFlowEdgesPass()); pm.addNestedPass(createApplyShardingConstraintsPass()); + // The sharding group import pass must run after applying sharding + // constraints. This ensures we can detect sharding conflicts between group + // members which have pre-propagation shardings due to sharding constraints. pm.addPass(createShardingGroupImportPass()); pm.addPass(createImportMaximalShardingPass()); diff --git a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc index 348685f..74d66fe 100644 --- a/shardy/dialect/sdy/transforms/import/sharding_group_import.cc +++ b/shardy/dialect/sdy/transforms/import/sharding_group_import.cc @@ -17,6 +17,7 @@ limitations under the License. #include // IWYU pragma: keep #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/BuiltinOps.h" @@ -38,47 +39,10 @@ namespace { using llvm::DenseMap; using llvm::EquivalenceClasses; using llvm::SmallDenseMap; -using llvm::SmallVector; using ValueToShardingGroup = - llvm::DenseMap>; - -void unifyShardingGroups(ValueToShardingGroup& tensorToGroups) { - if (tensorToGroups.empty()) { - return; - } - // Merge the equivalence classes of group ids which had the same tensors - // within them. (unionSets uses the default comparator and will consider the - // minimum group_id as the representative element of the equivalence class). - EquivalenceClasses shardingGroupEquivalences; - for (auto& [_, groupsForTensor] : tensorToGroups) { - const int64_t canonicalId = groupsForTensor.front().getGroupId(); - for (ShardingGroupOp group : groupsForTensor) { - shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId()); - } - } - - // After merging groups we reindex the group IDs so that they take values - // from the set {0,1,...,N-1} (N is the number of equivalence classes). - // The leader element of each equivalent class corresponds to the minimum - // group_id, so by looping over the group leaders in order their reindexed - // ids can be set to maintain the same relative ordering. - int64_t reindexId = 0; - SmallDenseMap reindexMap; - for (const auto& group : shardingGroupEquivalences) { - if (group.isLeader()) { - reindexMap[group.getData()] = reindexId++; - } - } - - // Update the graph to replace group_ids with their canonical id. - for (auto& [_, groupsForTensor] : tensorToGroups) { - for (ShardingGroupOp op : groupsForTensor) { - op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( - op.getGroupId())]); - } - } -} + llvm::MapVector>; +using GroupIdToShardingGroups = SmallVector>; LogicalResult buildShardingGroupMappingAndValidateGroups( ModuleOp module, ValueToShardingGroup& tensorToGroups) { @@ -126,6 +90,83 @@ LogicalResult buildShardingGroupMappingAndValidateGroups( return failure(result.wasInterrupted()); } +GroupIdToShardingGroups unifyShardingGroups( + ValueToShardingGroup& tensorToGroups) { + // Merge the equivalence classes of group ids which had the same tensors + // within them. (unionSets uses the default comparator and will consider the + // minimum group_id as the representative element of the equivalence class). + EquivalenceClasses shardingGroupEquivalences; + for (auto& [_, groupsForTensor] : tensorToGroups) { + int64_t canonicalId = groupsForTensor.front().getGroupId(); + for (ShardingGroupOp group : groupsForTensor) { + shardingGroupEquivalences.unionSets(canonicalId, group.getGroupId()); + } + } + + // After merging groups we reindex the group IDs so that they take values + // from the set {0,1,...,N-1} (N is the number of equivalence classes). + // The leader element of each equivalent class corresponds to the minimum + // group_id, so by looping over the group leaders in order their reindexed + // ids can be set to maintain the same relative ordering. + int64_t reindexId = 0; + SmallDenseMap reindexMap; + for (const auto& group : shardingGroupEquivalences) { + if (group.isLeader()) { + reindexMap[group.getData()] = reindexId++; + } + } + + GroupIdToShardingGroups reindexGroups(reindexId); + // Update the graph to replace group_ids with their canonical id. + for (auto& [_, groupsForTensor] : tensorToGroups) { + for (ShardingGroupOp op : groupsForTensor) { + op.setGroupId(reindexMap[shardingGroupEquivalences.getLeaderValue( + op.getGroupId())]); + reindexGroups[op.getGroupId()].push_back(op); + } + } + return reindexGroups; +} + +// This function verifies that sharding groups with pre-existing shardings are +// compatible. Compatibility means all values in the group must have either no +// sharding or the same sharding. +LogicalResult validateCompatibilityAndApplyInitialShardingConstraints( + ModuleOp module, GroupIdToShardingGroups& groupIdToShardingGroups) { + SmallDenseMap groupIdToSharding; + // Tensors can have initial shardings defined in several ways (e.g., sharding + // constraints, function arguments, manual computations). These initial + // shardings only conflict with Sharding Groups if their value belongs to a + // group. Therefore, we only need to validate the consistency of shardings + // within ShardingGroupOps to ensure no conflicts. + for (const auto& shardingGroups : groupIdToShardingGroups) { + for (ShardingGroupOp shardingGroupOp : shardingGroups) { + TensorShardingAttr sharding = getSharding(shardingGroupOp.getInput()); + int64_t groupId = shardingGroupOp.getGroupId(); + if (!sharding) { + continue; + } + auto [it, inserted] = groupIdToSharding.try_emplace(groupId, sharding); + if (!inserted && it->second != sharding) { + shardingGroupOp.emitError( + "Inconsistent shardings prior to propagation for ShardingGroupOps " + "with canonicalized groupId: ") + << groupId; + return failure(); + } + } + } + + // Apply initial shardings to all values in the group. + for (auto& [groupId, sharding] : groupIdToSharding) { + for (ShardingGroupOp shardingGroupOp : groupIdToShardingGroups[groupId]) { + setSharding(shardingGroupOp.getInput(), sharding); + } + } + + return success(); +} + struct ShardingGroupImportPass : public impl::ShardingGroupImportPassBase { using ShardingGroupImportPassBase::ShardingGroupImportPassBase; @@ -134,12 +175,26 @@ struct ShardingGroupImportPass // Extract the sharding group ids and tensor -> {group_id} mapping from the // high level module and validate any sharding group constrainst are met. ValueToShardingGroup tensorToGroups; - if (failed(buildShardingGroupMappingAndValidateGroups(getOperation(), + ModuleOp module = getOperation(); + if (failed(buildShardingGroupMappingAndValidateGroups(module, tensorToGroups))) { signalPassFailure(); } + // If there are no sharding groups, the rest of the preprocessing steps + // are not necessary. + if (tensorToGroups.empty()) { + return; + } - unifyShardingGroups(tensorToGroups); + GroupIdToShardingGroups groupIdToReindexedTensors = + unifyShardingGroups(tensorToGroups); + // This pass assumes sharding constraints are already applied to values. + // Compatibility constraints are applied after group unification to detect + // conflicts within the unified groups. + if (failed(validateCompatibilityAndApplyInitialShardingConstraints( + module, groupIdToReindexedTensors))) { + signalPassFailure(); + } } }; diff --git a/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir b/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir index c3086e7..aa4dcca 100644 --- a/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir +++ b/shardy/dialect/sdy/transforms/import/test/import_pipeline.mlir @@ -53,3 +53,59 @@ func.func @main(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) { sdy.sharding_group %arg1 group_id = 3456 : tensor<8x8xf32> func.return } + +// ----- + +// Verifies that the `-apply-sharding-constraints` pass is applied before the +// `-sharding-group-import` pass. This is validated by asserting that members +// of a sharding group pick up the sharding of a group member with a sharding +// constraint (the constraint needs to be added to the value in order for it to +// be applied to other group members). +sdy.mesh @mesh = <["a"=2]> +// CHECK-LABEL: func.func @main +func.func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: %0 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} + %0 = stablehlo.add %arg0, %arg0 : tensor<16x16xf32> + %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"a"}]> : tensor<16x16xf32> + // CHECK: %2 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"a"}]>]>} + %2 = stablehlo.add %arg0, %arg0 : tensor<16x16xf32> + sdy.sharding_group %0 group_id = 32 : tensor<16x16xf32> + sdy.sharding_group %2 group_id = 32 : tensor<16x16xf32> + return %1 : tensor<16x16xf32> +} + +// ----- + +// Verifies that the `-sdy-add-data-flow-edges` pass is applied before the +// `-sharding-group-import` pass. This is validated by adding a block argument +// of a while op to a sharding group which has a sharding constraint. This +// should be applied to other members of the group but can only happen if the +// `-sdy-add-data-flow-edges` pass is applied first. + +sdy.mesh @mesh = <["a"=2]> + +// CHECK: func.func @main +// CHECK-NEXT %arg0: tensor<16x16xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} +func.func @main(%arg0: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = stablehlo.constant dense<0> : tensor + %inc = stablehlo.constant dense<1> : tensor + %comp = stablehlo.constant dense<32> : tensor + %1:2 = stablehlo.while(%iterArg = %arg0, %iterArg_2 = %0) : tensor<16x16xf32>, tensor + cond { + %2 = stablehlo.compare LT, %iterArg_2, %comp : (tensor, tensor) -> tensor + stablehlo.return %2 : tensor + } do { + %2 = stablehlo.add %iterArg_2, %inc : tensor + // Add a value with an explicit sharding to group_id=50 which will apply an + // initial sharding to the result of the WhileOp outside of the loop. + %3 = stablehlo.add %iterArg, %iterArg : tensor<16x16xf32> + %4 = sdy.sharding_constraint %3 <@mesh, [{"a"}, {}]> : tensor<16x16xf32> + sdy.sharding_group %3 group_id = 50 : tensor<16x16xf32> + stablehlo.return %3, %2 : tensor<16x16xf32>, tensor + } + + // CHECK: sdy.data_flow_edge %3#0 sharding=<@mesh, [{"a"}, {}]> : tensor<16x16xf32> + sdy.sharding_group %1#0 group_id = 50 : tensor<16x16xf32> + return %1#0 : tensor<16x16xf32> +} + diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir index 341d14d..6c0ba90 100644 --- a/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir +++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_constraints.mlir @@ -184,3 +184,76 @@ func.func @main(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { func.return %0: tensor<8x8xf32> } +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Throw error for sharding groups which have incompatible shardings inferred +// from initial constraints. +func.func @error_for_incompatible_shardings_in_sharding_group( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}, {}]>}) { + // Sharding Group and Sharding Constraint compatibility checks happend after + // unification + canonicalization of group ids, which is why the group id + // below (555) corresponds to group id: 0 in the check-error. + sdy.sharding_group %arg0 group_id = 555 : tensor<8x8xf32> + // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} + sdy.sharding_group %arg1 group_id = 555 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// Throw error for sharding groups which have incompatible shardings inferred +// from initial constraints. +func.func @error_for_transitively_inferred_incompatible_shardings_in_unified_sharding_group( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32> + + // The shard group below will cause the above sharding groups to be merged + // by transitivity this implies that all of {%arg0, %arg1, 0, 1} should have + // the same sharding. Note that %0 and %1 are compatible by them selves but + // %arg0 and %arg1 are not due to their initial shardings. + sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32> + // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} + sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +func.func @error_for_incompatible_shardings_in_manual_computation(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) { + %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"a"}, {}]>, <@mesh, [{"b"}, {}]>] out_shardings=[<@mesh, [{"b"}, {}]>] manual_axes={} (%arg2: tensor<8x8xf32>, %arg3: tensor<8x8xf32>) { + sdy.sharding_group %arg2 group_id = 8675 : tensor<8x8xf32> + // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} + sdy.sharding_group %arg3 group_id = 8675 : tensor<8x8xf32> + sdy.return %arg2 : tensor<8x8xf32> + } : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +func.func @error_for_incompatible_shardings_with_sharding_constraint(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> + %1 = sdy.sharding_constraint %0 <@mesh, [{}, {"b"}]> : tensor<8x8xf32> + sdy.sharding_group %arg0 group_id = 1000 : tensor<8x8xf32> + // expected-error@below {{Inconsistent shardings prior to propagation for ShardingGroupOps with canonicalized groupId: 0}} + sdy.sharding_group %1 group_id = 1000 : tensor<8x8xf32> + func.return +} + diff --git a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir index 7cd8589..9fd7e88 100644 --- a/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir +++ b/shardy/dialect/sdy/transforms/import/test/sharding_group_import.mlir @@ -79,3 +79,107 @@ func.func @sharding_groups_reindex_ordering_matches_min_element_ordering(%arg0: sdy.sharding_group %arg2 group_id = 123456 : tensor<4xf32> func.return } + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: set_existing_shardings_for_sharding_group_members +func.func @set_existing_shardings_for_sharding_group_members( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {"b"}]>}) { + // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 43210 : tensor<8x8xf32> + sdy.sharding_group %arg1 group_id = 43210 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 43210 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: transitively_update_shardings_for_sharding_group_members +func.func @transitively_update_shardings_for_sharding_group_members( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, + %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + // CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 10 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 20 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 20 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 30 : tensor<8x8xf32> + sdy.sharding_group %arg1 group_id = 30 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: set_existing_shards_for_disjoint_groups +// CHECK-SAMEL %arg1: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} +// CHECK-SAMEL %arg3: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>} +func.func @set_existing_shards_for_disjoint_groups( + %arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}, + %arg1: tensor<8x8xf32>, + %arg2: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {"b"}]>}, + %arg3: tensor<8x8xf32>) { + // CHECK: %cst = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %0 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + // CHECK: %cst_0 = stablehlo.constant {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{}, {"b"}]>]>} dense<0.000000e+00> : tensor<8x8xf32> + %1 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + // CHECK: %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<8x8xf32> + %2 = stablehlo.constant dense<0.0> : tensor<8x8xf32> + + sdy.sharding_group %arg0 group_id = 11111 : tensor<8x8xf32> + sdy.sharding_group %arg1 group_id = 11111 : tensor<8x8xf32> + sdy.sharding_group %0 group_id = 11111 : tensor<8x8xf32> + + sdy.sharding_group %arg2 group_id = 22222 : tensor<8x8xf32> + sdy.sharding_group %arg3 group_id = 22222 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 22222 : tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +// CHECK-LABEL: set_existing_shardings_in_manual_computation_op +func.func @set_existing_shardings_in_manual_computation_op(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf32>) { + %0 = sdy.manual_computation(%arg0, %arg1) in_shardings=[<@mesh, [{"a"}, {}]>, <@mesh, [{"a"}, {}]>] out_shardings=[<@mesh, [{"a"}, {}]>] manual_axes={} (%arg2: tensor<8x8xf32>, %arg3: tensor<8x8xf32>) { + // CHECK: %1 = stablehlo.add %arg2, %arg2 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} : tensor<8x8xf32> + %1 = stablehlo.add %arg2, %arg2 : tensor<8x8xf32> + // CHECK: %2 = stablehlo.add %arg3, %arg3 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} : tensor<8x8xf32> + %2 = stablehlo.add %arg3, %arg3 : tensor<8x8xf32> + + sdy.sharding_group %1 group_id = 1000 : tensor<8x8xf32> + sdy.sharding_group %2 group_id = 1000 : tensor<8x8xf32> + sdy.sharding_group %arg2 group_id = 1000 : tensor<8x8xf32> + sdy.sharding_group %arg3 group_id = 1000 : tensor<8x8xf32> + sdy.return %1 : tensor<8x8xf32> + } : (tensor<8x8xf32>, tensor<8x8xf32>) -> tensor<8x8xf32> + func.return +} + +// ----- + +sdy.mesh @mesh = <["a"=2, "b"=2]> + +func.func @set_existing_shardings_in_groups_with_sharding_constraint(%arg0: tensor<8x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}, {}]>}) { + %0 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> + %1 = sdy.sharding_constraint %0 <@mesh, [{"a"}, {}]> : tensor<8x8xf32> + // CHECK: %2 = stablehlo.add %arg0, %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}, {}]>]>} : tensor<8x8xf32> + %2 = stablehlo.add %arg0, %arg0 : tensor<8x8xf32> + sdy.sharding_group %arg0 group_id = 1000 : tensor<8x8xf32> + sdy.sharding_group %1 group_id = 1000 : tensor<8x8xf32> + sdy.sharding_group %2 group_id = 1000 : tensor<8x8xf32> + func.return +} diff --git a/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.cc b/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.cc index 8950400..bd2bab3 100644 --- a/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.cc +++ b/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.cc @@ -18,10 +18,11 @@ limitations under the License. #include #include +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Mutex.h" +#include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/PassOptions.h" -#include "mlir/IR/DialectRegistry.h" // from @llvm-project namespace mlir { namespace sdy { @@ -41,23 +42,29 @@ void AutoPartitionerRegistry::setCallback( AutoPartitionerCallback callback, RegisterDependantDialectsCallback dialectsDependenciesCallback) { llvm::sys::ScopedLock scopedLock(*mutex); - // TODO(tomnatan): find a better way to fail in this case, and consider // allowing registring multiple callbacks with different keys (that are passed // by the user to sdy). - assert(!isRegistered() && "auto-partitioner callback already registered"); + + if (isRegistered()) { + llvm::report_fatal_error("auto-partitioner callback already registered"); + } *registeredCallback = callback; *registeredDependenciesCallback = dialectsDependenciesCallback; } void AutoPartitionerRegistry::addPasses(OpPassManager& pm) { // TODO(tomnatan): find a better way to fail in this case. - assert(isRegistered() && "auto-partitioner callback wasn't registered"); + if (!isRegistered()) { + llvm::report_fatal_error("auto-partitioner callback wasn't registered"); + } registeredCallback->value()(pm); } void AutoPartitionerRegistry::getDependentDialects(DialectRegistry& registry) { - assert(isRegistered() && "auto-partitioner callback wasn't registered"); + if (!isRegistered()) { + llvm::report_fatal_error("auto-partitioner callback wasn't registered"); + } registeredDependenciesCallback->value()(registry); } diff --git a/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.h b/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.h index fbb5e44..84abc31 100644 --- a/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.h +++ b/shardy/dialect/sdy/transforms/propagation/auto_partitioner_registry.h @@ -18,8 +18,8 @@ limitations under the License. #include +#include "mlir/IR/DialectRegistry.h" #include "mlir/Pass/PassOptions.h" -#include "mlir/IR/DialectRegistry.h" // from @llvm-project namespace mlir { namespace sdy { diff --git a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc index 40ef9ee..b073e19 100644 --- a/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc +++ b/shardy/dialect/sdy/transforms/propagation/op_sharding_rule_registry.cc @@ -93,17 +93,25 @@ void addGatherScatterFactors(RankedTensorType inputType, for (auto [slicesDim, slicesDimSize] : llvm::enumerate(slicesType.getShape())) { if (llvm::is_contained(offsetDims, slicesDim)) { - // `dim` is an offset dimension. - // We must now look up the next non-collapsed/batching input dimension - // that corresponds to this slices offset dimension. + // `slicesDim` is an offset dimension. + // We look up the non-collapsed/batching input dimension that corresponds + // to this slices offset dimension in the input. while (llvm::is_contained(collapsedSliceDims, inputDim) || llvm::is_contained(inputBatchingDims, inputDim)) { ++inputDim; } assert(inputDim < inputType.getRank()); - if (inputType.getDimSize(inputDim) == slicesDimSize) { - // We only propagate through unsliced dimensions. - addFactorFn(inputDim, /*indicesDim=*/kNullDim, slicesDim, + int64_t inputDimSize = inputType.getDimSize(inputDim); + + // If this dimension is unsliced, we add a common factor for `inputDim` + // and `slicesDim`. Otherwise, we add a unique factor for `inputDim` and + // `slicesDim` respectively. + if (inputDimSize == slicesDimSize) { + addFactorFn(inputDim, /*indicesDim=*/kNullDim, slicesDim, inputDimSize); + } else { + addFactorFn(inputDim, /*indicesDim=*/kNullDim, /*slicesDim=*/kNullDim, + inputDimSize); + addFactorFn(/*inputDim=*/kNullDim, /*indicesDim=*/kNullDim, slicesDim, slicesDimSize); } ++inputDim; @@ -126,6 +134,13 @@ void addGatherScatterFactors(RankedTensorType inputType, ++batchDimPos; } } + + // We add factors for all collapsed slice dimensions. + for (int64_t collapsedSliceDim : collapsedSliceDims) { + addFactorFn(collapsedSliceDim, /*indicesDim=*/kNullDim, + /*slicesDim=*/kNullDim, + inputType.getDimSize(collapsedSliceDim)); + } } } // namespace diff --git a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir index 34ac5fc..b6841d3 100644 --- a/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir +++ b/shardy/dialect/sdy/transforms/propagation/test/op_sharding_rule_registry.mlir @@ -383,7 +383,7 @@ func.func @fft_truncated_result(%arg0: tensor<8x32x64xf32>) -> tensor<8x32x33xco // CHECK-LABEL: @gather func.func @gather(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([l, m, k], [i, j, n])->([i, j, o, k]) {i=2, j=3, k=2, l=1, m=1, n=1, o=1}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([n, k, m], [i, j, o])->([i, j, l, m]) {i=2, j=3, k=4, l=2, m=2, n=3, o=1}> %0 = "stablehlo.gather"(%arg0, %arg1) { dimension_numbers = #stablehlo.gather< offset_dims = [2, 3], @@ -398,7 +398,7 @@ func.func @gather(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi64>) -> tensor< // CHECK-LABEL: @gather_batching_dims func.func @gather_batching_dims(%arg0: tensor<5x3x7x4xf32>, %arg1: tensor<7x5x3x2xi64>) -> tensor<7x5x3x2xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([j, l, i, m], [i, j, k, n])->([i, j, k, o]) {i=7, j=5, k=3, l=1, m=1, n=1, o=1}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([j, n, i, l], [i, j, k, o])->([i, j, k, m]) {i=7, j=5, k=3, l=4, m=2, n=3, o=1}> %0 = "stablehlo.gather"(%arg0, %arg1) { dimension_numbers = #stablehlo.gather< offset_dims = [3], @@ -415,7 +415,7 @@ func.func @gather_batching_dims(%arg0: tensor<5x3x7x4xf32>, %arg1: tensor<7x5x3x // CHECK-LABEL: @gather_index_vector_dim_before_batching_dim func.func @gather_index_vector_dim_before_batching_dim(%arg0: tensor<5x3x7x4xf32>, %arg1: tensor<7x2x5x3xi64>) -> tensor<7x5x3x2xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([j, l, i, m], [i, n, j, k])->([i, j, k, o]) {i=7, j=5, k=3, l=1, m=1, n=1, o=1}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([j, n, i, l], [i, o, j, k])->([i, j, k, m]) {i=7, j=5, k=3, l=4, m=2, n=3, o=1}> %0 = "stablehlo.gather"(%arg0, %arg1) { dimension_numbers = #stablehlo.gather< offset_dims = [3], @@ -608,7 +608,7 @@ func.func @reverse(%arg0: tensor<4x32x8x2xf32>) -> tensor<4x32x8x2xf32> { // CHECK-LABEL: @scatter_single_input func.func @scatter_single_input(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi64>, %arg2: tensor<2x3x2x2xf32>) -> tensor<3x4x2xf32>{ - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([l, m, k], [i, j, n], [i, j, o, k])->([p, q, k]) {i=2, j=3, k=2, l=1, m=1, n=1, o=1, p=1, q=1}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([n, k, m], [i, j, o], [i, j, l, m])->([n, k, m]) {i=2, j=3, k=4, l=2, m=2, n=3, o=1}> %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = stablehlo.add %arg3, %arg4 : tensor @@ -625,9 +625,28 @@ func.func @scatter_single_input(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi6 return %0 : tensor<3x4x2xf32> } +// CHECK-LABEL: @scatter_inserted_window_dim_is_last_one +func.func @scatter_inserted_window_dim_is_last_one(%arg0: tensor<4x2x3xf32>, %arg1: tensor<2x3x2xi64>, %arg2: tensor<2x3x2x2xf32>) -> tensor<4x2x3xf32>{ + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([k, m, n], [i, j, o], [i, j, l, m])->([k, m, n]) {i=2, j=3, k=4, l=2, m=2, n=3, o=1}> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = stablehlo.add %arg3, %arg4 : tensor + stablehlo.return %1 : tensor + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [2, 3], + inserted_window_dims = [2], + scatter_dims_to_operand_dims = [1, 0], + index_vector_dim = 2>, + indices_are_sorted = false, + unique_indices = false + } : (tensor<4x2x3xf32>, tensor<2x3x2xi64>, tensor<2x3x2x2xf32>) -> tensor<4x2x3xf32> + return %0 : tensor<4x2x3xf32> +} + // CHECK-LABEL: @scatter_batching_dims func.func @scatter_batching_dims(%arg0: tensor<5x3x7x4xf32>, %arg1: tensor<7x5x3x2xi64>, %arg2: tensor<7x5x3x2xf32>) -> tensor<5x3x7x4xf32> { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([j, l, i, m], [i, j, k, n], [i, j, k, o])->([j, p, i, q]) {i=7, j=5, k=3, l=1, m=1, n=1, o=1, p=1, q=1}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([j, n, i, l], [i, j, k, o], [i, j, k, m])->([j, n, i, l]) {i=7, j=5, k=3, l=4, m=2, n=3, o=1}> %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ ^bb0(%arg3: tensor, %arg4: tensor): %1 = stablehlo.add %arg3, %arg4 : tensor @@ -653,7 +672,7 @@ func.func @scatter_multiple_input(%arg0: tensor<3x4x2xi32>, %arg3: tensor<2x3x2x2xi32>, %arg4: tensor<2x3x2x2xf32>) -> (tensor<3x4x2xi32>, tensor<3x4x2xf32>) { - // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([l, m, k], [n, o, k], [i, j, p], [i, j, q, k], [i, j, r, k])->([s, t, k], [u, v, k]) {i=2, j=3, k=2, l=1, m=1, n=1, o=1, p=1, q=1, r=1, s=1, t=1, u=1, v=1}> + // CHECK: sdy.sharding_rule = #sdy.op_sharding_rule<([n, k, m], [n, k, m], [i, j, o], [i, j, l, m], [i, j, l, m])->([n, k, m], [n, k, m]) {i=2, j=3, k=4, l=2, m=2, n=3, o=1}> %0:2 = "stablehlo.scatter"(%arg0, %arg1, %arg2, %arg3, %arg4) ({ ^bb0(%arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor): %1 = stablehlo.add %arg5, %arg7 : tensor diff --git a/third_party/llvm/generated.patch b/third_party/llvm/generated.patch index de92cb4..509398d 100644 --- a/third_party/llvm/generated.patch +++ b/third_party/llvm/generated.patch @@ -1,4095 +1 @@ Auto generated patch. Do not edit or delete it, even if empty. -diff -ruN --strip-trailing-cr a/llvm/docs/NVPTXUsage.rst b/llvm/docs/NVPTXUsage.rst ---- a/llvm/docs/NVPTXUsage.rst -+++ b/llvm/docs/NVPTXUsage.rst -@@ -127,6 +127,69 @@ - NVPTX Intrinsics - ================ - -+Address Space Conversion -+------------------------ -+ -+'``llvm.nvvm.ptr.*.to.gen``' Intrinsics -+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -+ -+Syntax: -+""""""" -+ -+These are overloaded intrinsics. You can use these on any pointer types. -+ -+.. code-block:: llvm -+ -+ declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) -+ declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) -+ declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) -+ declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -+ -+Overview: -+""""""""" -+ -+The '``llvm.nvvm.ptr.*.to.gen``' intrinsics convert a pointer in a non-generic -+address space to a generic address space pointer. -+ -+Semantics: -+"""""""""" -+ -+These intrinsics modify the pointer value to be a valid generic address space -+pointer. -+ -+ -+'``llvm.nvvm.ptr.gen.to.*``' Intrinsics -+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -+ -+Syntax: -+""""""" -+ -+These are overloaded intrinsics. You can use these on any pointer types. -+ -+.. code-block:: llvm -+ -+ declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -+ declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) -+ declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) -+ declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) -+ -+Overview: -+""""""""" -+ -+The '``llvm.nvvm.ptr.gen.to.*``' intrinsics convert a pointer in the generic -+address space to a pointer in the target address space. Note that these -+intrinsics are only useful if the address space of the target address space of -+the pointer is known. It is not legal to use address space conversion -+intrinsics to convert a pointer from one non-generic address space to another -+non-generic address space. -+ -+Semantics: -+"""""""""" -+ -+These intrinsics modify the pointer value to be a valid pointer in the target -+non-generic address space. -+ -+ - Reading PTX Special Registers - ----------------------------- - -diff -ruN --strip-trailing-cr a/llvm/docs/ReleaseNotes.rst b/llvm/docs/ReleaseNotes.rst ---- a/llvm/docs/ReleaseNotes.rst -+++ b/llvm/docs/ReleaseNotes.rst -@@ -63,24 +63,6 @@ - * ``llvm.nvvm.bitcast.d2ll`` - * ``llvm.nvvm.bitcast.ll2d`` - --* Remove the following intrinsics which can be replaced with a funnel-shift: -- -- * ``llvm.nvvm.rotate.b32`` -- * ``llvm.nvvm.rotate.right.b64`` -- * ``llvm.nvvm.rotate.b64`` -- --* Remove the following intrinsics which can be replaced with an -- ``addrspacecast``: -- -- * ``llvm.nvvm.ptr.gen.to.global`` -- * ``llvm.nvvm.ptr.gen.to.shared`` -- * ``llvm.nvvm.ptr.gen.to.constant`` -- * ``llvm.nvvm.ptr.gen.to.local`` -- * ``llvm.nvvm.ptr.global.to.gen`` -- * ``llvm.nvvm.ptr.shared.to.gen`` -- * ``llvm.nvvm.ptr.constant.to.gen`` -- * ``llvm.nvvm.ptr.local.to.gen`` -- - Changes to LLVM infrastructure - ------------------------------ - -diff -ruN --strip-trailing-cr a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td ---- a/llvm/include/llvm/IR/IntrinsicsNVVM.td -+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td -@@ -30,18 +30,10 @@ - // * llvm.nvvm.max.ui --> select(x ule y, x, y) - // * llvm.nvvm.max.ull --> ibid. - // * llvm.nvvm.h2f --> llvm.convert.to.fp16.f32 --// * llvm.nvvm.bitcast.f2i --> bitcast --// * llvm.nvvm.bitcast.i2f --> ibid. --// * llvm.nvvm.bitcast.d2ll --> ibid. --// * llvm.nvvm.bitcast.ll2d --> ibid. --// * llvm.nvvm.ptr.gen.to.global --> addrspacecast --// * llvm.nvvm.ptr.gen.to.shared --> ibid. --// * llvm.nvvm.ptr.gen.to.constant --> ibid. --// * llvm.nvvm.ptr.gen.to.local --> ibid. --// * llvm.nvvm.ptr.global.to.gen --> ibid. --// * llvm.nvvm.ptr.shared.to.gen --> ibid. --// * llvm.nvvm.ptr.constant.to.gen --> ibid. --// * llvm.nvvm.ptr.local.to.gen --> ibid. -+// * llvm.nvvm.bitcast.f2i --> bitcast -+// * llvm.nvvm.bitcast.i2f --> ibid. -+// * llvm.nvvm.bitcast.d2ll --> ibid. -+// * llvm.nvvm.bitcast.ll2d --> ibid. - - def llvm_global_ptr_ty : LLVMQualPointerType<1>; // (global)ptr - def llvm_shared_ptr_ty : LLVMQualPointerType<3>; // (shared)ptr -@@ -1610,6 +1602,40 @@ - [IntrReadMem, IntrArgMemOnly, IntrNoCallback, IntrWillReturn, NoCapture>], - "llvm.nvvm.ldg.global.p">; - -+// Use for generic pointers -+// - These intrinsics are used to convert address spaces. -+// - The input pointer and output pointer must have the same type, except for -+// the address-space. (This restriction is not enforced here as there is -+// currently no way to describe it). -+// - This complements the llvm bitcast, which can be used to cast one type -+// of pointer to another type of pointer, while the address space remains -+// the same. -+def int_nvvm_ptr_local_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.local.to.gen">; -+def int_nvvm_ptr_shared_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.shared.to.gen">; -+def int_nvvm_ptr_global_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.global.to.gen">; -+def int_nvvm_ptr_constant_to_gen: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.constant.to.gen">; -+ -+def int_nvvm_ptr_gen_to_global: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.global">; -+def int_nvvm_ptr_gen_to_shared: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.shared">; -+def int_nvvm_ptr_gen_to_local: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.local">; -+def int_nvvm_ptr_gen_to_constant: DefaultAttrsIntrinsic<[llvm_anyptr_ty], -+ [llvm_anyptr_ty], [IntrNoMem, IntrSpeculatable], -+ "llvm.nvvm.ptr.gen.to.constant">; -+ - // Used in nvvm internally to help address space opt and ptx code generation - // This is for params that are passed to kernel functions by pointer by-val. - def int_nvvm_ptr_gen_to_param: Intrinsic<[llvm_anyptr_ty], -@@ -4453,6 +4479,22 @@ - "llvm.nvvm.sust.p.3d.v4i32.trap">, - ClangBuiltin<"__nvvm_sust_p_3d_v4i32_trap">; - -+ -+def int_nvvm_rotate_b32 -+ : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b32">, -+ ClangBuiltin<"__nvvm_rotate_b32">; -+ -+def int_nvvm_rotate_b64 -+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.b64">, -+ ClangBuiltin<"__nvvm_rotate_b64">; -+ -+def int_nvvm_rotate_right_b64 -+ : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i32_ty], -+ [IntrNoMem, IntrSpeculatable], "llvm.nvvm.rotate.right.b64">, -+ ClangBuiltin<"__nvvm_rotate_right_b64">; -+ - def int_nvvm_swap_lo_hi_b64 - : DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty], - [IntrNoMem, IntrSpeculatable], "llvm.nvvm.swap.lo.hi.b64">, -diff -ruN --strip-trailing-cr a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp ---- a/llvm/lib/IR/AutoUpgrade.cpp -+++ b/llvm/lib/IR/AutoUpgrade.cpp -@@ -1272,19 +1272,6 @@ - // nvvm.bitcast.{f2i,i2f,ll2d,d2ll} - Expand = - Name == "f2i" || Name == "i2f" || Name == "ll2d" || Name == "d2ll"; -- else if (Name.consume_front("rotate.")) -- // nvvm.rotate.{b32,b64,right.b64} -- Expand = Name == "b32" || Name == "b64" || Name == "right.b64"; -- else if (Name.consume_front("ptr.gen.to.")) -- // nvvm.ptr.gen.to.{local,shared,global,constant} -- Expand = Name.starts_with("local") || Name.starts_with("shared") || -- Name.starts_with("global") || Name.starts_with("constant"); -- else if (Name.consume_front("ptr.")) -- // nvvm.ptr.{local,shared,global,constant}.to.gen -- Expand = -- (Name.consume_front("local") || Name.consume_front("shared") || -- Name.consume_front("global") || Name.consume_front("constant")) && -- Name.starts_with(".to.gen"); - else - Expand = false; - -@@ -2271,117 +2258,6 @@ - } - } - --static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI, -- Function *F, IRBuilder<> &Builder) { -- Value *Rep = nullptr; -- -- if (Name == "abs.i" || Name == "abs.ll") { -- Value *Arg = CI->getArgOperand(0); -- Value *Neg = Builder.CreateNeg(Arg, "neg"); -- Value *Cmp = Builder.CreateICmpSGE( -- Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -- Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -- } else if (Name.starts_with("atomic.load.add.f32.p") || -- Name.starts_with("atomic.load.add.f64.p")) { -- Value *Ptr = CI->getArgOperand(0); -- Value *Val = CI->getArgOperand(1); -- Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -- AtomicOrdering::SequentiallyConsistent); -- } else if (Name.consume_front("max.") && -- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -- Name == "ui" || Name == "ull")) { -- Value *Arg0 = CI->getArgOperand(0); -- Value *Arg1 = CI->getArgOperand(1); -- Value *Cmp = Name.starts_with("u") -- ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -- : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -- } else if (Name.consume_front("min.") && -- (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -- Name == "ui" || Name == "ull")) { -- Value *Arg0 = CI->getArgOperand(0); -- Value *Arg1 = CI->getArgOperand(1); -- Value *Cmp = Name.starts_with("u") -- ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -- : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -- Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -- } else if (Name == "clz.ll") { -- // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -- Value *Arg = CI->getArgOperand(0); -- Value *Ctlz = Builder.CreateCall( -- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -- {Arg->getType()}), -- {Arg, Builder.getFalse()}, "ctlz"); -- Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -- } else if (Name == "popc.ll") { -- // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -- // i64. -- Value *Arg = CI->getArgOperand(0); -- Value *Popc = Builder.CreateCall( -- Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -- {Arg->getType()}), -- Arg, "ctpop"); -- Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); -- } else if (Name == "h2f") { -- Rep = Builder.CreateCall( -- Intrinsic::getDeclaration(F->getParent(), Intrinsic::convert_from_fp16, -- {Builder.getFloatTy()}), -- CI->getArgOperand(0), "h2f"); -- } else if (Name.consume_front("bitcast.") && -- (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -- Name == "d2ll")) { -- Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -- } else if (Name == "rotate.b32") { -- Value *Arg = CI->getOperand(0); -- Value *ShiftAmt = CI->getOperand(1); -- Rep = Builder.CreateIntrinsic(Builder.getInt32Ty(), Intrinsic::fshl, -- {Arg, Arg, ShiftAmt}); -- } else if (Name == "rotate.b64") { -- Type *Int64Ty = Builder.getInt64Ty(); -- Value *Arg = CI->getOperand(0); -- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshl, -- {Arg, Arg, ZExtShiftAmt}); -- } else if (Name == "rotate.right.b64") { -- Type *Int64Ty = Builder.getInt64Ty(); -- Value *Arg = CI->getOperand(0); -- Value *ZExtShiftAmt = Builder.CreateZExt(CI->getOperand(1), Int64Ty); -- Rep = Builder.CreateIntrinsic(Int64Ty, Intrinsic::fshr, -- {Arg, Arg, ZExtShiftAmt}); -- } else if ((Name.consume_front("ptr.gen.to.") && -- (Name.starts_with("local") || Name.starts_with("shared") || -- Name.starts_with("global") || Name.starts_with("constant"))) || -- (Name.consume_front("ptr.") && -- (Name.consume_front("local") || Name.consume_front("shared") || -- Name.consume_front("global") || -- Name.consume_front("constant")) && -- Name.starts_with(".to.gen"))) { -- Rep = Builder.CreateAddrSpaceCast(CI->getArgOperand(0), CI->getType()); -- } else { -- Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -- if (IID != Intrinsic::not_intrinsic && -- !F->getReturnType()->getScalarType()->isBFloatTy()) { -- rename(F); -- Function *NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -- SmallVector Args; -- for (size_t I = 0; I < NewFn->arg_size(); ++I) { -- Value *Arg = CI->getArgOperand(I); -- Type *OldType = Arg->getType(); -- Type *NewType = NewFn->getArg(I)->getType(); -- Args.push_back( -- (OldType->isIntegerTy() && NewType->getScalarType()->isBFloatTy()) -- ? Builder.CreateBitCast(Arg, NewType) -- : Arg); -- } -- Rep = Builder.CreateCall(NewFn, Args); -- if (F->getReturnType()->isIntegerTy()) -- Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -- } -- } -- -- return Rep; --} -- - static Value *upgradeX86IntrinsicCall(StringRef Name, CallBase *CI, Function *F, - IRBuilder<> &Builder) { - LLVMContext &C = F->getContext(); -@@ -4332,8 +4208,85 @@ - - if (!IsX86 && Name == "stackprotectorcheck") { - Rep = nullptr; -+ } else if (IsNVVM && (Name == "abs.i" || Name == "abs.ll")) { -+ Value *Arg = CI->getArgOperand(0); -+ Value *Neg = Builder.CreateNeg(Arg, "neg"); -+ Value *Cmp = Builder.CreateICmpSGE( -+ Arg, llvm::Constant::getNullValue(Arg->getType()), "abs.cond"); -+ Rep = Builder.CreateSelect(Cmp, Arg, Neg, "abs"); -+ } else if (IsNVVM && (Name.starts_with("atomic.load.add.f32.p") || -+ Name.starts_with("atomic.load.add.f64.p"))) { -+ Value *Ptr = CI->getArgOperand(0); -+ Value *Val = CI->getArgOperand(1); -+ Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(), -+ AtomicOrdering::SequentiallyConsistent); -+ } else if (IsNVVM && Name.consume_front("max.") && -+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+ Name == "ui" || Name == "ull")) { -+ Value *Arg0 = CI->getArgOperand(0); -+ Value *Arg1 = CI->getArgOperand(1); -+ Value *Cmp = Name.starts_with("u") -+ ? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond") -+ : Builder.CreateICmpSGE(Arg0, Arg1, "max.cond"); -+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max"); -+ } else if (IsNVVM && Name.consume_front("min.") && -+ (Name == "s" || Name == "i" || Name == "ll" || Name == "us" || -+ Name == "ui" || Name == "ull")) { -+ Value *Arg0 = CI->getArgOperand(0); -+ Value *Arg1 = CI->getArgOperand(1); -+ Value *Cmp = Name.starts_with("u") -+ ? Builder.CreateICmpULE(Arg0, Arg1, "min.cond") -+ : Builder.CreateICmpSLE(Arg0, Arg1, "min.cond"); -+ Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min"); -+ } else if (IsNVVM && Name == "clz.ll") { -+ // llvm.nvvm.clz.ll returns an i32, but llvm.ctlz.i64 returns an i64. -+ Value *Arg = CI->getArgOperand(0); -+ Value *Ctlz = Builder.CreateCall( -+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctlz, -+ {Arg->getType()}), -+ {Arg, Builder.getFalse()}, "ctlz"); -+ Rep = Builder.CreateTrunc(Ctlz, Builder.getInt32Ty(), "ctlz.trunc"); -+ } else if (IsNVVM && Name == "popc.ll") { -+ // llvm.nvvm.popc.ll returns an i32, but llvm.ctpop.i64 returns an -+ // i64. -+ Value *Arg = CI->getArgOperand(0); -+ Value *Popc = Builder.CreateCall( -+ Intrinsic::getDeclaration(F->getParent(), Intrinsic::ctpop, -+ {Arg->getType()}), -+ Arg, "ctpop"); -+ Rep = Builder.CreateTrunc(Popc, Builder.getInt32Ty(), "ctpop.trunc"); - } else if (IsNVVM) { -- Rep = upgradeNVVMIntrinsicCall(Name, CI, F, Builder); -+ if (Name == "h2f") { -+ Rep = -+ Builder.CreateCall(Intrinsic::getDeclaration( -+ F->getParent(), Intrinsic::convert_from_fp16, -+ {Builder.getFloatTy()}), -+ CI->getArgOperand(0), "h2f"); -+ } else if (Name.consume_front("bitcast.") && -+ (Name == "f2i" || Name == "i2f" || Name == "ll2d" || -+ Name == "d2ll")) { -+ Rep = Builder.CreateBitCast(CI->getArgOperand(0), CI->getType()); -+ } else { -+ Intrinsic::ID IID = shouldUpgradeNVPTXBF16Intrinsic(Name); -+ if (IID != Intrinsic::not_intrinsic && -+ !F->getReturnType()->getScalarType()->isBFloatTy()) { -+ rename(F); -+ NewFn = Intrinsic::getDeclaration(F->getParent(), IID); -+ SmallVector Args; -+ for (size_t I = 0; I < NewFn->arg_size(); ++I) { -+ Value *Arg = CI->getArgOperand(I); -+ Type *OldType = Arg->getType(); -+ Type *NewType = NewFn->getArg(I)->getType(); -+ Args.push_back((OldType->isIntegerTy() && -+ NewType->getScalarType()->isBFloatTy()) -+ ? Builder.CreateBitCast(Arg, NewType) -+ : Arg); -+ } -+ Rep = Builder.CreateCall(NewFn, Args); -+ if (F->getReturnType()->isIntegerTy()) -+ Rep = Builder.CreateBitCast(Rep, F->getReturnType()); -+ } -+ } - } else if (IsX86) { - Rep = upgradeX86IntrinsicCall(Name, CI, F, Builder); - } else if (IsARM) { -diff -ruN --strip-trailing-cr a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp ---- a/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -+++ b/llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp -@@ -292,7 +292,6 @@ - static const LLT S224 = LLT::scalar(224); - static const LLT S256 = LLT::scalar(256); - static const LLT S512 = LLT::scalar(512); --static const LLT S1024 = LLT::scalar(1024); - static const LLT MaxScalar = LLT::scalar(MaxRegisterSize); - - static const LLT V2S8 = LLT::fixed_vector(2, 8); -@@ -333,8 +332,8 @@ - static const LLT V2S128 = LLT::fixed_vector(2, 128); - static const LLT V4S128 = LLT::fixed_vector(4, 128); - --static std::initializer_list AllScalarTypes = { -- S32, S64, S96, S128, S160, S224, S256, S512, S1024}; -+static std::initializer_list AllScalarTypes = {S32, S64, S96, S128, -+ S160, S224, S256, S512}; - - static std::initializer_list AllS16Vectors{ - V2S16, V4S16, V6S16, V8S16, V10S16, V12S16, V16S16, V2S128, V4S128}; -@@ -890,11 +889,10 @@ - .clampScalar(0, S16, S64); - - getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}) -- .legalIf(isRegisterClassType(0)) -+ .legalIf(isRegisterType(0)) - // s1 and s16 are special cases because they have legal operations on - // them, but don't really occupy registers in the normal way. - .legalFor({S1, S16}) -- .clampNumElements(0, V16S32, V32S32) - .moreElementsIf(isSmallOddVector(0), oneMoreElement(0)) - .clampScalarOrElt(0, S32, MaxScalar) - .widenScalarToNextPow2(0, 32) -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td ---- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td -@@ -174,6 +174,10 @@ - def hasSHFL : Predicate<"!(Subtarget->getSmVersion() >= 70" - "&& Subtarget->getPTXVersion() >= 64)">; - -+def useShortPtrLocal : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_LOCAL) == 32">; -+def useShortPtrShared : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_SHARED) == 32">; -+def useShortPtrConst : Predicate<"TM.is64Bit() && TM.getPointerSizeInBits(ADDRESS_SPACE_CONST) == 32">; -+ - def useFP16Math: Predicate<"Subtarget->allowFP16Math()">; - def hasBF16Math: Predicate<"Subtarget->hasBF16Math()">; - -@@ -1661,6 +1665,167 @@ - "brev.b64 \t$dst, $a;", - [(set Int64Regs:$dst, (bitreverse Int64Regs:$a))]>; - -+// -+// Rotate: Use ptx shf instruction if available. -+// -+ -+// 32 bit r2 = rotl r1, n -+// => -+// r2 = shf.l r1, r1, n -+def ROTL32imm_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+def ROTL32reg_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+// 32 bit r2 = rotr r1, n -+// => -+// r2 = shf.r r1, r1, n -+def ROTR32imm_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, i32imm:$amt), -+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 imm:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+def ROTR32reg_hw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "shf.r.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[hasHWROT32]>; -+ -+// 32-bit software rotate by immediate. $amt2 should equal 32 - $amt1. -+def ROT32imm_sw : -+ NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$src, i32imm:$amt1, i32imm:$amt2), -+ "{{\n\t" -+ ".reg .b32 %lhs;\n\t" -+ ".reg .b32 %rhs;\n\t" -+ "shl.b32 \t%lhs, $src, $amt1;\n\t" -+ "shr.b32 \t%rhs, $src, $amt2;\n\t" -+ "add.u32 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ []>; -+ -+def SUB_FRM_32 : SDNodeXFormgetTargetConstant(32 - N->getZExtValue(), SDLoc(N), MVT::i32); -+}]>; -+ -+def : Pat<(rotl (i32 Int32Regs:$src), (i32 imm:$amt)), -+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -+ Requires<[noHWROT32]>; -+def : Pat<(rotr (i32 Int32Regs:$src), (i32 imm:$amt)), -+ (ROT32imm_sw Int32Regs:$src, (SUB_FRM_32 node:$amt), imm:$amt)>, -+ Requires<[noHWROT32]>; -+ -+// 32-bit software rotate left by register. -+def ROTL32reg_sw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b32 %lhs;\n\t" -+ ".reg .b32 %rhs;\n\t" -+ ".reg .b32 %amt2;\n\t" -+ "shl.b32 \t%lhs, $src, $amt;\n\t" -+ "sub.s32 \t%amt2, 32, $amt;\n\t" -+ "shr.b32 \t%rhs, $src, %amt2;\n\t" -+ "add.u32 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int32Regs:$dst, (rotl (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[noHWROT32]>; -+ -+// 32-bit software rotate right by register. -+def ROTR32reg_sw : -+ NVPTXInst<(outs Int32Regs:$dst), (ins Int32Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b32 %lhs;\n\t" -+ ".reg .b32 %rhs;\n\t" -+ ".reg .b32 %amt2;\n\t" -+ "shr.b32 \t%lhs, $src, $amt;\n\t" -+ "sub.s32 \t%amt2, 32, $amt;\n\t" -+ "shl.b32 \t%rhs, $src, %amt2;\n\t" -+ "add.u32 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int32Regs:$dst, (rotr (i32 Int32Regs:$src), (i32 Int32Regs:$amt)))]>, -+ Requires<[noHWROT32]>; -+ -+// 64-bit software rotate by immediate. $amt2 should equal 64 - $amt1. -+def ROT64imm_sw : -+ NVPTXInst<(outs Int64Regs:$dst), -+ (ins Int64Regs:$src, i32imm:$amt1, i32imm:$amt2), -+ "{{\n\t" -+ ".reg .b64 %lhs;\n\t" -+ ".reg .b64 %rhs;\n\t" -+ "shl.b64 \t%lhs, $src, $amt1;\n\t" -+ "shr.b64 \t%rhs, $src, $amt2;\n\t" -+ "add.u64 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ []>; -+ -+def SUB_FRM_64 : SDNodeXFormgetTargetConstant(64-N->getZExtValue(), SDLoc(N), MVT::i32); -+}]>; -+ -+def : Pat<(rotl Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>; -+def : Pat<(rotr Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>; -+ -+// 64-bit software rotate left by register. -+def ROTL64reg_sw : -+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b64 %lhs;\n\t" -+ ".reg .b64 %rhs;\n\t" -+ ".reg .u32 %amt2;\n\t" -+ "and.b32 \t%amt2, $amt, 63;\n\t" -+ "shl.b64 \t%lhs, $src, %amt2;\n\t" -+ "sub.u32 \t%amt2, 64, %amt2;\n\t" -+ "shr.b64 \t%rhs, $src, %amt2;\n\t" -+ "add.u64 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int64Regs:$dst, (rotl Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -+ -+def ROTR64reg_sw : -+ NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src, Int32Regs:$amt), -+ "{{\n\t" -+ ".reg .b64 %lhs;\n\t" -+ ".reg .b64 %rhs;\n\t" -+ ".reg .u32 %amt2;\n\t" -+ "and.b32 \t%amt2, $amt, 63;\n\t" -+ "shr.b64 \t%lhs, $src, %amt2;\n\t" -+ "sub.u32 \t%amt2, 64, %amt2;\n\t" -+ "shl.b64 \t%rhs, $src, %amt2;\n\t" -+ "add.u64 \t$dst, %lhs, %rhs;\n\t" -+ "}}", -+ [(set Int64Regs:$dst, (rotr Int64Regs:$src, (i32 Int32Regs:$amt)))]>; -+ -+// -+// Funnnel shift in clamp mode -+// -+ -+// Create SDNodes so they can be used in the DAG code, e.g. -+// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) -+def FUN_SHFL_CLAMP : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; -+def FUN_SHFR_CLAMP : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -+ -+def FUNSHFLCLAMP : -+ NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.l.clamp.b32 \t$dst, $lo, $hi, $amt;", -+ [(set Int32Regs:$dst, -+ (FUN_SHFL_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; -+ -+def FUNSHFRCLAMP : -+ NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.r.clamp.b32 \t$dst, $lo, $hi, $amt;", -+ [(set Int32Regs:$dst, -+ (FUN_SHFR_CLAMP (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>; - - // - // BFE - bit-field extract -@@ -3492,42 +3657,6 @@ - def: Pat<(v2i16 (scalar_to_vector (i16 Int16Regs:$a))), - (CVT_u32_u16 Int16Regs:$a, CvtNONE)>; - --// --// Funnel-Shift --// -- --// Create SDNodes so they can be used in the DAG code, e.g. --// NVPTXISelLowering (LowerShiftLeftParts and LowerShiftRightParts) --def fshl_clamp : SDNode<"NVPTXISD::FUN_SHFL_CLAMP", SDTIntShiftDOp, []>; --def fshr_clamp : SDNode<"NVPTXISD::FUN_SHFR_CLAMP", SDTIntShiftDOp, []>; -- --// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so --// no side effects. --let hasSideEffects = false in { -- multiclass ShfInst { -- def _i -- : NVPTXInst<(outs Int32Regs:$dst), -- (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -- [(set Int32Regs:$dst, -- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 imm:$amt)))]>, -- Requires<[hasHWROT32]>; -- -- def _r -- : NVPTXInst<(outs Int32Regs:$dst), -- (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -- "shf." # mode # ".b32 \t$dst, $lo, $hi, $amt;", -- [(set Int32Regs:$dst, -- (op (i32 Int32Regs:$lo), (i32 Int32Regs:$hi), (i32 Int32Regs:$amt)))]>, -- Requires<[hasHWROT32]>; -- } -- -- defm SHF_L_CLAMP : ShfInst<"l.clamp", fshl_clamp>; -- defm SHF_R_CLAMP : ShfInst<"r.clamp", fshr_clamp>; -- defm SHF_L_WRAP : ShfInst<"l.wrap", fshl>; -- defm SHF_R_WRAP : ShfInst<"r.wrap", fshr>; --} -- - // Count leading zeros - let hasSideEffects = false in { - def CLZr32 : NVPTXInst<(outs Int32Regs:$d), (ins Int32Regs:$a), -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td ---- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td -@@ -2537,45 +2537,59 @@ - : VLDG_G_ELE_V4<"v4.f32 \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];", Float32Regs>; - - --multiclass NG_TO_G { -+multiclass NG_TO_G { - def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -- "cvta." # Str # ".u32 \t$result, $src;", []>; -+ !strconcat("cvta.", Str, ".u32 \t$result, $src;"), -+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; - def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -- "cvta." # Str # ".u64 \t$result, $src;", []>; -+ !strconcat("cvta.", Str, ".u64 \t$result, $src;"), -+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -+ def _6432 : NVPTXInst<(outs Int64Regs:$result), (ins Int32Regs:$src), -+ "{{ .reg .b64 %tmp;\n\t" -+ #" cvt.u64.u32 \t%tmp, $src;\n\t" -+ #" cvta." # Str # ".u64 \t$result, %tmp; }}", -+ [(set Int64Regs:$result, (Intrin Int32Regs:$src))]>, -+ Requires<[ShortPtr]>; - } - --multiclass G_TO_NG { -+multiclass G_TO_NG { - def "" : NVPTXInst<(outs Int32Regs:$result), (ins Int32Regs:$src), -- "cvta.to." # Str # ".u32 \t$result, $src;", []>; -+ !strconcat("cvta.to.", Str, ".u32 \t$result, $src;"), -+ [(set Int32Regs:$result, (Intrin Int32Regs:$src))]>; - def _64 : NVPTXInst<(outs Int64Regs:$result), (ins Int64Regs:$src), -- "cvta.to." # Str # ".u64 \t$result, $src;", []>; -+ !strconcat("cvta.to.", Str, ".u64 \t$result, $src;"), -+ [(set Int64Regs:$result, (Intrin Int64Regs:$src))]>; -+ def _3264 : NVPTXInst<(outs Int32Regs:$result), (ins Int64Regs:$src), -+ "{{ .reg .b64 %tmp;\n\t" -+ #" cvta.to." # Str # ".u64 \t%tmp, $src;\n\t" -+ #" cvt.u32.u64 \t$result, %tmp; }}", -+ [(set Int32Regs:$result, (Intrin Int64Regs:$src))]>, -+ Requires<[ShortPtr]>; - } - --defm cvta_local : NG_TO_G<"local">; --defm cvta_shared : NG_TO_G<"shared">; --defm cvta_global : NG_TO_G<"global">; --defm cvta_const : NG_TO_G<"const">; -- --defm cvta_to_local : G_TO_NG<"local">; --defm cvta_to_shared : G_TO_NG<"shared">; --defm cvta_to_global : G_TO_NG<"global">; --defm cvta_to_const : G_TO_NG<"const">; -- --// nvvm.ptr.param.to.gen --defm cvta_param : NG_TO_G<"param">; -- --def : Pat<(int_nvvm_ptr_param_to_gen Int32Regs:$src), -- (cvta_param Int32Regs:$src)>; -- --def : Pat<(int_nvvm_ptr_param_to_gen Int64Regs:$src), -- (cvta_param_64 Int64Regs:$src)>; -+defm cvta_local : NG_TO_G<"local", int_nvvm_ptr_local_to_gen, useShortPtrLocal>; -+defm cvta_shared : NG_TO_G<"shared", int_nvvm_ptr_shared_to_gen, useShortPtrShared>; -+defm cvta_global : NG_TO_G<"global", int_nvvm_ptr_global_to_gen, False>; -+defm cvta_const : NG_TO_G<"const", int_nvvm_ptr_constant_to_gen, useShortPtrConst>; -+defm cvta_param : NG_TO_G<"param", int_nvvm_ptr_param_to_gen, False>; -+ -+defm cvta_to_local : G_TO_NG<"local", int_nvvm_ptr_gen_to_local, useShortPtrLocal>; -+defm cvta_to_shared : G_TO_NG<"shared", int_nvvm_ptr_gen_to_shared, useShortPtrShared>; -+defm cvta_to_global : G_TO_NG<"global", int_nvvm_ptr_gen_to_global, False>; -+defm cvta_to_const : G_TO_NG<"const", int_nvvm_ptr_gen_to_constant, useShortPtrConst>; - - // nvvm.ptr.gen.to.param --def : Pat<(int_nvvm_ptr_gen_to_param Int32Regs:$src), -- (IMOV32rr Int32Regs:$src)>; -+def nvvm_ptr_gen_to_param : NVPTXInst<(outs Int32Regs:$result), -+ (ins Int32Regs:$src), -+ "mov.u32 \t$result, $src;", -+ [(set Int32Regs:$result, -+ (int_nvvm_ptr_gen_to_param Int32Regs:$src))]>; -+def nvvm_ptr_gen_to_param_64 : NVPTXInst<(outs Int64Regs:$result), -+ (ins Int64Regs:$src), -+ "mov.u64 \t$result, $src;", -+ [(set Int64Regs:$result, -+ (int_nvvm_ptr_gen_to_param Int64Regs:$src))]>; - --def : Pat<(int_nvvm_ptr_gen_to_param Int64Regs:$src), -- (IMOV64rr Int64Regs:$src)>; - - // nvvm.move intrinsicc - def nvvm_move_i16 : NVPTXInst<(outs Int16Regs:$r), (ins Int16Regs:$s), -@@ -2618,6 +2632,24 @@ - [(set Int64Regs:$r, - (int_nvvm_move_ptr texternalsym:$s))]>;*/ - -+ -+// MoveParam %r1, param -+// ptr_local_to_gen %r2, %r1 -+// ptr_gen_to_local %r3, %r2 -+// -> -+// mov %r1, param -+ -+// @TODO: Revisit this. There is a type -+// contradiction between iPTRAny and iPTR for the addr defs, so the move_sym -+// instructions are not currently defined. However, we can use the ptr -+// variants and the asm printer will do the right thing. -+def : Pat<(i64 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -+ (MoveParam texternalsym:$src)))), -+ (nvvm_move_ptr64 texternalsym:$src)>; -+def : Pat<(i32 (int_nvvm_ptr_gen_to_local (int_nvvm_ptr_local_to_gen -+ (MoveParam texternalsym:$src)))), -+ (nvvm_move_ptr32 texternalsym:$src)>; -+ - def texsurf_handles - : NVPTXInst<(outs Int64Regs:$result), (ins imem:$src), - "mov.u64 \t$result, $src;", []>; -@@ -2701,9 +2733,134 @@ - def : Pat<(int_nvvm_read_ptx_sreg_envreg31), (MOV_SPECIAL ENVREG31)>; - - -+// rotate builtin support -+ -+def ROTATE_B32_HW_IMM -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$src, i32imm:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, -+ (int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)))]>, -+ Requires<[hasHWROT32]> ; -+ -+def ROTATE_B32_HW_REG -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$src, Int32Regs:$amt), -+ "shf.l.wrap.b32 \t$dst, $src, $src, $amt;", -+ [(set Int32Regs:$dst, -+ (int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt))]>, -+ Requires<[hasHWROT32]> ; -+ -+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, (i32 imm:$amt)), -+ (ROT32imm_sw Int32Regs:$src, imm:$amt, (SUB_FRM_32 node:$amt))>, -+ Requires<[noHWROT32]> ; -+ -+def : Pat<(int_nvvm_rotate_b32 Int32Regs:$src, Int32Regs:$amt), -+ (ROTL32reg_sw Int32Regs:$src, Int32Regs:$amt)>, -+ Requires<[noHWROT32]> ; -+ -+let hasSideEffects = false in { -+ def GET_LO_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -+ !strconcat("{{\n\t", -+ ".reg .b32 %dummy;\n\t", -+ "mov.b64 \t{$dst,%dummy}, $src;\n\t", -+ "}}"), -+ []> ; -+ -+ def GET_HI_INT64 : NVPTXInst<(outs Int32Regs:$dst), (ins Int64Regs:$src), -+ !strconcat("{{\n\t", -+ ".reg .b32 %dummy;\n\t", -+ "mov.b64 \t{%dummy,$dst}, $src;\n\t", -+ "}}"), -+ []> ; -+} -+ -+let hasSideEffects = false in { -+ def PACK_TWO_INT32 -+ : NVPTXInst<(outs Int64Regs:$dst), (ins Int32Regs:$lo, Int32Regs:$hi), -+ "mov.b64 \t$dst, {{$lo, $hi}};", []> ; -+} -+ - def : Pat<(int_nvvm_swap_lo_hi_b64 Int64Regs:$src), -- (V2I32toI64 (I64toI32H Int64Regs:$src), -- (I64toI32L Int64Regs:$src))> ; -+ (PACK_TWO_INT32 (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src))> ; -+ -+// Funnel shift, requires >= sm_32. Does not trap if amt is out of range, so -+// no side effects. -+let hasSideEffects = false in { -+ def SHF_L_WRAP_B32_IMM -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+ -+ def SHF_L_WRAP_B32_REG -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.l.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+ -+ def SHF_R_WRAP_B32_IMM -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, i32imm:$amt), -+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+ -+ def SHF_R_WRAP_B32_REG -+ : NVPTXInst<(outs Int32Regs:$dst), -+ (ins Int32Regs:$lo, Int32Regs:$hi, Int32Regs:$amt), -+ "shf.r.wrap.b32 \t$dst, $lo, $hi, $amt;",[]>, -+ Requires<[hasHWROT32]>; -+} -+ -+// HW version of rotate 64 -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (PACK_TWO_INT32 -+ (SHF_L_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), imm:$amt), -+ (SHF_L_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), imm:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -+ (PACK_TWO_INT32 -+ (SHF_L_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt), -+ (SHF_L_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+ -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (PACK_TWO_INT32 -+ (SHF_R_WRAP_B32_IMM (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), imm:$amt), -+ (SHF_R_WRAP_B32_IMM (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), imm:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -+ (PACK_TWO_INT32 -+ (SHF_R_WRAP_B32_REG (GET_LO_INT64 Int64Regs:$src), -+ (GET_HI_INT64 Int64Regs:$src), Int32Regs:$amt), -+ (SHF_R_WRAP_B32_REG (GET_HI_INT64 Int64Regs:$src), -+ (GET_LO_INT64 Int64Regs:$src), Int32Regs:$amt))>, -+ Requires<[hasHWROT32]>; -+ -+// SW version of rotate 64 -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, imm:$amt, (SUB_FRM_64 node:$amt))>, -+ Requires<[noHWROT32]>; -+def : Pat<(int_nvvm_rotate_b64 Int64Regs:$src, Int32Regs:$amt), -+ (ROTL64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -+ Requires<[noHWROT32]>; -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, (i32 imm:$amt)), -+ (ROT64imm_sw Int64Regs:$src, (SUB_FRM_64 node:$amt), imm:$amt)>, -+ Requires<[noHWROT32]>; -+def : Pat<(int_nvvm_rotate_right_b64 Int64Regs:$src, Int32Regs:$amt), -+ (ROTR64reg_sw Int64Regs:$src, Int32Regs:$amt)>, -+ Requires<[noHWROT32]>; -+ - - //----------------------------------- - // Texture Intrinsics -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp ---- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp -@@ -1109,21 +1109,11 @@ - AddrSpaceCastSDNode *CastN = cast(N); - unsigned SrcAddrSpace = CastN->getSrcAddressSpace(); - unsigned DstAddrSpace = CastN->getDestAddressSpace(); -- SDLoc DL(N); - assert(SrcAddrSpace != DstAddrSpace && - "addrspacecast must be between different address spaces"); - - if (DstAddrSpace == ADDRESS_SPACE_GENERIC) { - // Specific to generic -- -- if (TM.is64Bit() && TM.getPointerSizeInBits(SrcAddrSpace) == 32) { -- SDValue CvtNone = -- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -- SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u64_u32, DL, MVT::i64, -- Src, CvtNone); -- Src = SDValue(Cvt, 0); -- } -- - unsigned Opc; - switch (SrcAddrSpace) { - default: report_fatal_error("Bad address space in addrspacecast"); -@@ -1131,16 +1121,26 @@ - Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global; - break; - case ADDRESS_SPACE_SHARED: -- Opc = TM.is64Bit() ? NVPTX::cvta_shared_64 : NVPTX::cvta_shared; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -+ ? NVPTX::cvta_shared_6432 -+ : NVPTX::cvta_shared_64) -+ : NVPTX::cvta_shared; - break; - case ADDRESS_SPACE_CONST: -- Opc = TM.is64Bit() ? NVPTX::cvta_const_64 : NVPTX::cvta_const; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -+ ? NVPTX::cvta_const_6432 -+ : NVPTX::cvta_const_64) -+ : NVPTX::cvta_const; - break; - case ADDRESS_SPACE_LOCAL: -- Opc = TM.is64Bit() ? NVPTX::cvta_local_64 : NVPTX::cvta_local; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32 -+ ? NVPTX::cvta_local_6432 -+ : NVPTX::cvta_local_64) -+ : NVPTX::cvta_local; - break; - } -- ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src)); -+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -+ Src)); - return; - } else { - // Generic to specific -@@ -1153,28 +1153,30 @@ - Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global; - break; - case ADDRESS_SPACE_SHARED: -- Opc = TM.is64Bit() ? NVPTX::cvta_to_shared_64 : NVPTX::cvta_to_shared; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -+ ? NVPTX::cvta_to_shared_3264 -+ : NVPTX::cvta_to_shared_64) -+ : NVPTX::cvta_to_shared; - break; - case ADDRESS_SPACE_CONST: -- Opc = TM.is64Bit() ? NVPTX::cvta_to_const_64 : NVPTX::cvta_to_const; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -+ ? NVPTX::cvta_to_const_3264 -+ : NVPTX::cvta_to_const_64) -+ : NVPTX::cvta_to_const; - break; - case ADDRESS_SPACE_LOCAL: -- Opc = TM.is64Bit() ? NVPTX::cvta_to_local_64 : NVPTX::cvta_to_local; -+ Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32 -+ ? NVPTX::cvta_to_local_3264 -+ : NVPTX::cvta_to_local_64) -+ : NVPTX::cvta_to_local; - break; - case ADDRESS_SPACE_PARAM: -- Opc = TM.is64Bit() ? NVPTX::IMOV64rr : NVPTX::IMOV32rr; -+ Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64 -+ : NVPTX::nvvm_ptr_gen_to_param; - break; - } -- -- SDNode *CVTA = CurDAG->getMachineNode(Opc, DL, N->getValueType(0), Src); -- if (TM.is64Bit() && TM.getPointerSizeInBits(DstAddrSpace) == 32) { -- SDValue CvtNone = -- CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL, MVT::i32); -- CVTA = CurDAG->getMachineNode(NVPTX::CVT_u32_u64, DL, MVT::i32, -- SDValue(CVTA, 0), CvtNone); -- } -- -- ReplaceNode(N, CVTA); -+ ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0), -+ Src)); - return; - } - } -diff -ruN --strip-trailing-cr a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp ---- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp -@@ -594,13 +594,20 @@ - setOperationAction(ISD::BITREVERSE, MVT::i32, Legal); - setOperationAction(ISD::BITREVERSE, MVT::i64, Legal); - -- setOperationAction({ISD::ROTL, ISD::ROTR}, -- {MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}, -- Expand); -- -- if (STI.hasHWROT32()) -- setOperationAction({ISD::FSHL, ISD::FSHR}, MVT::i32, Legal); -+ // TODO: we may consider expanding ROTL/ROTR on older GPUs. Currently on GPUs -+ // that don't have h/w rotation we lower them to multi-instruction assembly. -+ // See ROT*_sw in NVPTXIntrInfo.td -+ setOperationAction(ISD::ROTL, MVT::i64, Legal); -+ setOperationAction(ISD::ROTR, MVT::i64, Legal); -+ setOperationAction(ISD::ROTL, MVT::i32, Legal); -+ setOperationAction(ISD::ROTR, MVT::i32, Legal); - -+ setOperationAction(ISD::ROTL, MVT::i16, Expand); -+ setOperationAction(ISD::ROTL, MVT::v2i16, Expand); -+ setOperationAction(ISD::ROTR, MVT::i16, Expand); -+ setOperationAction(ISD::ROTR, MVT::v2i16, Expand); -+ setOperationAction(ISD::ROTL, MVT::i8, Expand); -+ setOperationAction(ISD::ROTR, MVT::i8, Expand); - setOperationAction(ISD::BSWAP, MVT::i16, Expand); - - setOperationAction(ISD::BR_JT, MVT::Other, Custom); -diff -ruN --strip-trailing-cr a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll ---- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -+++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll -@@ -31,19 +31,6 @@ - declare i64 @llvm.nvvm.bitcast.d2ll(double) - declare double @llvm.nvvm.bitcast.ll2d(i64) - --declare i32 @llvm.nvvm.rotate.b32(i32, i32) --declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) --declare i64 @llvm.nvvm.rotate.b64(i64, i32) -- --declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) --declare ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr) --declare ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr) --declare ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr) --declare ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1)) --declare ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3)) --declare ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4)) --declare ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5)) -- - ; CHECK-LABEL: @simple_upgrade - define void @simple_upgrade(i32 %a, i64 %b, i16 %c) { - ; CHECK: call i32 @llvm.bitreverse.i32(i32 %a) -@@ -152,42 +139,4 @@ - %r4 = call double @llvm.nvvm.bitcast.ll2d(i64 %b) - - ret void --} -- --; CHECK-LABEL: @rotate --define void @rotate(i32 %a, i64 %b) { --; CHECK: call i32 @llvm.fshl.i32(i32 %a, i32 %a, i32 6) --; CHECK: call i64 @llvm.fshr.i64(i64 %b, i64 %b, i64 7) --; CHECK: call i64 @llvm.fshl.i64(i64 %b, i64 %b, i64 8) --; -- %r1 = call i32 @llvm.nvvm.rotate.b32(i32 %a, i32 6) -- %r2 = call i64 @llvm.nvvm.rotate.right.b64(i64 %b, i32 7) -- %r3 = call i64 @llvm.nvvm.rotate.b64(i64 %b, i32 8) -- ret void --} -- --; CHECK-LABEL: @addrspacecast --define void @addrspacecast(ptr %p0) { --; CHECK: %1 = addrspacecast ptr %p0 to ptr addrspace(1) --; CHECK: %2 = addrspacecast ptr addrspace(1) %1 to ptr --; CHECK: %3 = addrspacecast ptr %2 to ptr addrspace(3) --; CHECK: %4 = addrspacecast ptr addrspace(3) %3 to ptr --; CHECK: %5 = addrspacecast ptr %4 to ptr addrspace(4) --; CHECK: %6 = addrspacecast ptr addrspace(4) %5 to ptr --; CHECK: %7 = addrspacecast ptr %6 to ptr addrspace(5) --; CHECK: %8 = addrspacecast ptr addrspace(5) %7 to ptr --; -- %p1 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %p0) -- %p2 = call ptr @llvm.nvvm.ptr.global.to.gen.p0.p1(ptr addrspace(1) %p1) -- -- %p3 = call ptr addrspace(3) @llvm.nvvm.ptr.gen.to.shared.p3.p0(ptr %p2) -- %p4 = call ptr @llvm.nvvm.ptr.shared.to.gen.p0.p3(ptr addrspace(3) %p3) -- -- %p5 = call ptr addrspace(4) @llvm.nvvm.ptr.gen.to.constant.p4.p0(ptr %p4) -- %p6 = call ptr @llvm.nvvm.ptr.constant.to.gen.p0.p4(ptr addrspace(4) %p5) -- -- %p7 = call ptr addrspace(5) @llvm.nvvm.ptr.gen.to.local.p5.p0(ptr %p6) -- %p8 = call ptr @llvm.nvvm.ptr.local.to.gen.p0.p5(ptr addrspace(5) %p7) -- -- ret void --} -+} -\ No newline at end of file -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/freeze.ll b/llvm/test/CodeGen/AMDGPU/freeze.ll ---- a/llvm/test/CodeGen/AMDGPU/freeze.ll -+++ b/llvm/test/CodeGen/AMDGPU/freeze.ll -@@ -1,1856 +0,0 @@ --; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py --; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-SDAG %s --; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1010 < %s | FileCheck -check-prefixes=GFX10,GFX10-GISEL %s --; RUN: llc -global-isel=0 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-SDAG %s --; RUN: llc -global-isel=1 -mtriple=amdgcn-mesa-mesa3d -mcpu=gfx1100 -amdgpu-enable-delay-alu=0 < %s | FileCheck -check-prefixes=GFX11,GFX11-GISEL %s -- --define void @freeze_v2i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v2i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v2i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <2 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <2 x i32> %a -- store <2 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v3i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v3i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx3 v[4:6], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx3 v[2:3], v[4:6], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v3i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b96 v[4:6], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b96 v[2:3], v[4:6], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <3 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <3 x i32> %a -- store <3 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v4i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v4i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v4i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <4 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <4 x i32> %a -- store <4 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v5i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v5i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dword v8, v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v8, off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v5i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dword v8, v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v8, off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v5i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b32 v8, v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v8, off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v5i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <5 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <5 x i32> %a -- store <5 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v6i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v6i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v6i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx2 v[8:9], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[8:9], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v6i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b64 v[8:9], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[8:9], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v6i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <6 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <6 x i32> %a -- store <6 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v7i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v7i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v7i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx3 v[8:10], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[8:10], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v7i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v7i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b96 v[8:10], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[8:10], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <7 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <7 x i32> %a -- store <7 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v8i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v8i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v8i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v8i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v8i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <8 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <8 x i32> %a -- store <8 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v9i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v9i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x2 --; GFX10-SDAG-NEXT: global_load_dword v12, v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v12, off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v9i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x2 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dword v12, v[0:1], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v12, off offset:32 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v9i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x2 --; GFX11-SDAG-NEXT: global_load_b32 v12, v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v12, off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v9i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x2 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:32 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <9 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <9 x i32> %a -- store <9 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v10i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v10i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: s_clause 0x2 --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-NEXT: global_load_dwordx2 v[12:13], v[0:1], off offset:32 --; GFX10-NEXT: s_waitcnt vmcnt(2) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_waitcnt vmcnt(1) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx2 v[2:3], v[12:13], off offset:32 --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v10i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: s_clause 0x2 --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off offset:32 --; GFX11-NEXT: s_waitcnt vmcnt(2) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_waitcnt vmcnt(1) --; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off offset:32 --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <10 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <10 x i32> %a -- store <10 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v11i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v11i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: s_clause 0x2 --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-NEXT: global_load_dwordx3 v[12:14], v[0:1], off offset:32 --; GFX10-NEXT: s_waitcnt vmcnt(2) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_waitcnt vmcnt(1) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx3 v[2:3], v[12:14], off offset:32 --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v11i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: s_clause 0x2 --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-NEXT: global_load_b96 v[12:14], v[0:1], off offset:32 --; GFX11-NEXT: s_waitcnt vmcnt(2) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_waitcnt vmcnt(1) --; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b96 v[2:3], v[12:14], off offset:32 --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <11 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <11 x i32> %a -- store <11 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v12i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_v12i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: s_clause 0x2 --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-NEXT: s_waitcnt vmcnt(2) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_waitcnt vmcnt(1) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_v12i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: s_clause 0x2 --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-NEXT: s_waitcnt vmcnt(2) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_waitcnt vmcnt(1) --; GFX11-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load <12 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <12 x i32> %a -- store <12 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} --define void @freeze_v13i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v13i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dword v16, v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v16, off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v13i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dword v16, v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v16, off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v13i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b32 v16, v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v16, off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v13i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <13 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <13 x i32> %a -- store <13 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v14i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v14i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v14i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[16:17], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[16:17], off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v14i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b64 v[16:17], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[16:17], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v14i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <14 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <14 x i32> %a -- store <14 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v15i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v15i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v15i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx3 v[16:18], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[16:18], off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v15i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v15i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b96 v[16:18], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[16:18], off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <15 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <15 x i32> %a -- store <15 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v16i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v16i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x3 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v16i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x3 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v16i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x3 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v16i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x3 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <16 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <16 x i32> %a -- store <16 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v17i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v17i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dword v20, v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v20, off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v17i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dword v20, v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v20, off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v17i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b32 v20, v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v20, off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v17i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <17 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <17 x i32> %a -- store <17 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v18i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v18i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v18i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[20:21], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[20:21], off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v18i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b64 v[20:21], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[20:21], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v18i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <18 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <18 x i32> %a -- store <18 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v19i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v19i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v19i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx3 v[20:22], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[20:22], off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v19i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v19i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b96 v[20:22], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[20:22], off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <19 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <19 x i32> %a -- store <19 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v20i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v20i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x4 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v20i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x4 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v20i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x4 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v20i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x4 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <20 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <20 x i32> %a -- store <20 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v21i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v21i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x5 --; GFX10-SDAG-NEXT: global_load_dword v24, v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dword v[2:3], v24, off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v21i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x5 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dword v24, v[0:1], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dword v[2:3], v24, off offset:80 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v21i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x5 --; GFX11-SDAG-NEXT: global_load_b32 v24, v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b32 v[2:3], v24, off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v21i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x5 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b32 v0, v[0:1], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b32 v[2:3], v0, off offset:80 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <21 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <21 x i32> %a -- store <21 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v22i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v22i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x5 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v22i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x5 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[24:25], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[24:25], off offset:80 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v22i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x5 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b64 v[24:25], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[24:25], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v22i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x5 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <22 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <22 x i32> %a -- store <22 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v30i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v30i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x7 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 --; GFX10-SDAG-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX10-SDAG-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v30i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x7 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 --; GFX10-GISEL-NEXT: global_load_dwordx2 v[32:33], v[0:1], off offset:112 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx2 v[2:3], v[32:33], off offset:112 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v30i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x7 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 --; GFX11-SDAG-NEXT: global_load_b64 v[32:33], v[0:1], off offset:112 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX11-SDAG-NEXT: global_store_b64 v[2:3], v[32:33], off offset:112 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v30i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x7 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 --; GFX11-GISEL-NEXT: global_load_b64 v[0:1], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b64 v[2:3], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <30 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <30 x i32> %a -- store <30 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v31i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v31i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x7 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 --; GFX10-SDAG-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX10-SDAG-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v31i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x7 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 --; GFX10-GISEL-NEXT: global_load_dwordx3 v[32:34], v[0:1], off offset:112 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx3 v[2:3], v[32:34], off offset:112 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v31i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x7 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 --; GFX11-SDAG-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX11-SDAG-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v31i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x7 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 --; GFX11-GISEL-NEXT: global_load_b96 v[32:34], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b96 v[2:3], v[32:34], off offset:112 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <31 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <31 x i32> %a -- store <31 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_v32i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_v32i32: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x7 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:96 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:112 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:64 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:80 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:32 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:48 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[28:31], v[0:1], off --; GFX10-SDAG-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:96 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:112 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:64 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:80 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:32 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:48 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[28:31], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:16 --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_v32i32: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x7 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[12:15], v[0:1], off offset:32 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[16:19], v[0:1], off offset:48 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[20:23], v[0:1], off offset:64 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[24:27], v[0:1], off offset:80 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[28:31], v[0:1], off offset:96 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[32:35], v[0:1], off offset:112 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[12:15], off offset:32 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[16:19], off offset:48 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[20:23], off offset:64 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[24:27], off offset:80 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[28:31], off offset:96 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[32:35], off offset:112 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_v32i32: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x7 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:96 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off offset:112 --; GFX11-SDAG-NEXT: global_load_b128 v[12:15], v[0:1], off offset:64 --; GFX11-SDAG-NEXT: global_load_b128 v[16:19], v[0:1], off offset:80 --; GFX11-SDAG-NEXT: global_load_b128 v[20:23], v[0:1], off offset:32 --; GFX11-SDAG-NEXT: global_load_b128 v[24:27], v[0:1], off offset:48 --; GFX11-SDAG-NEXT: global_load_b128 v[28:31], v[0:1], off --; GFX11-SDAG-NEXT: global_load_b128 v[32:35], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(7) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:96 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(6) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off offset:112 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(5) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[12:15], off offset:64 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(4) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[16:19], off offset:80 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(3) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[20:23], off offset:32 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(2) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[24:27], off offset:48 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[28:31], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[32:35], off offset:16 --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_v32i32: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x7 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: global_load_b128 v[12:15], v[0:1], off offset:32 --; GFX11-GISEL-NEXT: global_load_b128 v[16:19], v[0:1], off offset:48 --; GFX11-GISEL-NEXT: global_load_b128 v[20:23], v[0:1], off offset:64 --; GFX11-GISEL-NEXT: global_load_b128 v[24:27], v[0:1], off offset:80 --; GFX11-GISEL-NEXT: global_load_b128 v[28:31], v[0:1], off offset:96 --; GFX11-GISEL-NEXT: global_load_b128 v[32:35], v[0:1], off offset:112 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(7) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(6) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(5) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[12:15], off offset:32 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(4) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[16:19], off offset:48 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(3) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[20:23], off offset:64 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(2) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[24:27], off offset:80 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[28:31], off offset:96 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[32:35], off offset:112 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load <32 x i32>, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze <32 x i32> %a -- store <32 x i32> %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i32(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_i32: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dword v0, v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dword v[2:3], v0, off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_i32: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b32 v0, v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b32 v[2:3], v0, off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load i32, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i32 %a -- store i32 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i64(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_i64: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx2 v[0:1], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx2 v[2:3], v[0:1], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_i64: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b64 v[0:1], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b64 v[2:3], v[0:1], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load i64, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i64 %a -- store i64 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_float(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_float: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dword v0, v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dword v[2:3], v0, off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_float: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b32 v0, v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b32 v[2:3], v0, off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load float, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze float %a -- store float %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i128(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-LABEL: freeze_i128: --; GFX10: ; %bb.0: --; GFX10-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-NEXT: s_waitcnt vmcnt(0) --; GFX10-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-LABEL: freeze_i128: --; GFX11: ; %bb.0: --; GFX11-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-NEXT: s_waitcnt vmcnt(0) --; GFX11-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-NEXT: s_setpc_b64 s[30:31] -- %a = load i128, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i128 %a -- store i128 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -- --define void @freeze_i256(ptr addrspace(1) %ptra, ptr addrspace(1) %ptrb) { --; GFX10-SDAG-LABEL: freeze_i256: --; GFX10-SDAG: ; %bb.0: --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-SDAG-NEXT: s_clause 0x1 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[4:7], v[0:1], off offset:16 --; GFX10-SDAG-NEXT: global_load_dwordx4 v[8:11], v[0:1], off --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[4:7], off offset:16 --; GFX10-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX10-SDAG-NEXT: global_store_dwordx4 v[2:3], v[8:11], off --; GFX10-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX10-GISEL-LABEL: freeze_i256: --; GFX10-GISEL: ; %bb.0: --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX10-GISEL-NEXT: s_clause 0x1 --; GFX10-GISEL-NEXT: global_load_dwordx4 v[4:7], v[0:1], off --; GFX10-GISEL-NEXT: global_load_dwordx4 v[8:11], v[0:1], off offset:16 --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[4:7], off --; GFX10-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX10-GISEL-NEXT: global_store_dwordx4 v[2:3], v[8:11], off offset:16 --; GFX10-GISEL-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-SDAG-LABEL: freeze_i256: --; GFX11-SDAG: ; %bb.0: --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-SDAG-NEXT: s_clause 0x1 --; GFX11-SDAG-NEXT: global_load_b128 v[4:7], v[0:1], off offset:16 --; GFX11-SDAG-NEXT: global_load_b128 v[8:11], v[0:1], off --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(1) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[4:7], off offset:16 --; GFX11-SDAG-NEXT: s_waitcnt vmcnt(0) --; GFX11-SDAG-NEXT: global_store_b128 v[2:3], v[8:11], off --; GFX11-SDAG-NEXT: s_setpc_b64 s[30:31] --; --; GFX11-GISEL-LABEL: freeze_i256: --; GFX11-GISEL: ; %bb.0: --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0) --; GFX11-GISEL-NEXT: s_clause 0x1 --; GFX11-GISEL-NEXT: global_load_b128 v[4:7], v[0:1], off --; GFX11-GISEL-NEXT: global_load_b128 v[8:11], v[0:1], off offset:16 --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(1) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[4:7], off --; GFX11-GISEL-NEXT: s_waitcnt vmcnt(0) --; GFX11-GISEL-NEXT: global_store_b128 v[2:3], v[8:11], off offset:16 --; GFX11-GISEL-NEXT: s_setpc_b64 s[30:31] -- %a = load i256, ptr addrspace(1) %ptra, align 4 -- %freeze = freeze i256 %a -- store i256 %freeze, ptr addrspace(1) %ptrb, align 4 -- ret void --} -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-unmerge-values.mir -@@ -171,9 +171,11 @@ - ; GCN-LABEL: name: test_unmerge_values_s_s64_s_s64_s64_s_s192 - ; GCN: liveins: $sgpr0_sgpr1_sgpr2_sgpr3 - ; GCN-NEXT: {{ $}} -- ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr(s192) = G_IMPLICIT_DEF -- ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr(s64), [[UV1:%[0-9]+]]:sgpr(s64), [[UV2:%[0-9]+]]:sgpr(s64) = G_UNMERGE_VALUES [[DEF]](s192) -- ; GCN-NEXT: S_ENDPGM 0, implicit [[UV]](s64), implicit [[UV1]](s64), implicit [[UV2]](s64) -+ ; GCN-NEXT: [[DEF:%[0-9]+]]:sgpr_192 = IMPLICIT_DEF -+ ; GCN-NEXT: [[COPY:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub0_sub1 -+ ; GCN-NEXT: [[COPY1:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub2_sub3 -+ ; GCN-NEXT: [[COPY2:%[0-9]+]]:sreg_64 = COPY [[DEF]].sub4_sub5 -+ ; GCN-NEXT: S_ENDPGM 0, implicit [[COPY]], implicit [[COPY1]], implicit [[COPY2]] - %0:sgpr(s192) = G_IMPLICIT_DEF - %1:sgpr(s64), %2:sgpr(s64), %3:sgpr(s64) = G_UNMERGE_VALUES %0 - S_ENDPGM 0, implicit %1, implicit %2, implicit %3 -@@ -292,11 +294,11 @@ - ; GCN-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:sgpr_384(<12 x s32>) = G_CONCAT_VECTORS [[COPY]](<3 x s32>), [[COPY1]](<3 x s32>), [[COPY2]](<3 x s32>), [[COPY3]](<3 x s32>) - ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub0_sub1_sub2(<12 x s32>) - ; GCN-NEXT: [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>) = COPY [[CONCAT_VECTORS]].sub3_sub4_sub5(<12 x s32>) -- ; GCN-NEXT: [[COPY4:%[0-9]+]]:sgpr_96(<3 x s32>), [[COPY5:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -- ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[COPY4]](<3 x s32>) -- ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[COPY5]](<3 x s32>) -- ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV]](<3 x s32>) -- ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV1]](<3 x s32>) -+ ; GCN-NEXT: [[UV:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV1:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV2:%[0-9]+]]:sgpr_96(<3 x s32>), [[UV3:%[0-9]+]]:sgpr_96(<3 x s32>) = G_UNMERGE_VALUES [[CONCAT_VECTORS]](<12 x s32>) -+ ; GCN-NEXT: $sgpr0_sgpr1_sgpr2 = COPY [[UV]](<3 x s32>) -+ ; GCN-NEXT: $sgpr4_sgpr5_sgpr6 = COPY [[UV1]](<3 x s32>) -+ ; GCN-NEXT: $sgpr8_sgpr9_sgpr10 = COPY [[UV2]](<3 x s32>) -+ ; GCN-NEXT: $sgpr12_sgpr13_sgpr14 = COPY [[UV3]](<3 x s32>) - %0:sgpr(<3 x s32>) = COPY $sgpr0_sgpr1_sgpr2 - %1:sgpr(<3 x s32>) = COPY $sgpr4_sgpr5_sgpr6 - %2:sgpr(<3 x s32>) = COPY $sgpr8_sgpr9_sgpr10 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-freeze.mir -@@ -171,8 +171,12 @@ - - ; CHECK-LABEL: name: test_freeze_s448 - ; CHECK: [[COPY:%[0-9]+]]:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 -- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s512) = G_FREEZE [[COPY]] -- ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[FREEZE]](s512) -+ ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[COPY]](s512) -+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(s448) = G_FREEZE [[TRUNC]] -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s64), [[UV1:%[0-9]+]]:_(s64), [[UV2:%[0-9]+]]:_(s64), [[UV3:%[0-9]+]]:_(s64), [[UV4:%[0-9]+]]:_(s64), [[UV5:%[0-9]+]]:_(s64), [[UV6:%[0-9]+]]:_(s64) = G_UNMERGE_VALUES [[FREEZE]](s448) -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(s64) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[MV:%[0-9]+]]:_(s512) = G_MERGE_VALUES [[UV]](s64), [[UV1]](s64), [[UV2]](s64), [[UV3]](s64), [[UV4]](s64), [[UV5]](s64), [[UV6]](s64), [[DEF]](s64) -+ ; CHECK-NEXT: $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 = COPY [[MV]](s512) - %0:_(s512) = COPY $vgpr0_vgpr1_vgpr2_vgpr3_vgpr4_vgpr5_vgpr6_vgpr7_vgpr8_vgpr9_vgpr10_vgpr11_vgpr12_vgpr13_vgpr14_vgpr15 - %1:_(s448) = G_TRUNC %0 - %2:_(s448) = G_FREEZE %1 -@@ -395,12 +399,14 @@ - bb.0: - - ; CHECK-LABEL: name: test_freeze_v33s32 -- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<32 x s32>) -- ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE1]](s32) -+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(s32) = G_FREEZE [[DEF1]] -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE]](<16 x s32>) -+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[FREEZE1]](<16 x s32>) -+ ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<33 x s32>) = G_BUILD_VECTOR [[UV]](s32), [[UV1]](s32), [[UV2]](s32), [[UV3]](s32), [[UV4]](s32), [[UV5]](s32), [[UV6]](s32), [[UV7]](s32), [[UV8]](s32), [[UV9]](s32), [[UV10]](s32), [[UV11]](s32), [[UV12]](s32), [[UV13]](s32), [[UV14]](s32), [[UV15]](s32), [[UV16]](s32), [[UV17]](s32), [[UV18]](s32), [[UV19]](s32), [[UV20]](s32), [[UV21]](s32), [[UV22]](s32), [[UV23]](s32), [[UV24]](s32), [[UV25]](s32), [[UV26]](s32), [[UV27]](s32), [[UV28]](s32), [[UV29]](s32), [[UV30]](s32), [[UV31]](s32), [[FREEZE2]](s32) - ; CHECK-NEXT: S_NOP 0, implicit [[BUILD_VECTOR]](<33 x s32>) - %0:_(<33 x s32>) = G_IMPLICIT_DEF - %1:_(<33 x s32>) = G_FREEZE %0 -@@ -413,10 +419,12 @@ - bb.0: - - ; CHECK-LABEL: name: test_freeze_v64s32 -- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -- ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<32 x s32>) = G_FREEZE [[DEF]] -- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<32 x s32>), [[FREEZE1]](<32 x s32>) -+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[FREEZE:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE1:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE2:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[FREEZE3:%[0-9]+]]:_(<16 x s32>) = G_FREEZE [[DEF]] -+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[FREEZE]](<16 x s32>), [[FREEZE1]](<16 x s32>), [[FREEZE2]](<16 x s32>), [[FREEZE3]](<16 x s32>) - ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>) - %0:_(<64 x s32>) = G_IMPLICIT_DEF - %1:_(<64 x s32>) = G_FREEZE %0 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-implicit-def.mir -@@ -135,9 +135,8 @@ - bb.0: - - ; CHECK-LABEL: name: test_implicit_def_s448 -- ; CHECK: [[DEF:%[0-9]+]]:_(s512) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[TRUNC:%[0-9]+]]:_(s448) = G_TRUNC [[DEF]](s512) -- ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[TRUNC]](s448), 0 -+ ; CHECK: [[DEF:%[0-9]+]]:_(s448) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:_(s32) = G_EXTRACT [[DEF]](s448), 0 - ; CHECK-NEXT: $vgpr0 = COPY [[EXTRACT]](s32) - %0:_(s448) = G_IMPLICIT_DEF - %1:_(s32) = G_EXTRACT %0, 0 -@@ -297,6 +296,18 @@ - ... - - --- -+name: test_implicit_def_v17s32 -+body: | -+ bb.0: -+ -+ ; CHECK-LABEL: name: test_implicit_def_v17s32 -+ ; CHECK: [[DEF:%[0-9]+]]:_(<17 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: S_NOP 0, implicit [[DEF]](<17 x s32>) -+ %0:_(<17 x s32>) = G_IMPLICIT_DEF -+ S_NOP 0, implicit %0 -+... -+ -+--- - name: test_implicit_def_v32s32 - body: | - bb.0: -@@ -317,9 +328,9 @@ - ; CHECK-LABEL: name: test_implicit_def_v33s32 - ; CHECK: liveins: $vgpr0_vgpr1 - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[DEF1:%[0-9]+]]:_(s32) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) - ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 - ; CHECK-NEXT: G_STORE [[UV]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) - ; CHECK-NEXT: G_STORE [[DEF1]](s32), [[COPY]](p1) :: (volatile store (s32), addrspace 1) -@@ -337,9 +348,10 @@ - bb.0: - - ; CHECK-LABEL: name: test_implicit_def_v64s32 -- ; CHECK: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -- ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<32 x s32>), [[DEF]](<32 x s32>) -- ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[DEF]](<32 x s32>) -+ ; CHECK: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[CONCAT_VECTORS1:%[0-9]+]]:_(<32 x s32>) = G_CONCAT_VECTORS [[DEF]](<16 x s32>), [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: S_NOP 0, implicit [[CONCAT_VECTORS]](<64 x s32>), implicit [[CONCAT_VECTORS1]](<32 x s32>) - %0:_(<64 x s32>) = G_IMPLICIT_DEF - %1:_(<32 x s32>), %2:_(<32 x s32>) = G_UNMERGE_VALUES %0 - S_NOP 0, implicit %0, implicit %1 -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-insert-vector-elt.mir -@@ -190,11 +190,13 @@ - ; CHECK-LABEL: name: insert_vector_elt_64_65_v64s32 - ; CHECK: liveins: $sgpr0_sgpr1, $vgpr0_vgpr1, $vgpr2_vgpr3 - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(p1) = COPY $vgpr0_vgpr1 - ; CHECK-NEXT: [[COPY1:%[0-9]+]]:_(p1) = COPY $vgpr2_vgpr3 -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>), [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>), [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<4 x s32>), [[UV1:%[0-9]+]]:_(<4 x s32>), [[UV2:%[0-9]+]]:_(<4 x s32>), [[UV3:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(<4 x s32>), [[UV5:%[0-9]+]]:_(<4 x s32>), [[UV6:%[0-9]+]]:_(<4 x s32>), [[UV7:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV8:%[0-9]+]]:_(<4 x s32>), [[UV9:%[0-9]+]]:_(<4 x s32>), [[UV10:%[0-9]+]]:_(<4 x s32>), [[UV11:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV12:%[0-9]+]]:_(<4 x s32>), [[UV13:%[0-9]+]]:_(<4 x s32>), [[UV14:%[0-9]+]]:_(<4 x s32>), [[UV15:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) - ; CHECK-NEXT: G_STORE [[UV]](<4 x s32>), [[COPY]](p1) :: (store (<4 x s32>), align 4, addrspace 1) - ; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 16 - ; CHECK-NEXT: [[PTR_ADD:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C]](s64) -@@ -241,8 +243,10 @@ - ; CHECK-NEXT: [[C14:%[0-9]+]]:_(s64) = G_CONSTANT i64 240 - ; CHECK-NEXT: [[PTR_ADD14:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY]], [[C14]](s64) - ; CHECK-NEXT: G_STORE [[UV15]](<4 x s32>), [[PTR_ADD14]](p1) :: (store (<4 x s32>) into unknown-address + 240, align 4, addrspace 1) -- ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>), [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>), [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(<4 x s32>), [[UV17:%[0-9]+]]:_(<4 x s32>), [[UV18:%[0-9]+]]:_(<4 x s32>), [[UV19:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV20:%[0-9]+]]:_(<4 x s32>), [[UV21:%[0-9]+]]:_(<4 x s32>), [[UV22:%[0-9]+]]:_(<4 x s32>), [[UV23:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV24:%[0-9]+]]:_(<4 x s32>), [[UV25:%[0-9]+]]:_(<4 x s32>), [[UV26:%[0-9]+]]:_(<4 x s32>), [[UV27:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV28:%[0-9]+]]:_(<4 x s32>), [[UV29:%[0-9]+]]:_(<4 x s32>), [[UV30:%[0-9]+]]:_(<4 x s32>), [[UV31:%[0-9]+]]:_(<4 x s32>) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) - ; CHECK-NEXT: G_STORE [[UV16]](<4 x s32>), [[COPY1]](p1) :: (store (<4 x s32>), align 4, addrspace 1) - ; CHECK-NEXT: [[PTR_ADD15:%[0-9]+]]:_(p1) = G_PTR_ADD [[COPY1]], [[C]](s64) - ; CHECK-NEXT: G_STORE [[UV17]](<4 x s32>), [[PTR_ADD15]](p1) :: (store (<4 x s32>) into unknown-address + 16, align 4, addrspace 1) -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/legalize-phi.mir -@@ -673,86 +673,88 @@ - ; CHECK-NEXT: successors: %bb.1(0x40000000), %bb.2(0x40000000) - ; CHECK-NEXT: liveins: $vgpr0_vgpr1_vgpr2_vgpr3, $vgpr4 - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<32 x s32>) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[DEF:%[0-9]+]]:_(<16 x s32>) = G_IMPLICIT_DEF - ; CHECK-NEXT: [[COPY:%[0-9]+]]:_(s32) = COPY $vgpr4 - ; CHECK-NEXT: [[C:%[0-9]+]]:_(s32) = G_CONSTANT i32 0 - ; CHECK-NEXT: [[ICMP:%[0-9]+]]:_(s1) = G_ICMP intpred(eq), [[COPY]](s32), [[C]] -- ; CHECK-NEXT: [[UV:%[0-9]+]]:_(<16 x s32>), [[UV1:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV2:%[0-9]+]]:_(<16 x s32>), [[UV3:%[0-9]+]]:_(<16 x s32>) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) - ; CHECK-NEXT: G_BRCOND [[ICMP]](s1), %bb.1 - ; CHECK-NEXT: G_BR %bb.2 - ; CHECK-NEXT: {{ $}} - ; CHECK-NEXT: bb.1: - ; CHECK-NEXT: successors: %bb.2(0x80000000) - ; CHECK-NEXT: {{ $}} -- ; CHECK-NEXT: [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32), [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32), [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32), [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32), [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32), [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32), [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32), [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32), [[UV128:%[0-9]+]]:_(s32), [[UV129:%[0-9]+]]:_(s32), [[UV130:%[0-9]+]]:_(s32), [[UV131:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<32 x s32>) -- ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -- ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -- ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -- ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -- ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -- ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -- ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -- ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -- ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -- ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -- ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -- ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -- ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -- ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -- ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -- ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -- ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -- ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -- ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -- ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -- ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -- ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -- ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -- ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -- ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -- ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -- ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -- ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -- ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -- ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -- ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -- ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -- ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -- ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -- ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -- ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -- ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -- ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -- ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -- ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -- ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -- ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -- ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -- ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -- ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -- ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -- ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -- ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -- ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -- ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -- ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -- ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -- ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -- ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -- ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -- ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -- ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -- ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -- ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -- ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] -- ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV64]], [[UV128]] -- ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV65]], [[UV129]] -- ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV66]], [[UV130]] -- ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV67]], [[UV131]] -+ ; CHECK-NEXT: [[UV:%[0-9]+]]:_(s32), [[UV1:%[0-9]+]]:_(s32), [[UV2:%[0-9]+]]:_(s32), [[UV3:%[0-9]+]]:_(s32), [[UV4:%[0-9]+]]:_(s32), [[UV5:%[0-9]+]]:_(s32), [[UV6:%[0-9]+]]:_(s32), [[UV7:%[0-9]+]]:_(s32), [[UV8:%[0-9]+]]:_(s32), [[UV9:%[0-9]+]]:_(s32), [[UV10:%[0-9]+]]:_(s32), [[UV11:%[0-9]+]]:_(s32), [[UV12:%[0-9]+]]:_(s32), [[UV13:%[0-9]+]]:_(s32), [[UV14:%[0-9]+]]:_(s32), [[UV15:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV16:%[0-9]+]]:_(s32), [[UV17:%[0-9]+]]:_(s32), [[UV18:%[0-9]+]]:_(s32), [[UV19:%[0-9]+]]:_(s32), [[UV20:%[0-9]+]]:_(s32), [[UV21:%[0-9]+]]:_(s32), [[UV22:%[0-9]+]]:_(s32), [[UV23:%[0-9]+]]:_(s32), [[UV24:%[0-9]+]]:_(s32), [[UV25:%[0-9]+]]:_(s32), [[UV26:%[0-9]+]]:_(s32), [[UV27:%[0-9]+]]:_(s32), [[UV28:%[0-9]+]]:_(s32), [[UV29:%[0-9]+]]:_(s32), [[UV30:%[0-9]+]]:_(s32), [[UV31:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV32:%[0-9]+]]:_(s32), [[UV33:%[0-9]+]]:_(s32), [[UV34:%[0-9]+]]:_(s32), [[UV35:%[0-9]+]]:_(s32), [[UV36:%[0-9]+]]:_(s32), [[UV37:%[0-9]+]]:_(s32), [[UV38:%[0-9]+]]:_(s32), [[UV39:%[0-9]+]]:_(s32), [[UV40:%[0-9]+]]:_(s32), [[UV41:%[0-9]+]]:_(s32), [[UV42:%[0-9]+]]:_(s32), [[UV43:%[0-9]+]]:_(s32), [[UV44:%[0-9]+]]:_(s32), [[UV45:%[0-9]+]]:_(s32), [[UV46:%[0-9]+]]:_(s32), [[UV47:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV48:%[0-9]+]]:_(s32), [[UV49:%[0-9]+]]:_(s32), [[UV50:%[0-9]+]]:_(s32), [[UV51:%[0-9]+]]:_(s32), [[UV52:%[0-9]+]]:_(s32), [[UV53:%[0-9]+]]:_(s32), [[UV54:%[0-9]+]]:_(s32), [[UV55:%[0-9]+]]:_(s32), [[UV56:%[0-9]+]]:_(s32), [[UV57:%[0-9]+]]:_(s32), [[UV58:%[0-9]+]]:_(s32), [[UV59:%[0-9]+]]:_(s32), [[UV60:%[0-9]+]]:_(s32), [[UV61:%[0-9]+]]:_(s32), [[UV62:%[0-9]+]]:_(s32), [[UV63:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV64:%[0-9]+]]:_(s32), [[UV65:%[0-9]+]]:_(s32), [[UV66:%[0-9]+]]:_(s32), [[UV67:%[0-9]+]]:_(s32), [[UV68:%[0-9]+]]:_(s32), [[UV69:%[0-9]+]]:_(s32), [[UV70:%[0-9]+]]:_(s32), [[UV71:%[0-9]+]]:_(s32), [[UV72:%[0-9]+]]:_(s32), [[UV73:%[0-9]+]]:_(s32), [[UV74:%[0-9]+]]:_(s32), [[UV75:%[0-9]+]]:_(s32), [[UV76:%[0-9]+]]:_(s32), [[UV77:%[0-9]+]]:_(s32), [[UV78:%[0-9]+]]:_(s32), [[UV79:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV80:%[0-9]+]]:_(s32), [[UV81:%[0-9]+]]:_(s32), [[UV82:%[0-9]+]]:_(s32), [[UV83:%[0-9]+]]:_(s32), [[UV84:%[0-9]+]]:_(s32), [[UV85:%[0-9]+]]:_(s32), [[UV86:%[0-9]+]]:_(s32), [[UV87:%[0-9]+]]:_(s32), [[UV88:%[0-9]+]]:_(s32), [[UV89:%[0-9]+]]:_(s32), [[UV90:%[0-9]+]]:_(s32), [[UV91:%[0-9]+]]:_(s32), [[UV92:%[0-9]+]]:_(s32), [[UV93:%[0-9]+]]:_(s32), [[UV94:%[0-9]+]]:_(s32), [[UV95:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV96:%[0-9]+]]:_(s32), [[UV97:%[0-9]+]]:_(s32), [[UV98:%[0-9]+]]:_(s32), [[UV99:%[0-9]+]]:_(s32), [[UV100:%[0-9]+]]:_(s32), [[UV101:%[0-9]+]]:_(s32), [[UV102:%[0-9]+]]:_(s32), [[UV103:%[0-9]+]]:_(s32), [[UV104:%[0-9]+]]:_(s32), [[UV105:%[0-9]+]]:_(s32), [[UV106:%[0-9]+]]:_(s32), [[UV107:%[0-9]+]]:_(s32), [[UV108:%[0-9]+]]:_(s32), [[UV109:%[0-9]+]]:_(s32), [[UV110:%[0-9]+]]:_(s32), [[UV111:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[UV112:%[0-9]+]]:_(s32), [[UV113:%[0-9]+]]:_(s32), [[UV114:%[0-9]+]]:_(s32), [[UV115:%[0-9]+]]:_(s32), [[UV116:%[0-9]+]]:_(s32), [[UV117:%[0-9]+]]:_(s32), [[UV118:%[0-9]+]]:_(s32), [[UV119:%[0-9]+]]:_(s32), [[UV120:%[0-9]+]]:_(s32), [[UV121:%[0-9]+]]:_(s32), [[UV122:%[0-9]+]]:_(s32), [[UV123:%[0-9]+]]:_(s32), [[UV124:%[0-9]+]]:_(s32), [[UV125:%[0-9]+]]:_(s32), [[UV126:%[0-9]+]]:_(s32), [[UV127:%[0-9]+]]:_(s32) = G_UNMERGE_VALUES [[DEF]](<16 x s32>) -+ ; CHECK-NEXT: [[ADD:%[0-9]+]]:_(s32) = G_ADD [[UV]], [[UV64]] -+ ; CHECK-NEXT: [[ADD1:%[0-9]+]]:_(s32) = G_ADD [[UV1]], [[UV65]] -+ ; CHECK-NEXT: [[ADD2:%[0-9]+]]:_(s32) = G_ADD [[UV2]], [[UV66]] -+ ; CHECK-NEXT: [[ADD3:%[0-9]+]]:_(s32) = G_ADD [[UV3]], [[UV67]] -+ ; CHECK-NEXT: [[ADD4:%[0-9]+]]:_(s32) = G_ADD [[UV4]], [[UV68]] -+ ; CHECK-NEXT: [[ADD5:%[0-9]+]]:_(s32) = G_ADD [[UV5]], [[UV69]] -+ ; CHECK-NEXT: [[ADD6:%[0-9]+]]:_(s32) = G_ADD [[UV6]], [[UV70]] -+ ; CHECK-NEXT: [[ADD7:%[0-9]+]]:_(s32) = G_ADD [[UV7]], [[UV71]] -+ ; CHECK-NEXT: [[ADD8:%[0-9]+]]:_(s32) = G_ADD [[UV8]], [[UV72]] -+ ; CHECK-NEXT: [[ADD9:%[0-9]+]]:_(s32) = G_ADD [[UV9]], [[UV73]] -+ ; CHECK-NEXT: [[ADD10:%[0-9]+]]:_(s32) = G_ADD [[UV10]], [[UV74]] -+ ; CHECK-NEXT: [[ADD11:%[0-9]+]]:_(s32) = G_ADD [[UV11]], [[UV75]] -+ ; CHECK-NEXT: [[ADD12:%[0-9]+]]:_(s32) = G_ADD [[UV12]], [[UV76]] -+ ; CHECK-NEXT: [[ADD13:%[0-9]+]]:_(s32) = G_ADD [[UV13]], [[UV77]] -+ ; CHECK-NEXT: [[ADD14:%[0-9]+]]:_(s32) = G_ADD [[UV14]], [[UV78]] -+ ; CHECK-NEXT: [[ADD15:%[0-9]+]]:_(s32) = G_ADD [[UV15]], [[UV79]] -+ ; CHECK-NEXT: [[ADD16:%[0-9]+]]:_(s32) = G_ADD [[UV16]], [[UV80]] -+ ; CHECK-NEXT: [[ADD17:%[0-9]+]]:_(s32) = G_ADD [[UV17]], [[UV81]] -+ ; CHECK-NEXT: [[ADD18:%[0-9]+]]:_(s32) = G_ADD [[UV18]], [[UV82]] -+ ; CHECK-NEXT: [[ADD19:%[0-9]+]]:_(s32) = G_ADD [[UV19]], [[UV83]] -+ ; CHECK-NEXT: [[ADD20:%[0-9]+]]:_(s32) = G_ADD [[UV20]], [[UV84]] -+ ; CHECK-NEXT: [[ADD21:%[0-9]+]]:_(s32) = G_ADD [[UV21]], [[UV85]] -+ ; CHECK-NEXT: [[ADD22:%[0-9]+]]:_(s32) = G_ADD [[UV22]], [[UV86]] -+ ; CHECK-NEXT: [[ADD23:%[0-9]+]]:_(s32) = G_ADD [[UV23]], [[UV87]] -+ ; CHECK-NEXT: [[ADD24:%[0-9]+]]:_(s32) = G_ADD [[UV24]], [[UV88]] -+ ; CHECK-NEXT: [[ADD25:%[0-9]+]]:_(s32) = G_ADD [[UV25]], [[UV89]] -+ ; CHECK-NEXT: [[ADD26:%[0-9]+]]:_(s32) = G_ADD [[UV26]], [[UV90]] -+ ; CHECK-NEXT: [[ADD27:%[0-9]+]]:_(s32) = G_ADD [[UV27]], [[UV91]] -+ ; CHECK-NEXT: [[ADD28:%[0-9]+]]:_(s32) = G_ADD [[UV28]], [[UV92]] -+ ; CHECK-NEXT: [[ADD29:%[0-9]+]]:_(s32) = G_ADD [[UV29]], [[UV93]] -+ ; CHECK-NEXT: [[ADD30:%[0-9]+]]:_(s32) = G_ADD [[UV30]], [[UV94]] -+ ; CHECK-NEXT: [[ADD31:%[0-9]+]]:_(s32) = G_ADD [[UV31]], [[UV95]] -+ ; CHECK-NEXT: [[ADD32:%[0-9]+]]:_(s32) = G_ADD [[UV32]], [[UV96]] -+ ; CHECK-NEXT: [[ADD33:%[0-9]+]]:_(s32) = G_ADD [[UV33]], [[UV97]] -+ ; CHECK-NEXT: [[ADD34:%[0-9]+]]:_(s32) = G_ADD [[UV34]], [[UV98]] -+ ; CHECK-NEXT: [[ADD35:%[0-9]+]]:_(s32) = G_ADD [[UV35]], [[UV99]] -+ ; CHECK-NEXT: [[ADD36:%[0-9]+]]:_(s32) = G_ADD [[UV36]], [[UV100]] -+ ; CHECK-NEXT: [[ADD37:%[0-9]+]]:_(s32) = G_ADD [[UV37]], [[UV101]] -+ ; CHECK-NEXT: [[ADD38:%[0-9]+]]:_(s32) = G_ADD [[UV38]], [[UV102]] -+ ; CHECK-NEXT: [[ADD39:%[0-9]+]]:_(s32) = G_ADD [[UV39]], [[UV103]] -+ ; CHECK-NEXT: [[ADD40:%[0-9]+]]:_(s32) = G_ADD [[UV40]], [[UV104]] -+ ; CHECK-NEXT: [[ADD41:%[0-9]+]]:_(s32) = G_ADD [[UV41]], [[UV105]] -+ ; CHECK-NEXT: [[ADD42:%[0-9]+]]:_(s32) = G_ADD [[UV42]], [[UV106]] -+ ; CHECK-NEXT: [[ADD43:%[0-9]+]]:_(s32) = G_ADD [[UV43]], [[UV107]] -+ ; CHECK-NEXT: [[ADD44:%[0-9]+]]:_(s32) = G_ADD [[UV44]], [[UV108]] -+ ; CHECK-NEXT: [[ADD45:%[0-9]+]]:_(s32) = G_ADD [[UV45]], [[UV109]] -+ ; CHECK-NEXT: [[ADD46:%[0-9]+]]:_(s32) = G_ADD [[UV46]], [[UV110]] -+ ; CHECK-NEXT: [[ADD47:%[0-9]+]]:_(s32) = G_ADD [[UV47]], [[UV111]] -+ ; CHECK-NEXT: [[ADD48:%[0-9]+]]:_(s32) = G_ADD [[UV48]], [[UV112]] -+ ; CHECK-NEXT: [[ADD49:%[0-9]+]]:_(s32) = G_ADD [[UV49]], [[UV113]] -+ ; CHECK-NEXT: [[ADD50:%[0-9]+]]:_(s32) = G_ADD [[UV50]], [[UV114]] -+ ; CHECK-NEXT: [[ADD51:%[0-9]+]]:_(s32) = G_ADD [[UV51]], [[UV115]] -+ ; CHECK-NEXT: [[ADD52:%[0-9]+]]:_(s32) = G_ADD [[UV52]], [[UV116]] -+ ; CHECK-NEXT: [[ADD53:%[0-9]+]]:_(s32) = G_ADD [[UV53]], [[UV117]] -+ ; CHECK-NEXT: [[ADD54:%[0-9]+]]:_(s32) = G_ADD [[UV54]], [[UV118]] -+ ; CHECK-NEXT: [[ADD55:%[0-9]+]]:_(s32) = G_ADD [[UV55]], [[UV119]] -+ ; CHECK-NEXT: [[ADD56:%[0-9]+]]:_(s32) = G_ADD [[UV56]], [[UV120]] -+ ; CHECK-NEXT: [[ADD57:%[0-9]+]]:_(s32) = G_ADD [[UV57]], [[UV121]] -+ ; CHECK-NEXT: [[ADD58:%[0-9]+]]:_(s32) = G_ADD [[UV58]], [[UV122]] -+ ; CHECK-NEXT: [[ADD59:%[0-9]+]]:_(s32) = G_ADD [[UV59]], [[UV123]] -+ ; CHECK-NEXT: [[ADD60:%[0-9]+]]:_(s32) = G_ADD [[UV60]], [[UV124]] -+ ; CHECK-NEXT: [[ADD61:%[0-9]+]]:_(s32) = G_ADD [[UV61]], [[UV125]] -+ ; CHECK-NEXT: [[ADD62:%[0-9]+]]:_(s32) = G_ADD [[UV62]], [[UV126]] -+ ; CHECK-NEXT: [[ADD63:%[0-9]+]]:_(s32) = G_ADD [[UV63]], [[UV127]] - ; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD]](s32), [[ADD1]](s32), [[ADD2]](s32), [[ADD3]](s32), [[ADD4]](s32), [[ADD5]](s32), [[ADD6]](s32), [[ADD7]](s32), [[ADD8]](s32), [[ADD9]](s32), [[ADD10]](s32), [[ADD11]](s32), [[ADD12]](s32), [[ADD13]](s32), [[ADD14]](s32), [[ADD15]](s32) - ; CHECK-NEXT: [[BUILD_VECTOR1:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD16]](s32), [[ADD17]](s32), [[ADD18]](s32), [[ADD19]](s32), [[ADD20]](s32), [[ADD21]](s32), [[ADD22]](s32), [[ADD23]](s32), [[ADD24]](s32), [[ADD25]](s32), [[ADD26]](s32), [[ADD27]](s32), [[ADD28]](s32), [[ADD29]](s32), [[ADD30]](s32), [[ADD31]](s32) - ; CHECK-NEXT: [[BUILD_VECTOR2:%[0-9]+]]:_(<16 x s32>) = G_BUILD_VECTOR [[ADD32]](s32), [[ADD33]](s32), [[ADD34]](s32), [[ADD35]](s32), [[ADD36]](s32), [[ADD37]](s32), [[ADD38]](s32), [[ADD39]](s32), [[ADD40]](s32), [[ADD41]](s32), [[ADD42]](s32), [[ADD43]](s32), [[ADD44]](s32), [[ADD45]](s32), [[ADD46]](s32), [[ADD47]](s32) -@@ -760,10 +762,10 @@ - ; CHECK-NEXT: G_BR %bb.2 - ; CHECK-NEXT: {{ $}} - ; CHECK-NEXT: bb.2: -- ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -- ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV1]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -- ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV2]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -- ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[UV3]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI1:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR1]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI2:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR2]](<16 x s32>), %bb.1 -+ ; CHECK-NEXT: [[PHI3:%[0-9]+]]:_(<16 x s32>) = G_PHI [[DEF]](<16 x s32>), %bb.0, [[BUILD_VECTOR3]](<16 x s32>), %bb.1 - ; CHECK-NEXT: [[CONCAT_VECTORS:%[0-9]+]]:_(<64 x s32>) = G_CONCAT_VECTORS [[PHI]](<16 x s32>), [[PHI1]](<16 x s32>), [[PHI2]](<16 x s32>), [[PHI3]](<16 x s32>) - ; CHECK-NEXT: S_SETPC_B64 undef $sgpr30_sgpr31, implicit [[CONCAT_VECTORS]](<64 x s32>) - bb.0: -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir ---- a/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -+++ b/llvm/test/CodeGen/AMDGPU/GlobalISel/regbankselect.mir -@@ -42,6 +42,8 @@ - ret void - } - -+ define void @non_power_of_2() { ret void } -+ - define amdgpu_kernel void @load_constant_v4i16_from_8_align8(ptr addrspace(4) %ptr0) { - ret void - } -@@ -185,6 +187,23 @@ - ... - - --- -+name: non_power_of_2 -+legalized: true -+ -+body: | -+ bb.0: -+ ; CHECK-LABEL: name: non_power_of_2 -+ ; CHECK: [[DEF:%[0-9]+]]:sgpr(s448) = G_IMPLICIT_DEF -+ ; CHECK-NEXT: [[EXTRACT:%[0-9]+]]:sgpr(s32) = G_EXTRACT [[DEF]](s448), 0 -+ ; CHECK-NEXT: $sgpr0 = COPY [[EXTRACT]](s32) -+ ; CHECK-NEXT: SI_RETURN_TO_EPILOG $sgpr0 -+ %0:_(s448) = G_IMPLICIT_DEF -+ %1:_(s32) = G_EXTRACT %0:_(s448), 0 -+ $sgpr0 = COPY %1:_(s32) -+ SI_RETURN_TO_EPILOG $sgpr0 -+... -+ -+--- - name: load_constant_v4i16_from_8_align8 - legalized: true - -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll ---- a/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -+++ b/llvm/test/CodeGen/NVPTX/intrin-nocapture.ll -@@ -0,0 +1,21 @@ -+; RUN: opt < %s -O3 -S | FileCheck %s -+ -+; Address space intrinsics were erroneously marked NoCapture, leading to bad -+; optimizations (such as the store below being eliminated as dead code). This -+; test makes sure we don't regress. -+ -+declare void @foo(ptr addrspace(1)) -+ -+declare ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr) -+ -+; CHECK: @bar -+define void @bar() { -+ %t1 = alloca i32 -+; CHECK: call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr nonnull %t1) -+; CHECK-NEXT: store i32 10, ptr %t1 -+ %t2 = call ptr addrspace(1) @llvm.nvvm.ptr.gen.to.global.p1.p0(ptr %t1) -+ store i32 10, ptr %t1 -+ call void @foo(ptr addrspace(1) %t2) -+ ret void -+} -+ -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate_64.ll b/llvm/test/CodeGen/NVPTX/rotate_64.ll ---- a/llvm/test/CodeGen/NVPTX/rotate_64.ll -+++ b/llvm/test/CodeGen/NVPTX/rotate_64.ll -@@ -1,38 +1,25 @@ --; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 - ; RUN: llc < %s -march=nvptx64 | FileCheck %s - ; RUN: %if ptxas %{ llc < %s -march=nvptx64 | %ptxas-verify %} - - declare i64 @llvm.nvvm.rotate.b64(i64, i32) - declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) - -+; CHECK: rotate64 - define i64 @rotate64(i64 %a, i32 %b) { --; CHECK-LABEL: rotate64( --; CHECK: { --; CHECK-NEXT: .reg .b64 %rd<5>; --; CHECK-EMPTY: --; CHECK-NEXT: // %bb.0: --; CHECK-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; --; CHECK-NEXT: shr.u64 %rd2, %rd1, 61; --; CHECK-NEXT: shl.b64 %rd3, %rd1, 3; --; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; --; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; --; CHECK-NEXT: ret; -+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 3; -+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 61; -+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -+; CHECK: ret - %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 3) - ret i64 %val - } - -+; CHECK: rotateright64 - define i64 @rotateright64(i64 %a, i32 %b) { --; CHECK-LABEL: rotateright64( --; CHECK: { --; CHECK-NEXT: .reg .b64 %rd<5>; --; CHECK-EMPTY: --; CHECK-NEXT: // %bb.0: --; CHECK-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; --; CHECK-NEXT: shl.b64 %rd2, %rd1, 61; --; CHECK-NEXT: shr.u64 %rd3, %rd1, 3; --; CHECK-NEXT: or.b64 %rd4, %rd3, %rd2; --; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd4; --; CHECK-NEXT: ret; -+; CHECK: shl.b64 [[LHS:%.*]], [[RD1:%.*]], 61; -+; CHECK: shr.b64 [[RHS:%.*]], [[RD1]], 3; -+; CHECK: add.u64 [[RD2:%.*]], [[LHS]], [[RHS]]; -+; CHECK: ret - %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 3) - ret i64 %val - } -diff -ruN --strip-trailing-cr a/llvm/test/CodeGen/NVPTX/rotate.ll b/llvm/test/CodeGen/NVPTX/rotate.ll ---- a/llvm/test/CodeGen/NVPTX/rotate.ll -+++ b/llvm/test/CodeGen/NVPTX/rotate.ll -@@ -9,29 +9,26 @@ - declare i64 @llvm.nvvm.rotate.b64(i64, i32) - declare i64 @llvm.nvvm.rotate.right.b64(i64, i32) - --declare i64 @llvm.fshl.i64(i64, i64, i64) --declare i64 @llvm.fshr.i64(i64, i64, i64) --declare i32 @llvm.fshl.i32(i32, i32, i32) --declare i32 @llvm.fshr.i32(i32, i32, i32) -- -- - ; SM20: rotate32 - ; SM35: rotate32 - define i32 @rotate32(i32 %a, i32 %b) { - ; SM20-LABEL: rotate32( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<9>; -+; SM20-NEXT: .reg .b32 %r<4>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u32 %r1, [rotate32_param_0]; - ; SM20-NEXT: ld.param.u32 %r2, [rotate32_param_1]; --; SM20-NEXT: and.b32 %r3, %r2, 31; --; SM20-NEXT: shl.b32 %r4, %r1, %r3; --; SM20-NEXT: neg.s32 %r5, %r2; --; SM20-NEXT: and.b32 %r6, %r5, 31; --; SM20-NEXT: shr.u32 %r7, %r1, %r6; --; SM20-NEXT: or.b32 %r8, %r4, %r7; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r8; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b32 %lhs; -+; SM20-NEXT: .reg .b32 %rhs; -+; SM20-NEXT: .reg .b32 %amt2; -+; SM20-NEXT: shl.b32 %lhs, %r1, %r2; -+; SM20-NEXT: sub.s32 %amt2, 32, %r2; -+; SM20-NEXT: shr.b32 %rhs, %r1, %amt2; -+; SM20-NEXT: add.u32 %r3, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b32 [func_retval0+0], %r3; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotate32( -@@ -53,36 +50,45 @@ - define i64 @rotate64(i64 %a, i32 %b) { - ; SM20-LABEL: rotate64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotate64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotate64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<6>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotate64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [rotate64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%dummy,%r1}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%r2,%dummy}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: ld.param.u32 %r3, [rotate64_param_1]; -+; SM35-NEXT: shf.l.wrap.b32 %r4, %r2, %r1, %r3; -+; SM35-NEXT: shf.l.wrap.b32 %r5, %r1, %r2, %r3; -+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.nvvm.rotate.b64(i64 %a, i32 %b) - ret i64 %val -@@ -93,36 +99,45 @@ - define i64 @rotateright64(i64 %a, i32 %b) { - ; SM20-LABEL: rotateright64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotateright64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<6>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotateright64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [rotateright64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%r1,%dummy}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b32 %dummy; -+; SM35-NEXT: mov.b64 {%dummy,%r2}, %rd1; -+; SM35-NEXT: } -+; SM35-NEXT: ld.param.u32 %r3, [rotateright64_param_1]; -+; SM35-NEXT: shf.r.wrap.b32 %r4, %r2, %r1, %r3; -+; SM35-NEXT: shf.r.wrap.b32 %r5, %r1, %r2, %r3; -+; SM35-NEXT: mov.b64 %rd2, {%r5, %r4}; -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.nvvm.rotate.right.b64(i64 %a, i32 %b) - ret i64 %val -@@ -133,14 +148,18 @@ - define i32 @rotl0(i32 %x) { - ; SM20-LABEL: rotl0( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; -+; SM20-NEXT: .reg .b32 %r<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u32 %r1, [rotl0_param_0]; --; SM20-NEXT: shr.u32 %r2, %r1, 24; --; SM20-NEXT: shl.b32 %r3, %r1, 8; --; SM20-NEXT: or.b32 %r4, %r3, %r2; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b32 %lhs; -+; SM20-NEXT: .reg .b32 %rhs; -+; SM20-NEXT: shl.b32 %lhs, %r1, 8; -+; SM20-NEXT: shr.b32 %rhs, %r1, 24; -+; SM20-NEXT: add.u32 %r2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b32 [func_retval0+0], %r2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotl0( -@@ -158,40 +177,51 @@ - ret i32 %t2 - } - -+declare i64 @llvm.fshl.i64(i64, i64, i64) -+declare i64 @llvm.fshr.i64(i64, i64, i64) -+ - ; SM35: rotl64 - define i64 @rotl64(i64 %a, i64 %n) { - ; SM20-LABEL: rotl64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotl64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shl.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shr.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotl64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<2>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_param_0]; - ; SM35-NEXT: ld.param.u32 %r1, [rotl64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shr.u64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: .reg .u32 %amt2; -+; SM35-NEXT: and.b32 %amt2, %r1, 63; -+; SM35-NEXT: shl.b64 %lhs, %rd1, %amt2; -+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM35-NEXT: shr.b64 %rhs, %rd1, %amt2; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 %n) - ret i64 %val -@@ -201,26 +231,34 @@ - define i64 @rotl64_imm(i64 %a) { - ; SM20-LABEL: rotl64_imm( - ; SM20: { --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; --; SM20-NEXT: shr.u64 %rd2, %rd1, 62; --; SM20-NEXT: shl.b64 %rd3, %rd1, 2; --; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: shl.b64 %lhs, %rd1, 2; -+; SM20-NEXT: shr.b64 %rhs, %rd1, 62; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotl64_imm( - ; SM35: { --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotl64_imm_param_0]; --; SM35-NEXT: shr.u64 %rd2, %rd1, 62; --; SM35-NEXT: shl.b64 %rd3, %rd1, 2; --; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: shl.b64 %lhs, %rd1, 2; -+; SM35-NEXT: shr.b64 %rhs, %rd1, 62; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshl.i64(i64 %a, i64 %a, i64 66) - ret i64 %val -@@ -230,36 +268,44 @@ - define i64 @rotr64(i64 %a, i64 %n) { - ; SM20-LABEL: rotr64( - ; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b32 %r<2>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; - ; SM20-NEXT: ld.param.u32 %r1, [rotr64_param_1]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM20-NEXT: neg.s32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM20-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: .reg .u32 %amt2; -+; SM20-NEXT: and.b32 %amt2, %r1, 63; -+; SM20-NEXT: shr.b64 %lhs, %rd1, %amt2; -+; SM20-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM20-NEXT: shl.b64 %rhs, %rd1, %amt2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotr64( - ; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b32 %r<2>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_param_0]; - ; SM35-NEXT: ld.param.u32 %r1, [rotr64_param_1]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shr.u64 %rd2, %rd1, %r2; --; SM35-NEXT: neg.s32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shl.b64 %rd3, %rd1, %r4; --; SM35-NEXT: or.b64 %rd4, %rd2, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: .reg .u32 %amt2; -+; SM35-NEXT: and.b32 %amt2, %r1, 63; -+; SM35-NEXT: shr.b64 %lhs, %rd1, %amt2; -+; SM35-NEXT: sub.u32 %amt2, 64, %amt2; -+; SM35-NEXT: shl.b64 %rhs, %rd1, %amt2; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 %n) - ret i64 %val -@@ -269,180 +315,35 @@ - define i64 @rotr64_imm(i64 %a) { - ; SM20-LABEL: rotr64_imm( - ; SM20: { --; SM20-NEXT: .reg .b64 %rd<5>; -+; SM20-NEXT: .reg .b64 %rd<3>; - ; SM20-EMPTY: - ; SM20-NEXT: // %bb.0: - ; SM20-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; --; SM20-NEXT: shl.b64 %rd2, %rd1, 62; --; SM20-NEXT: shr.u64 %rd3, %rd1, 2; --; SM20-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM20-NEXT: { -+; SM20-NEXT: .reg .b64 %lhs; -+; SM20-NEXT: .reg .b64 %rhs; -+; SM20-NEXT: shl.b64 %lhs, %rd1, 62; -+; SM20-NEXT: shr.b64 %rhs, %rd1, 2; -+; SM20-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM20-NEXT: } -+; SM20-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM20-NEXT: ret; - ; - ; SM35-LABEL: rotr64_imm( - ; SM35: { --; SM35-NEXT: .reg .b64 %rd<5>; -+; SM35-NEXT: .reg .b64 %rd<3>; - ; SM35-EMPTY: - ; SM35-NEXT: // %bb.0: - ; SM35-NEXT: ld.param.u64 %rd1, [rotr64_imm_param_0]; --; SM35-NEXT: shl.b64 %rd2, %rd1, 62; --; SM35-NEXT: shr.u64 %rd3, %rd1, 2; --; SM35-NEXT: or.b64 %rd4, %rd3, %rd2; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd4; -+; SM35-NEXT: { -+; SM35-NEXT: .reg .b64 %lhs; -+; SM35-NEXT: .reg .b64 %rhs; -+; SM35-NEXT: shl.b64 %lhs, %rd1, 62; -+; SM35-NEXT: shr.b64 %rhs, %rd1, 2; -+; SM35-NEXT: add.u64 %rd2, %lhs, %rhs; -+; SM35-NEXT: } -+; SM35-NEXT: st.param.b64 [func_retval0+0], %rd2; - ; SM35-NEXT: ret; - %val = tail call i64 @llvm.fshr.i64(i64 %a, i64 %a, i64 66) - ret i64 %val - } -- --define i32 @funnel_shift_right_32(i32 %a, i32 %b, i32 %c) { --; SM20-LABEL: funnel_shift_right_32( --; SM20: { --; SM20-NEXT: .reg .b32 %r<11>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; --; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_2]; --; SM20-NEXT: and.b32 %r3, %r2, 31; --; SM20-NEXT: ld.param.u32 %r4, [funnel_shift_right_32_param_1]; --; SM20-NEXT: shr.u32 %r5, %r4, %r3; --; SM20-NEXT: shl.b32 %r6, %r1, 1; --; SM20-NEXT: not.b32 %r7, %r2; --; SM20-NEXT: and.b32 %r8, %r7, 31; --; SM20-NEXT: shl.b32 %r9, %r6, %r8; --; SM20-NEXT: or.b32 %r10, %r9, %r5; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_right_32( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_32_param_0]; --; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_right_32_param_1]; --; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_right_32_param_2]; --; SM35-NEXT: shf.r.wrap.b32 %r4, %r1, %r2, %r3; --; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; --; SM35-NEXT: ret; -- %val = call i32 @llvm.fshr.i32(i32 %a, i32 %b, i32 %c) -- ret i32 %val --} -- --define i32 @funnel_shift_left_32(i32 %a, i32 %b, i32 %c) { --; SM20-LABEL: funnel_shift_left_32( --; SM20: { --; SM20-NEXT: .reg .b32 %r<11>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; --; SM20-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_2]; --; SM20-NEXT: and.b32 %r3, %r2, 31; --; SM20-NEXT: shl.b32 %r4, %r1, %r3; --; SM20-NEXT: ld.param.u32 %r5, [funnel_shift_left_32_param_1]; --; SM20-NEXT: shr.u32 %r6, %r5, 1; --; SM20-NEXT: not.b32 %r7, %r2; --; SM20-NEXT: and.b32 %r8, %r7, 31; --; SM20-NEXT: shr.u32 %r9, %r6, %r8; --; SM20-NEXT: or.b32 %r10, %r4, %r9; --; SM20-NEXT: st.param.b32 [func_retval0+0], %r10; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_left_32( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_32_param_0]; --; SM35-NEXT: ld.param.u32 %r2, [funnel_shift_left_32_param_1]; --; SM35-NEXT: ld.param.u32 %r3, [funnel_shift_left_32_param_2]; --; SM35-NEXT: shf.l.wrap.b32 %r4, %r1, %r2, %r3; --; SM35-NEXT: st.param.b32 [func_retval0+0], %r4; --; SM35-NEXT: ret; -- %val = call i32 @llvm.fshl.i32(i32 %a, i32 %b, i32 %c) -- ret i32 %val --} -- --define i64 @funnel_shift_right_64(i64 %a, i64 %b, i64 %c) { --; SM20-LABEL: funnel_shift_right_64( --; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<7>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; --; SM20-NEXT: shr.u64 %rd3, %rd2, %r2; --; SM20-NEXT: shl.b64 %rd4, %rd1, 1; --; SM20-NEXT: not.b32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shl.b64 %rd5, %rd4, %r4; --; SM20-NEXT: or.b64 %rd6, %rd5, %rd3; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_right_64( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<7>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_right_64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_right_64_param_2]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: ld.param.u64 %rd2, [funnel_shift_right_64_param_1]; --; SM35-NEXT: shr.u64 %rd3, %rd2, %r2; --; SM35-NEXT: shl.b64 %rd4, %rd1, 1; --; SM35-NEXT: not.b32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shl.b64 %rd5, %rd4, %r4; --; SM35-NEXT: or.b64 %rd6, %rd5, %rd3; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM35-NEXT: ret; -- %val = call i64 @llvm.fshr.i64(i64 %a, i64 %b, i64 %c) -- ret i64 %val --} -- --define i64 @funnel_shift_left_64(i64 %a, i64 %b, i64 %c) { --; SM20-LABEL: funnel_shift_left_64( --; SM20: { --; SM20-NEXT: .reg .b32 %r<5>; --; SM20-NEXT: .reg .b64 %rd<7>; --; SM20-EMPTY: --; SM20-NEXT: // %bb.0: --; SM20-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; --; SM20-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; --; SM20-NEXT: and.b32 %r2, %r1, 63; --; SM20-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM20-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; --; SM20-NEXT: shr.u64 %rd4, %rd3, 1; --; SM20-NEXT: not.b32 %r3, %r1; --; SM20-NEXT: and.b32 %r4, %r3, 63; --; SM20-NEXT: shr.u64 %rd5, %rd4, %r4; --; SM20-NEXT: or.b64 %rd6, %rd2, %rd5; --; SM20-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM20-NEXT: ret; --; --; SM35-LABEL: funnel_shift_left_64( --; SM35: { --; SM35-NEXT: .reg .b32 %r<5>; --; SM35-NEXT: .reg .b64 %rd<7>; --; SM35-EMPTY: --; SM35-NEXT: // %bb.0: --; SM35-NEXT: ld.param.u64 %rd1, [funnel_shift_left_64_param_0]; --; SM35-NEXT: ld.param.u32 %r1, [funnel_shift_left_64_param_2]; --; SM35-NEXT: and.b32 %r2, %r1, 63; --; SM35-NEXT: shl.b64 %rd2, %rd1, %r2; --; SM35-NEXT: ld.param.u64 %rd3, [funnel_shift_left_64_param_1]; --; SM35-NEXT: shr.u64 %rd4, %rd3, 1; --; SM35-NEXT: not.b32 %r3, %r1; --; SM35-NEXT: and.b32 %r4, %r3, 63; --; SM35-NEXT: shr.u64 %rd5, %rd4, %r4; --; SM35-NEXT: or.b64 %rd6, %rd2, %rd5; --; SM35-NEXT: st.param.b64 [func_retval0+0], %rd6; --; SM35-NEXT: ret; -- %val = call i64 @llvm.fshl.i64(i64 %a, i64 %b, i64 %c) -- ret i64 %val --} -- -diff -ruN --strip-trailing-cr a/llvm/test/DebugInfo/NVPTX/debug-info.ll b/llvm/test/DebugInfo/NVPTX/debug-info.ll ---- a/llvm/test/DebugInfo/NVPTX/debug-info.ll -+++ b/llvm/test/DebugInfo/NVPTX/debug-info.ll -@@ -25,10 +25,6 @@ - ; CHECK-DAG: .reg .b64 %rd<8>; - ; CHECK: .loc [[DEBUG_INFO_CU:[0-9]+]] 5 0 - ; CHECK: ld.param.u32 %r{{.+}}, [{{.+}}]; --; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; --; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; --; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; --; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; - ; CHECK: .loc [[BUILTUIN_VARS_H:[0-9]+]] 78 180 - ; CHECK: mov.u32 %r{{.+}}, %ctaid.x; - ; CHECK: .loc [[BUILTUIN_VARS_H]] 89 180 -@@ -42,6 +38,10 @@ - ; CHECK: .loc [[DEBUG_INFO_CU]] 7 7 - ; CHECK: @%p{{.+}} bra [[BB:\$L__.+]]; - ; CHECK: ld.param.f32 %f{{.+}}, [{{.+}}]; -+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; -+; CHECK: ld.param.u64 %rd{{.+}}, [{{.+}}]; -+; CHECK: cvta.to.global.u64 %rd{{.+}}, %rd{{.+}}; - ; CHECK: .loc [[DEBUG_INFO_CU]] 8 13 - ; CHECK: mul.wide.u32 %rd{{.+}}, %r{{.+}}, 4; - ; CHECK: add.s64 %rd{{.+}}, %rd{{.+}}, %rd{{.+}}; -@@ -2661,22 +2661,22 @@ - ; CHECK-NEXT:.b32 4579 // DW_AT_type - ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8aa:0x18 DW_TAG_inlined_subroutine - ; CHECK-NEXT:.b32 707 // DW_AT_abstract_origin --; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc --; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc -+; CHECK-NEXT:.b64 $L__tmp0 // DW_AT_low_pc -+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_high_pc - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 11 // DW_AT_call_column - ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8c2:0x18 DW_TAG_inlined_subroutine - ; CHECK-NEXT:.b32 1466 // DW_AT_abstract_origin --; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc --; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc -+; CHECK-NEXT:.b64 $L__tmp1 // DW_AT_low_pc -+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_high_pc - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 24 // DW_AT_call_column - ; CHECK-NEXT:.b8 25 // Abbrev [25] 0x8da:0x18 DW_TAG_inlined_subroutine - ; CHECK-NEXT:.b32 2060 // DW_AT_abstract_origin --; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_low_pc --; CHECK-NEXT:.b64 $L__tmp4 // DW_AT_high_pc -+; CHECK-NEXT:.b64 $L__tmp2 // DW_AT_low_pc -+; CHECK-NEXT:.b64 $L__tmp3 // DW_AT_high_pc - ; CHECK-NEXT:.b8 1 // DW_AT_call_file - ; CHECK-NEXT:.b8 6 // DW_AT_call_line - ; CHECK-NEXT:.b8 37 // DW_AT_call_column diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index af35fe7..7b11086 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "9830156f623c56062bf6df1b4c4b4bd8ab5bd57c" - LLVM_SHA256 = "85bb9a61cfdaf0d3386890dc7b4bbaa17eecf4b70b60c314307f2ca3919b9035" + LLVM_COMMIT = "29b92d07746fac26cd64c914bc9c5c3833974f6d" + LLVM_SHA256 = "3e8e93e3749454af4b64f7f34b792a4748b62fc533bca1703d33b2b04e34eb70" tf_http_archive( name = name, diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 8b13789..b997031 100755 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -1 +1,617 @@ +diff --ruN a/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir b/stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +--- stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir ++++ stablehlo/stablehlo/tests/transforms/stablehlo_create_compatibility_expander.mlir +@@ -69,7 +69,7 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + } : (tensor<3x2x4x7x9xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> + } +@@ -77,9 +77,9 @@ + // ----- + + // CHECK-LABEL: @gather_with_batching_no_index_vector_dim ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +-// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> + // CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ + // CHECK-SAME: dimension_numbers = #stablehlo.gather< +@@ -102,7 +102,7 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + }> : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>) -> tensor<4x3x5x8xi32> + func.return %0 : tensor<4x3x5x8xi32> + } +@@ -133,9 +133,305 @@ + index_vector_dim = 3 + >, + slice_sizes = array, +- indices_are_sorted = true ++ indices_are_sorted = false + }> : (tensor<0x2x9xi32>, tensor<0x3x5x1xi32>) -> tensor<0x3x5x8xi32> + func.return %0 : tensor<0x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<3x4x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 1 : tensor<3x4x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<3x4x5x1xi32>, tensor<3x4x5x1xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x4xi32>) -> tensor<3x4x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<3x4x5x8xi32> ++func.func @gather_batching_dims_indices_become_unsorted(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [0, 1], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<3x2x4x7x9xi32>, tensor<3x4x5x2xi32>) -> tensor<3x4x5x8xi32> ++ func.return %0 : tensor<3x4x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_become_unsorted_2 ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_become_unsorted_2(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<3x2x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_remain_sorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = true, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [0, 2], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = true ++ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_indices_remain_unsorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>) -> tensor<2x3x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<2x3x5x8xi32> ++func.func @gather_batching_dims_indices_remain_unsorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [2, 3], ++ operand_batching_dims = [0, 1], ++ start_indices_batching_dims = [0, 2], ++ start_index_map = [2, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x8xi32> ++ func.return %0 : tensor<2x3x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dims_does_not_overflow_indices_type ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x127x5x1xi8> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x127x5x1xi8> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<4x127x5x1xi8>, tensor<4x127x5x1xi8>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x4xi8> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x4xi8>) -> tensor<4x127x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x127x5x8xi32> ++func.func @gather_batching_dims_does_not_overflow_indices_type(%arg0: tensor<127x2x4x7x9xi32>, %arg1: tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<127x2x4x7x9xi32>, tensor<4x127x5x2xi8>) -> tensor<4x127x5x8xi32> ++ func.return %0 : tensor<4x127x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_signless_indices_type ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5x2xi8>) -> tensor<4x128x5x2xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x2xi32>) -> tensor<4x128x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x4xi32>) -> tensor<4x128x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> ++func.func @gather_batching_dim_overflows_signless_indices_type(%arg0: tensor<128x2x4x7x9xi32>, %arg1: tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<128x2x4x7x9xi32>, tensor<4x128x5x2xi8>) -> tensor<4x128x5x8xi32> ++ func.return %0 : tensor<4x128x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_unsigned_indices_type ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<256x4x5x2xui8>) -> tensor<256x4x5x2xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<256x4x5x1xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<256x4x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim0]], %[[iota_dim1]], %[[convert]], dim = 3 : (tensor<256x4x5x1xi32>, tensor<256x4x5x1xi32>, tensor<256x4x5x2xi32>) -> tensor<256x4x5x4xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x4xi32>) -> tensor<256x4x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<256x4x5x8xi32> ++func.func @gather_batching_dim_overflows_unsigned_indices_type(%arg0: tensor<256x2x4x7x9xi32>, %arg1: tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [0, 1], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<256x2x4x7x9xi32>, tensor<256x4x5x2xui8>) -> tensor<256x4x5x8xi32> ++ func.return %0 : tensor<256x4x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_indices_type_and_i32 ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x2xi64> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x2147483648x5x1xi64> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x2147483648x5x1xi64> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[convert]], dim = 3 : (tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x1xi64>, tensor<4x2147483648x5x2xi64>) -> tensor<4x2147483648x5x4xi64> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], ++// CHECK-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x4xi64>) -> tensor<4x2147483648x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x2147483648x5x8xi32> ++func.func @gather_batching_dim_overflows_indices_type_and_i32(%arg0: tensor<2147483648x2x4x7x9xi32>, %arg1: tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<2147483648x2x4x7x9xi32>, tensor<4x2147483648x5x2xi8>) -> tensor<4x2147483648x5x8xi32> ++ func.return %0 : tensor<4x2147483648x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_dynamic_size ++// CHECK: operand_batching_dims = [0, 2] ++// CHECK: start_indices_batching_dims = [1, 0] ++func.func @gather_batching_dim_dynamic_size(%arg0: tensor, %arg1: tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1, 3], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1, 3], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor, tensor<4x?x5x2xi8>) -> tensor<4x?x5x8xi32> ++ func.return %0 : tensor<4x?x5x8xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @gather_batching_dim_overflows_and_no_index_vector_dim ++// CHECK-NEXT: %[[convert:.*]] = stablehlo.convert %arg1 : (tensor<4x128x5xi8>) -> tensor<4x128x5xi32> ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %[[convert]] : (tensor<4x128x5xi32>) -> tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x128x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>, tensor<4x128x5x1xi32>) -> tensor<4x128x5x3xi32> ++// CHECK-NEXT: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ ++// CHECK-SAME: dimension_numbers = #stablehlo.gather< ++// CHECK-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2], ++// CHECK-SAME: start_index_map = [0, 2, 1], index_vector_dim = 3>, ++// CHECK-SAME: indices_are_sorted = false, ++// CHECK-SAME: slice_sizes = array ++// CHECK-SAME: }> : (tensor<128x2x4x9xi32>, tensor<4x128x5x3xi32>) -> tensor<4x128x5x8xi32> ++// CHECK-NEXT: return %[[gather]] : tensor<4x128x5x8xi32> ++func.func @gather_batching_dim_overflows_and_no_index_vector_dim(%arg0: tensor<128x2x4x9xi32>, %arg1: tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> { ++ %0 = "stablehlo.gather"(%arg0, %arg1) { ++ dimension_numbers = #stablehlo.gather< ++ offset_dims = [3], ++ collapsed_slice_dims = [1], ++ operand_batching_dims = [0, 2], ++ start_indices_batching_dims = [1, 0], ++ start_index_map = [1], ++ index_vector_dim = 3 ++ >, ++ slice_sizes = array, ++ indices_are_sorted = false ++ } : (tensor<128x2x4x9xi32>, tensor<4x128x5xi8>) -> tensor<4x128x5x8xi32> ++ func.return %0 : tensor<4x128x5x8xi32> + } + + // ----- +@@ -156,7 +452,7 @@ + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ +- indices_are_sorted = true, ++ indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1, 3], +@@ -176,9 +472,9 @@ + // ----- + + // CHECK-LABEL: @scatter_with_batching_no_index_vector_dim ++// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 1 : tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 0 : tensor<4x3x5x1xi32> +-// CHECK-NEXT: %[[reshape:.*]] = stablehlo.reshape %arg1 : (tensor<4x3x5xi32>) -> tensor<4x3x5x1xi32> + // CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %[[reshape]], dim = 3 : (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>) -> tensor<4x3x5x3xi32> + // CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ + // CHECK-SAME: indices_are_sorted = false, +@@ -192,7 +488,7 @@ + // CHECK-NO-DOWNGRADE: input_batching_dims = [0, 2] + // CHECK-NO-DOWNGRADE: scatter_indices_batching_dims = [1, 0] + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ +- indices_are_sorted = true, ++ indices_are_sorted = false, + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [3], + inserted_window_dims = [1], +@@ -208,3 +504,60 @@ + }) : (tensor<3x2x4x9xi32>, tensor<4x3x5xi32>, tensor<4x3x5x8xi32>) -> tensor<3x2x4x9xi32> + func.return %0 : tensor<3x2x4x9xi32> + } ++ ++// ----- ++ ++// CHECK-LABEL: @scatter_batching_dims_indices_remain_sorted ++// CHECK-NEXT: %[[iota_dim1:.*]] = stablehlo.iota dim = 0 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[iota_dim0:.*]] = stablehlo.iota dim = 2 : tensor<2x3x5x1xi32> ++// CHECK-NEXT: %[[concat:.*]] = stablehlo.concatenate %[[iota_dim1]], %[[iota_dim0]], %arg1, dim = 3 : (tensor<2x3x5x1xi32>, tensor<2x3x5x1xi32>, tensor<2x3x5x2xi32>) -> tensor<2x3x5x4xi32> ++// CHECK-NEXT: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ ++// CHECK-SAME: indices_are_sorted = true, ++// CHECK-SAME: dimension_numbers = #stablehlo.scatter< ++// CHECK-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], ++// CHECK-SAME: scatter_dims_to_operand_dims = [0, 1, 2, 3], index_vector_dim = 3>, ++// CHECK-SAME: unique_indices = false}> ++// CHECK: (tensor<2x5x4x7x9xi32>, tensor<2x3x5x4xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> ++// CHECK-NEXT: return %[[scatter]] : tensor<2x5x4x7x9xi32> ++func.func @scatter_batching_dims_indices_remain_sorted(%arg0: tensor<2x5x4x7x9xi32>, %arg1: tensor<2x3x5x2xi32>, %arg2: tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> { ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ ++ indices_are_sorted = true, ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [3], ++ inserted_window_dims = [2, 3], ++ input_batching_dims = [0, 1], ++ scatter_indices_batching_dims = [0, 2], ++ scatter_dims_to_operand_dims = [2, 3], ++ index_vector_dim = 3 ++ >, ++ unique_indices = false ++ }> ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ stablehlo.return %arg4 : tensor ++ }) : (tensor<2x5x4x7x9xi32>, tensor<2x3x5x2xi32>, tensor<2x3x5x8xi32>) -> tensor<2x5x4x7x9xi32> ++ func.return %0 : tensor<2x5x4x7x9xi32> ++} ++ ++// ----- ++ ++// CHECK-LABEL: @scatter_batching_dim_dynamic_scatter_indices ++// CHECK: input_batching_dims = [0, 2] ++// CHECK: scatter_indices_batching_dims = [1, 0] ++func.func @scatter_batching_dim_dynamic_scatter_indices(%arg0: tensor, %arg1: tensor<4x?x5x2xi32>, %arg2: tensor<4x?x5x8xi32>) -> tensor { ++ %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) <{ ++ indices_are_sorted = false, ++ scatter_dimension_numbers = #stablehlo.scatter< ++ update_window_dims = [3], ++ inserted_window_dims = [1, 3], ++ input_batching_dims = [0, 2], ++ scatter_indices_batching_dims = [1, 0], ++ scatter_dims_to_operand_dims = [1, 3], ++ index_vector_dim = 3 ++ >, ++ unique_indices = false ++ }> ({ ++ ^bb0(%arg3: tensor, %arg4: tensor): ++ stablehlo.return %arg4 : tensor ++ }) : (tensor, tensor<4x?x5x2xi32>, tensor<4x?x5x8xi32>) -> tensor ++ func.return %0 : tensor ++} +diff --ruN a/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp b/stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +--- stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp ++++ stablehlo/stablehlo/transforms/StablehloCreateCompatibilityExpander.cpp +@@ -22,8 +22,11 @@ + #include "llvm/ADT/STLExtras.h" + #include "llvm/ADT/SmallVector.h" + #include "llvm/Support/ErrorHandling.h" ++#include "llvm/Support/MathExtras.h" + #include "mlir/Dialect/Func/IR/FuncOps.h" ++#include "mlir/IR/Builders.h" + #include "mlir/IR/BuiltinAttributes.h" ++#include "mlir/IR/BuiltinTypeInterfaces.h" + #include "mlir/IR/BuiltinTypes.h" + #include "mlir/IR/Diagnostics.h" + #include "mlir/IR/PatternMatch.h" +@@ -75,6 +78,42 @@ + return result; + } + ++bool fitsInIntegralType(int64_t size, IntegerType type) { ++ if (type.isUnsigned()) { ++ return llvm::isUIntN(type.getWidth(), size); ++ } else { ++ return llvm::isIntN(type.getWidth(), size); ++ } ++} ++ ++// If `type` is an integer type in which `size` doesn't fit, promote it to i32 ++// or i64 (depending on `size`). ++Type promoteTypeForSize(Type type, int64_t size, OpBuilder &builder) { ++ // Gather/Scatter should have an integer type, but we check just in case. ++ auto intType = dyn_cast(type); ++ if (!intType || fitsInIntegralType(size, intType)) { ++ return type; ++ } ++ if (fitsInIntegralType(size, builder.getI32Type())) { ++ return builder.getI32Type(); ++ } ++ return builder.getI64Type(); ++} ++ ++// If `indices_batching_dims` and `updated_index_map` are both sorted, then the ++// `indices_are_sorted` property is preserved. ++// ++// This is because each concatenated iota is monotonically increasing, sorted ++// indices batching dims mean their order corresponds to the order of batching ++// dims in the operand, and a sorted updated start index map means the order of ++// the index vector dim corresponds to the order of operand dims. ++bool getUpdatedIndicesAreSorted(bool indices_are_sorted, ++ ArrayRef indices_batching_dims, ++ ArrayRef updated_index_map) { ++ return indices_are_sorted && llvm::is_sorted(indices_batching_dims) && ++ llvm::is_sorted(updated_index_map); ++} ++ + // Returns an updated indices tensor such that an `IotaOp` is prepended for each + // dim in `indicesBatchingDims` with a `ConcatenateOp`. + // +@@ -85,16 +124,31 @@ + PatternRewriter &rewriter) { + Location loc = indices.getLoc(); + auto indicesType = cast(indices.getType()); ++ Type elementType = indicesType.getElementType(); ++ ++ // The batching dim sizes might not fit in the existing element type, ++ // in which case we need to promote it. ++ for (int64_t batchingDim : indicesBatchingDims) { ++ elementType = promoteTypeForSize( ++ elementType, indicesType.getDimSize(batchingDim), rewriter); ++ } ++ if (elementType != indicesType.getElementType()) { ++ indicesType = RankedTensorType::get(indicesType.getShape(), elementType); ++ indices = rewriter.create(loc, indicesType, indices); ++ } ++ + bool indexVectorDimOnLastDim = indexVectorDim == indicesType.getRank(); +- + SmallVector iotaShape(indicesType.getShape()); + if (indexVectorDimOnLastDim) { + iotaShape.push_back(1); + } else { + iotaShape[indexVectorDim] = 1; + } +- auto iotaType = +- RankedTensorType::get(iotaShape, indicesType.getElementType()); ++ auto iotaType = RankedTensorType::get(iotaShape, elementType); ++ ++ if (indexVectorDimOnLastDim) { ++ indices = rewriter.create(loc, iotaType, indices); ++ } + + SmallVector indicesToConcat; + indicesToConcat.reserve(indicesBatchingDims.size() + 1); +@@ -102,12 +156,7 @@ + indicesToConcat.push_back( + rewriter.create(loc, iotaType, batchingDim)); + } +- if (indexVectorDimOnLastDim) { +- indicesToConcat.push_back( +- rewriter.create(loc, iotaType, indices)); +- } else { +- indicesToConcat.push_back(indices); +- } ++ indicesToConcat.push_back(indices); + return rewriter.create(loc, indicesToConcat, indexVectorDim); + } + +@@ -125,9 +174,17 @@ + PatternRewriter &rewriter) const override { + GatherDimensionNumbersAttr dimNumbers = op.getDimensionNumbers(); + ArrayRef operandBatchingDims = dimNumbers.getOperandBatchingDims(); ++ ArrayRef startIndicesBatchingDims = ++ dimNumbers.getStartIndicesBatchingDims(); + if (operandBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "gather op has no batching dims"; ++ }); ++ } ++ ++ if (!op.getStartIndices().getType().hasStaticShape()) { ++ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { ++ diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } + +@@ -136,16 +193,18 @@ + SmallVector newStartIndexMap = + llvm::to_vector(llvm::concat( + operandBatchingDims, dimNumbers.getStartIndexMap())); +- Value newIndices = createConcatIndices( +- op.getStartIndices(), dimNumbers.getIndexVectorDim(), +- dimNumbers.getStartIndicesBatchingDims(), rewriter); ++ Value newIndices = createConcatIndices(op.getStartIndices(), ++ dimNumbers.getIndexVectorDim(), ++ startIndicesBatchingDims, rewriter); + rewriter.replaceOpWithNewOp( + op, op.getOperand(), newIndices, + GatherDimensionNumbersAttr::get( + op.getContext(), dimNumbers.getOffsetDims(), newCollapsedSliceDims, + /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, + newStartIndexMap, dimNumbers.getIndexVectorDim()), +- op.getSliceSizes(), /*indicesAreSorted=*/false); ++ op.getSliceSizes(), ++ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), ++ startIndicesBatchingDims, newStartIndexMap)); + + return success(); + } +@@ -161,9 +220,17 @@ + PatternRewriter &rewriter) const override { + ScatterDimensionNumbersAttr dimNumbers = op.getScatterDimensionNumbers(); + ArrayRef inputBatchingDims = dimNumbers.getInputBatchingDims(); ++ ArrayRef scatterIndicesBatchingDims = ++ dimNumbers.getScatterIndicesBatchingDims(); + if (inputBatchingDims.empty()) { + return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { + diag << "scatter op has no batching dims"; ++ }); ++ } ++ ++ if (!op.getScatterIndices().getType().hasStaticShape()) { ++ return rewriter.notifyMatchFailure(op, [](Diagnostic &diag) { ++ diag << "gather op has start indices with dynamic shape, can't expand"; + }); + } + +@@ -174,7 +241,7 @@ + inputBatchingDims, dimNumbers.getScatterDimsToOperandDims())); + Value newIndices = createConcatIndices( + op.getScatterIndices(), dimNumbers.getIndexVectorDim(), +- dimNumbers.getScatterIndicesBatchingDims(), rewriter); ++ scatterIndicesBatchingDims, rewriter); + auto newScatterOp = rewriter.create( + op.getLoc(), op->getResultTypes(), op.getInputs(), newIndices, + op.getUpdates(), +@@ -183,7 +250,10 @@ + newInsertedWindowDims, + /*inputBatchingDims=*/{}, /*scatterIndicesBatchingDims=*/{}, + newScatterDimsToOperandDims, dimNumbers.getIndexVectorDim()), +- /*indicesAreSorted=*/false, op.getUniqueIndices()); ++ getUpdatedIndicesAreSorted(op.getIndicesAreSorted(), ++ scatterIndicesBatchingDims, ++ newScatterDimsToOperandDims), ++ op.getUniqueIndices()); + + newScatterOp.getUpdateComputation().takeBody(op.getUpdateComputation()); + rewriter.replaceOp(op, newScatterOp.getResults()); diff --git a/third_party/stablehlo/workspace.bzl b/third_party/stablehlo/workspace.bzl index 2e87599..0a9d3d0 100644 --- a/third_party/stablehlo/workspace.bzl +++ b/third_party/stablehlo/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") def repo(): # - STABLEHLO_COMMIT = "ca13d31b5ed0b2053dde0a624480ad765e219ebf" - STABLEHLO_SHA256 = "123462093f087f2576bb6a6cc471370eed2d43c291f881ff359fd4ca812003db" + STABLEHLO_COMMIT = "9d9290dc2308c1850cea69ea05f8c94017e484ee" + STABLEHLO_SHA256 = "29803fc8a3a96f9e5469c7ab51f2ff4292dc2419c17bd0466f5d15a448cf6815" # tf_http_archive(