From 4f74fd010bdf4b2a6211c8615754f97855bf30d3 Mon Sep 17 00:00:00 2001 From: Denys Mentiei Date: Fri, 7 Jun 2024 16:35:19 +0200 Subject: [PATCH] AdvancedInterfaceVariableScalarReplacementPass added It is a full-rewrite of the existing InterfaceVariableScalarReplacementPass to cover all the corner cases and ability to skip matrices from being scalarized. The plan is to contribute this back to replace the original pass, but before that happens we have this as a separate pass. --- include/spirv-tools/optimizer.hpp | 16 + source/opt/CMakeLists.txt | 26 + source/opt/adv_interface_var_sroa.cpp | 1103 ++++++++++++++++++++++ source/opt/adv_interface_var_sroa.h | 240 +++++ source/opt/optimizer.cpp | 22 + source/opt/passes.h | 3 + test/opt/CMakeLists.txt | 5 +- test/opt/adv_interface_var_sroa_test.cpp | 548 +++++++++++ 8 files changed, 1962 insertions(+), 1 deletion(-) create mode 100644 source/opt/adv_interface_var_sroa.cpp create mode 100644 source/opt/adv_interface_var_sroa.h create mode 100644 test/opt/adv_interface_var_sroa_test.cpp diff --git a/include/spirv-tools/optimizer.hpp b/include/spirv-tools/optimizer.hpp index 6df8ad211a..ac92005525 100644 --- a/include/spirv-tools/optimizer.hpp +++ b/include/spirv-tools/optimizer.hpp @@ -1024,6 +1024,22 @@ Optimizer::PassToken CreateAndroidDriverPatchPass(); Optimizer::PassToken CreateReduceConstArrayToStructPass(); // UE Change End: Added support for reducing const arrays to structs +// UE Change Begin: Convert-Composite-To-Op-Access-Chain-Pass +Optimizer::PassToken CreateConvertCompositeToOpAccessChainPass(); +// UE Change End: Convert-Composite-To-Op-Access-Chain-Pass + +// UE Change Begin: Interface variable scalar replacement pass rewrite +// Create an adv-interface-variable-scalar-replacement pass that replaces array +// or matrix interface variables with a series of scalar or vector interface +// variables. For example, it replaces `float3 foo[2]` with `float3 foo0, foo1`. +// If |process_matrices| is true, matrix interface variables will be replaced by +// scalars. +// It handles more cases than existing interface-variable-scalar-replacement +// pass, and hopefully will replace that soon. +Optimizer::PassToken CreateAdvancedInterfaceVariableScalarReplacementPass( + bool process_matrices); +// UE Change End: Interface variable scalar replacement pass rewrite + } // namespace spvtools #endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_ diff --git a/source/opt/CMakeLists.txt b/source/opt/CMakeLists.txt index 2df47e78ae..dc786c3283 100644 --- a/source/opt/CMakeLists.txt +++ b/source/opt/CMakeLists.txt @@ -13,10 +13,15 @@ # limitations under the License. set(SPIRV_TOOLS_OPT_SOURCES fix_func_call_arguments.h +# UE Change Begin: Interface variable scalar replacement pass rewrite + adv_interface_var_sroa.h +# UE Change End: Interface variable scalar replacement pass rewrite aggressive_dead_code_elim_pass.h amd_ext_to_khr.h analyze_live_input_pass.h +# UE Change Begin: Added support for Android driver patch pass to fix platform specific issues android_driver_patch_pass.h +# UE Change End: Added support for Android driver patch pass to fix platform specific issues basic_block.h block_merge_pass.h block_merge_util.h @@ -60,7 +65,11 @@ set(SPIRV_TOOLS_OPT_SOURCES fold_spec_constant_op_and_composite_pass.h freeze_spec_constant_value_pass.h function.h +# UE Change Begin: Implement a fused-multiply-add pass to reduce the +# possibility of re-association. fused_multiply_add_pass.h +# UE Change End: Implement a fused-multiply-add pass to reduce the +# possibility of re-association. graphics_robust_access_pass.h if_conversion.h inline_exhaustive_pass.h @@ -104,8 +113,12 @@ set(SPIRV_TOOLS_OPT_SOURCES pass_manager.h private_to_local_pass.h propagator.h +# UE Change Begin: Added support for reducing const arrays to structs reduce_const_array_to_struct_pass.h +# UE Change End: Added support for reducing const arrays to structs +# UE Change Begin: Convert-Composite-To-Op-Access-Chain-Pass convert_composite_to_op_access_chain.h +# UE Change End: Convert-Composite-To-Op-Access-Chain-Pass reduce_load_size.h redundancy_elimination.h reflect.h @@ -140,10 +153,15 @@ set(SPIRV_TOOLS_OPT_SOURCES wrap_opkill.h fix_func_call_arguments.cpp +# UE Change Begin: Interface variable scalar replacement pass rewrite + adv_interface_var_sroa.cpp +# UE Change End: Interface variable scalar replacement pass rewrite aggressive_dead_code_elim_pass.cpp amd_ext_to_khr.cpp analyze_live_input_pass.cpp +# UE Change Begin: Added support for Android driver patch pass to fix platform specific issues android_driver_patch_pass.cpp +# UE Change End: Added support for Android driver patch pass to fix platform specific issues basic_block.cpp block_merge_pass.cpp block_merge_util.cpp @@ -186,7 +204,11 @@ set(SPIRV_TOOLS_OPT_SOURCES fold_spec_constant_op_and_composite_pass.cpp freeze_spec_constant_value_pass.cpp function.cpp +# UE Change Begin: Implement a fused-multiply-add pass to reduce the +# possibility of re-association. fused_multiply_add_pass.cpp +# UE Change End: Implement a fused-multiply-add pass to reduce the +# possibility of re-association. graphics_robust_access_pass.cpp if_conversion.cpp inline_exhaustive_pass.cpp @@ -228,8 +250,12 @@ set(SPIRV_TOOLS_OPT_SOURCES pass_manager.cpp private_to_local_pass.cpp propagator.cpp +# UE Change Begin: Added support for reducing const arrays to structs reduce_const_array_to_struct_pass.cpp +# UE Change End: Added support for reducing const arrays to structs +# UE Change Begin: Convert-Composite-To-Op-Access-Chain-Pass convert_composite_to_op_access_chain.cpp +# UE Change End: Convert-Composite-To-Op-Access-Chain-Pass reduce_load_size.cpp redundancy_elimination.cpp register_pressure.cpp diff --git a/source/opt/adv_interface_var_sroa.cpp b/source/opt/adv_interface_var_sroa.cpp new file mode 100644 index 0000000000..54b0b2913a --- /dev/null +++ b/source/opt/adv_interface_var_sroa.cpp @@ -0,0 +1,1103 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "source/opt/adv_interface_var_sroa.h" + +#include + +#include "source/opt/decoration_manager.h" +#include "source/opt/def_use_manager.h" +#include "source/opt/log.h" +#include "source/opt/type_manager.h" + +namespace spvtools { +namespace opt { +namespace { +constexpr uint32_t kOpDecorateDecorationInOperandIndex = 1; +constexpr uint32_t kOpDecorateLiteralInOperandIndex = 2; +constexpr uint32_t kOpEntryPointInOperandInterface = 3; +constexpr uint32_t kOpVariableStorageClassInOperandIndex = 0; +constexpr uint32_t kOpTypeArrayElemTypeInOperandIndex = 0; +constexpr uint32_t kOpTypeArrayLengthInOperandIndex = 1; +constexpr uint32_t kOpTypeMatrixColCountInOperandIndex = 1; +constexpr uint32_t kOpTypeMatrixColTypeInOperandIndex = 0; +constexpr uint32_t kOpTypePtrTypeInOperandIndex = 1; +constexpr uint32_t kOpConstantValueInOperandIndex = 0; + +// Get the length of the OpTypeArray |array_type|. +uint32_t GetArrayLength(analysis::DefUseManager* def_use_mgr, + Instruction* array_type) { + assert(array_type->opcode() == spv::Op::OpTypeArray); + uint32_t const_int_id = + array_type->GetSingleWordInOperand(kOpTypeArrayLengthInOperandIndex); + Instruction* array_length_inst = def_use_mgr->GetDef(const_int_id); + assert(array_length_inst->opcode() == spv::Op::OpConstant); + return array_length_inst->GetSingleWordInOperand( + kOpConstantValueInOperandIndex); +} + +// Get the element type instruction of the OpTypeArray |array_type|. +Instruction* GetArrayElementType(analysis::DefUseManager* def_use_mgr, + Instruction* array_type) { + assert(array_type->opcode() == spv::Op::OpTypeArray); + uint32_t elem_type_id = + array_type->GetSingleWordInOperand(kOpTypeArrayElemTypeInOperandIndex); + return def_use_mgr->GetDef(elem_type_id); +} + +// Get the column type instruction of the OpTypeMatrix |matrix_type|. +Instruction* GetMatrixColumnType(analysis::DefUseManager* def_use_mgr, + Instruction* matrix_type) { + assert(matrix_type->opcode() == spv::Op::OpTypeMatrix); + uint32_t column_type_id = + matrix_type->GetSingleWordInOperand(kOpTypeMatrixColTypeInOperandIndex); + return def_use_mgr->GetDef(column_type_id); +} + +// Returns the storage class of the instruction |var|. +spv::StorageClass GetStorageClass(Instruction* var) { + return static_cast( + var->GetSingleWordInOperand(kOpVariableStorageClassInOperandIndex)); +} + +// Creates an OpDecorate instruction whose Target is |var_id| and Decoration is +// |decoration|. Adds |literal| as an extra operand of the instruction. +void CreateDecoration(analysis::DecorationManager* decoration_mgr, + uint32_t var_id, spv::Decoration decoration, + uint32_t literal) { + std::vector operands({ + {SPV_OPERAND_TYPE_ID, {var_id}}, + {SPV_OPERAND_TYPE_DECORATION, {static_cast(decoration)}}, + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {literal}}, + }); + decoration_mgr->AddDecoration(spv::Op::OpDecorate, std::move(operands)); +} + +std::unique_ptr CreateAccessChain(IRContext* context, uint32_t id, + Instruction* base_var, + uint32_t type_id, + Operand index) { + assert(context); + assert(base_var); + + auto storage_class = GetStorageClass(base_var); + uint32_t ptr_type_id = + context->get_type_mgr()->FindPointerToType(type_id, storage_class); + + std::unique_ptr access_chain( + new Instruction(context, spv::Op::OpAccessChain, ptr_type_id, id, + {{SPV_OPERAND_TYPE_ID, {base_var->result_id()}}, index})); + + return access_chain; +} + +// Creates an OpCompositeExtract instruction to extract the part with Result +// type |type_id| from the Composite that is |input_id| and Indexes are +// |indices|. If optional extra array index |extra_array_index| is passed, it is +// injected as a very first index. +std::unique_ptr CreateCompositeExtract( + IRContext* context, uint32_t id, uint32_t type_id, uint32_t input_id, + const std::vector& indices, uint32_t* extra_array_index) { + assert(context); + assert(!indices.empty()); + std::unique_ptr extract( + new Instruction(context, spv::Op::OpCompositeExtract, type_id, id, + {{SPV_OPERAND_TYPE_ID, {input_id}}})); + if (extra_array_index) { + extract->AddOperand( + {SPV_OPERAND_TYPE_LITERAL_INTEGER, {*extra_array_index}}); + } + for (uint32_t i : indices) { + extract->AddOperand({SPV_OPERAND_TYPE_LITERAL_INTEGER, {i}}); + } + return extract; +} + +// Creates an OpStore instruction to store value |what_id| to pointer +// |where_id|, while copying the memory attributes from another instruction +// |original_store|. +std::unique_ptr CreateStore(IRContext* context, uint32_t where_id, + uint32_t what_id, + Instruction* original_store) { + assert(context); + assert(original_store); + + std::unique_ptr store(new Instruction( + context, spv::Op::OpStore, 0, 0, + {{SPV_OPERAND_TYPE_ID, {where_id}}, {SPV_OPERAND_TYPE_ID, {what_id}}})); + + // Copy memory access attributes which start at index 2. Index 0 is the + // pointer and index 1 is the data. + for (uint32_t i = 2; i < original_store->NumInOperands(); ++i) { + store->AddOperand(original_store->GetInOperand(i)); + } + + return store; +} + +// Creates an OpLoad instruction with id |load_id| to load value of type +// |type_id| from |ptr_id|, while copying the memory attributes from another +// instruction |original_load|. +std::unique_ptr CreateLoad(IRContext* context, uint32_t type_id, + uint32_t ptr_id, uint32_t load_id, + Instruction* original_load) { + assert(context); + assert(original_load); + std::unique_ptr load( + new Instruction(context, spv::Op::OpLoad, type_id, load_id, + {{SPV_OPERAND_TYPE_ID, {ptr_id}}})); + // Copy memory access attributes which start at index 1. Index 0 is + // the pointer to load. + for (uint32_t i = 1; i < original_load->NumInOperands(); ++i) { + load->AddOperand(original_load->GetInOperand(i)); + } + + return load; +} + +} // namespace + +Pass::Status AdvancedInterfaceVariableScalarReplacement::Process() { + Pass::Status status = Status::SuccessWithoutChange; + for (Instruction& entry_point : get_module()->entry_points()) { + status = CombineStatus(status, ProcessEntryPoint(entry_point)); + } + return status; +} + +Pass::Status AdvancedInterfaceVariableScalarReplacement::ProcessEntryPoint( + Instruction& entry_point) { + std::vector interface_vars = + CollectInterfaceVariables(entry_point); + Pass::Status status = Status::SuccessWithoutChange; + std::unordered_set replaced_interface_vars; + std::vector scalar_vars; + + for (Instruction* var : interface_vars) { + uint32_t location; + if (!GetVariableLocation(var, &location)) continue; + + Instruction* var_type = GetTypeOfVariable(var); + uint32_t extra_array_length = 0; + if (HasExtraArrayness(entry_point, var)) { + extra_array_length = + GetArrayLength(context()->get_def_use_mgr(), var_type); + var_type = GetArrayElementType(context()->get_def_use_mgr(), var_type); + vars_with_extra_arrayness.insert(var); + } else { + vars_without_extra_arrayness.insert(var); + } + + InterfaceVar interface_var(var, var_type, extra_array_length); + + if (!CheckExtraArraynessConflictBetweenEntries(interface_var)) { + return Pass::Status::Failure; + } + + spv::Op opcode = var_type->opcode(); + bool should_process = false; + should_process |= opcode == spv::Op::OpTypeArray; + should_process |= process_matrices_ && opcode == spv::Op::OpTypeMatrix; + if (!should_process) { + continue; + } + + replaced_interface_vars.insert(var->result_id()); + if (!ReplaceInterfaceVariable(interface_var, location, &scalar_vars)) { + return Pass::Status::Failure; + } + + status = Pass::Status::SuccessWithChange; + } + + ReplaceInEntryPoint(&entry_point, replaced_interface_vars, scalar_vars); + + return status; +} + +bool AdvancedInterfaceVariableScalarReplacement::ReplaceInterfaceVariable( + InterfaceVar var, uint32_t location, + std::vector* all_scalar_vars) { + assert(all_scalar_vars); + + std::vector scalar_vars; + Replacement replacement = CreateReplacementVariables(var, &scalar_vars); + assert(!scalar_vars.empty()); + + for (auto* scalar_var : scalar_vars) { + all_scalar_vars->push_back(scalar_var); + } + + uint32_t component = 0; + bool has_component_decoration = GetVariableComponent(var.def, &component); + AddLocationAndComponentDecorations( + replacement, &location, has_component_decoration ? &component : nullptr); + KillLocationAndComponentDecorations(var.def->result_id()); + + std::vector decoration_work_list; + std::vector access_chain_work_list; + struct LoadStore { + // Original interface variable touching instruction. + Instruction* to_be_replaced; + // Node representing the replacement for the part of interface variable, + // the instruction targets. + const Replacement* target; + // This is set only if instruction uses the extra arrayed scalar var. + Instruction* optional_access_chain = nullptr; + }; + std::vector load_work_list; + std::vector store_work_list; + + // Finds out all the interface variable usages to populate the work lists. + bool failed = !get_def_use_mgr()->WhileEachUser( + var.def, [this, &replacement, &decoration_work_list, &load_work_list, + &store_work_list, &access_chain_work_list](Instruction* user) { + if (user->IsDecoration()) { + decoration_work_list.push_back(user); + return true; + } + + switch (user->opcode()) { + case spv::Op::OpEntryPoint: + // Nothing to do here, it is handled later in |ProcessEntryPoint|. + return true; + case spv::Op::OpName: + decoration_work_list.push_back(user); + return true; + case spv::Op::OpLoad: + load_work_list.push_back({user, &replacement}); + return true; + case spv::Op::OpStore: + store_work_list.push_back({user, &replacement}); + return true; + case spv::Op::OpAccessChain: + case spv::Op::OpInBoundsAccessChain: + access_chain_work_list.push_back(user); + return true; + default: + context()->EmitErrorMessage( + "Variable cannot be replaced: unexpected instruction", user); + return false; + } + }); + + if (failed) { + // Error has been reported already. + return false; + } + + std::vector dead; + + for (Instruction* decoration : decoration_work_list) { + spv::Op opcode = decoration->opcode(); + // Name decorations are already created for each replacement scalar + // variable. + if (opcode != spv::Op::OpName) { + for (const auto* scalar_var : scalar_vars) { + CloneAnnotationForVariable(decoration, scalar_var->result_id()); + } + } + + // Decorations will be killed together with the variable instruction, + // there is no need to add anything to |dead|. + } + + // Access chains are processed as a stack, as there might exist chains of + // access chains, which must be eventually fully replaced with loads/stores. + // Hence, processing of one access chain, might add more work to this stack. + // IMPORTANT: Access chains are processed _before_ the loads/stores as this + // processing can create more work for the loads/stores one. + while (!access_chain_work_list.empty()) { + Instruction* access_chain = access_chain_work_list.back(); + access_chain_work_list.pop_back(); + + assert(access_chain->opcode() == spv::Op::OpAccessChain || + access_chain->opcode() == spv::Op::OpInBoundsAccessChain); + assert(access_chain->NumInOperands() > 1 && + "OpAccessChain does not have Indexes operand"); + + // We are going to replace the access chain with either direct usage of the + // replacement scalar variable, or a set of composite loads/stores. + + const Replacement* target = + LookupReplacement(access_chain, &replacement, var.extra_array_length); + if (!target) { + // Error has been already logged by |LookupReplacement|. + return false; + } + + if (!target->HasChildren() && var.extra_array_length == 0) { + // Replace with a direct use of the scalar variable. + auto scalar = target->GetScalarVariable(); + assert(scalar); + context()->ReplaceAllUsesWith(access_chain->result_id(), + scalar->result_id()); + } else { + // The current access chain's target is a composite, meaning that there + // are other instructions using the pointer. We need to convert those to + // use the replacement scalar variables. + failed = !get_def_use_mgr()->WhileEachUser( + access_chain, [this, target, &access_chain_work_list, &load_work_list, + &store_work_list, access_chain](Instruction* user) { + switch (user->opcode()) { + case spv::Op::OpLoad: + load_work_list.push_back({user, target, access_chain}); + return true; + case spv::Op::OpStore: + store_work_list.push_back({user, target, access_chain}); + return true; + case spv::Op::OpAccessChain: + case spv::Op::OpInBoundsAccessChain: + access_chain_work_list.push_back(user); + return true; + default: + context()->EmitErrorMessage( + "Variable cannot be replaced: unexpected instruction", + user); + return false; + } + }); + + if (failed) { + return false; + } + } + + dead.push_back(access_chain); + } + + for (auto [load, target_replacement, opt_access_chain] : load_work_list) { + if (!ReplaceLoad(load, *target_replacement, opt_access_chain, + var.extra_array_length)) { + return false; + } + dead.push_back(load); + } + + for (auto [store, target_replacement, opt_access_chain] : store_work_list) { + if (!ReplaceStore(store, *target_replacement, opt_access_chain, + var.extra_array_length)) { + return false; + } + dead.push_back(store); + } + + dead.push_back(var.def); + + while (!dead.empty()) { + Instruction* to_kill = dead.back(); + dead.pop_back(); + context()->KillInst(to_kill); + } + + return true; +} + +bool AdvancedInterfaceVariableScalarReplacement::ReplaceInEntryPoint( + Instruction* entry_point, + const std::unordered_set& interface_vars, + const std::vector& scalar_vars) { + Instruction::OperandList new_operands; + + if (scalar_vars.empty()) { + return true; + } + + // Copy all operands except all interface variables, which will be replaced. + bool found = false; + for (uint32_t i = 0; i < entry_point->NumOperands(); ++i) { + Operand& op = entry_point->GetOperand(i); + if (op.type == SPV_OPERAND_TYPE_ID && + interface_vars.find(op.words[0]) != interface_vars.end()) { + found = true; + } else { + new_operands.emplace_back(std::move(op)); + } + } + + if (!found) { + context()->EmitErrorMessage( + "Interface variables are not operands of the entry point", entry_point); + return false; + } + + // Add all the new replacement variables. + for (auto scalar : scalar_vars) { + new_operands.push_back({SPV_OPERAND_TYPE_ID, {scalar->result_id()}}); + } + + entry_point->ReplaceOperands(new_operands); + context()->UpdateDefUse(entry_point); + + return true; +} + +bool AdvancedInterfaceVariableScalarReplacement::ReplaceLoad( + Instruction* load, const Replacement& replacement, + Instruction* optional_access_chain, uint32_t extra_array_length) { + assert(load && load->opcode() == spv::Op::OpLoad); + + const auto insert_before = + [this, load](Instruction* where, + std::unique_ptr what) -> Instruction* { + auto inst = where->InsertBefore(std::move(what)); + inst->UpdateDebugInfoFrom(load); + get_def_use_mgr()->AnalyzeInstDefUse(inst); + return inst; + }; + + std::vector pending_instructions; + // We do a post-order traversal of the tree of composite replacements to emit + // properly nested loads and composite constructions to match the original + // interface variable shape. + std::vector> todo; + + uint32_t num_passes = 1; + // If we have an optional access chain, we need to load a single element of + // the extra array. Otherwise, we load it fully. + if (!optional_access_chain && extra_array_length != 0) { + num_passes = extra_array_length; + } + for (uint32_t pass = 0; pass < num_passes; ++pass) { + std::optional extra_array_index; + if (extra_array_length != 0) { + if (optional_access_chain) { + extra_array_index = optional_access_chain->GetInOperand(1); + } else { + uint32_t index_id = context()->get_constant_mgr()->GetUIntConstId(pass); + extra_array_index = {SPV_OPERAND_TYPE_ID, {index_id}}; + } + } + todo.push_back({&replacement, false}); + + while (!todo.empty()) { + const auto [node, inserted] = todo.back(); + assert(node); + + if (inserted) { + todo.pop_back(); + + if (node->HasChildren()) { + // Construct the composite component from already loaded scalars. + uint32_t composite_id = TakeNextId(); + if (composite_id == 0) { + return false; + } + std::unique_ptr construct( + new Instruction(context(), spv::Op::OpCompositeConstruct, + node->GetTypeId(), composite_id, {})); + + // As we are doing a post-order traversal, out children instructions + // should already be laid out and ready to be used as our operands. + const auto& children = node->GetChildren(); + size_t num_children_left = children.size(); + assert(pending_instructions.size() >= num_children_left && + "Post-order traversal is broken"); + size_t i = pending_instructions.size() - num_children_left; + while (num_children_left > 0) { + construct->AddOperand( + {SPV_OPERAND_TYPE_ID, {pending_instructions[i]->result_id()}}); + + ++i; + --num_children_left; + } + for (size_t i = 0; i < children.size(); ++i) { + pending_instructions.pop_back(); + } + + auto inst = insert_before(load, std::move(construct)); + pending_instructions.push_back(inst); + } else { + auto scalar = node->GetScalarVariable(); + assert(scalar); + + Instruction* ptr = scalar; + + if (extra_array_index.has_value()) { + // Indirection access chain to get a pointer to the extra array + // element. + + uint32_t indirection_id = TakeNextId(); + if (indirection_id == 0) { + return false; + } + + std::unique_ptr access_chain = + CreateAccessChain(context(), indirection_id, ptr, + node->GetTypeId(), extra_array_index.value()); + ptr = insert_before(load, std::move(access_chain)); + } + + uint32_t subload_id = TakeNextId(); + if (subload_id == 0) { + return false; + } + + std::unique_ptr subload = CreateLoad( + context(), node->GetTypeId(), ptr->result_id(), subload_id, load); + + auto inst = insert_before(load, std::move(subload)); + pending_instructions.push_back(inst); + } + } else { + todo.back().second = true; + + const auto& children = node->GetChildren(); + for (const auto& child : + make_range(children.rbegin(), children.rend())) { + todo.push_back({&child, false}); + } + } + } + } + assert(pending_instructions.size() == num_passes); + if (num_passes > 1) { + uint32_t extra_array_type_id = + GetArrayType(replacement.GetTypeId(), extra_array_length); + + // Construct the composite component from already loaded scalars. + uint32_t extra_array_id = TakeNextId(); + if (extra_array_id == 0) { + return false; + } + std::unique_ptr extra_construct( + new Instruction(context(), spv::Op::OpCompositeConstruct, + extra_array_type_id, extra_array_id, {})); + for (auto& pending : pending_instructions) { + Operand op(SPV_OPERAND_TYPE_ID, {pending->result_id()}); + extra_construct->AddOperand(std::move(op)); + } + auto inst = insert_before(load, std::move(extra_construct)); + pending_instructions.push_back(inst); + } + + context()->ReplaceAllUsesWith(load->result_id(), + pending_instructions.back()->result_id()); + return true; +} + +bool AdvancedInterfaceVariableScalarReplacement::ReplaceStore( + Instruction* store, const Replacement& replacement, + Instruction* optional_access_chain, uint32_t extra_array_length) { + assert(store && store->opcode() == spv::Op::OpStore); + + uint32_t input_id = store->GetSingleWordInOperand(1); + + // This is a managed stack of indices, which will contain a chain of indices + // coming to the currently processed node. + std::vector indices_chain; + + struct Entry { + // Currently processed node. + const Replacement* node; + // Local index of the node inside of the parent. + uint32_t index; + // Current node depth in the nodes tree. + size_t depth; + }; + std::vector todo; + todo.push_back({&replacement, 0, 0}); + size_t current_depth = 0; + + // We do an in-order traversal of the tree of composite replacements to emit + // proper stores with composite extracts to get the data we need, considering + // the original interface variable shape. + while (!todo.empty()) { + const auto entry = todo.back(); + const auto node = entry.node; + const auto index = entry.index; + const auto depth = entry.depth; + todo.pop_back(); + + assert(node); + + while (current_depth > depth) { + indices_chain.pop_back(); + --current_depth; + } + current_depth = depth; + if (node != &replacement) { + indices_chain.push_back(index); + } + + if (node->HasChildren()) { + const auto& children = node->GetChildren(); + uint32_t child_index = uint32_t(children.size()); + while (child_index > 0) { + --child_index; + todo.push_back( + {&children[child_index], child_index, current_depth + 1}); + } + } else { + const auto insert_before = + [this, store](Instruction* where, + std::unique_ptr what) -> Instruction* { + auto inst = where->InsertBefore(std::move(what)); + inst->UpdateDebugInfoFrom(store); + get_def_use_mgr()->AnalyzeInstDefUse(inst); + return inst; + }; + + const auto store_to_scalar = [this, store, node, &indices_chain, + &insert_before]( + uint32_t value_to_store_id, + uint32_t* extra_array_index_for_extract, + uint32_t* extra_array_index_id) -> bool { + // This one is empty if replacement root is already a scalar, + // e.g. ivar[1][2] = scalar; + // hence we do not need the compositeextract. + if (!indices_chain.empty()) { + uint32_t extract_id = TakeNextId(); + if (extract_id == 0) { + return false; + } + + // Composite extract the nested scalar value. + std::unique_ptr extract = CreateCompositeExtract( + context(), extract_id, node->GetTypeId(), value_to_store_id, + indices_chain, extra_array_index_for_extract); + + insert_before(store, std::move(extract)); + + // To be used by the OpStore below. + value_to_store_id = extract_id; + } + + auto scalar = node->GetScalarVariable(); + assert(scalar); + + Instruction* ptr = scalar; + // Indirection access chain to get a pointer to the extra array + // element. + if (extra_array_index_id) { + uint32_t indirection_id = TakeNextId(); + if (indirection_id == 0) { + return false; + } + + std::unique_ptr access_chain = CreateAccessChain( + context(), indirection_id, ptr, node->GetTypeId(), + {SPV_OPERAND_TYPE_ID, {*extra_array_index_id}}); + + ptr = insert_before(store, std::move(access_chain)); + } + + // Store the value to the corresponding variable. + std::unique_ptr store_to_scalar = + CreateStore(context(), ptr->result_id(), value_to_store_id, store); + + insert_before(store, std::move(store_to_scalar)); + + return true; + }; + + bool ok = true; + if (extra_array_length == 0) { + ok = store_to_scalar(input_id, nullptr, nullptr); + } else if (optional_access_chain) { + uint32_t indirect_index = + optional_access_chain->GetSingleWordInOperand(1); + ok = store_to_scalar(input_id, nullptr, &indirect_index); + } else { + for (uint32_t i = 0; i < extra_array_length; ++i) { + uint32_t extra_array_index_id = + context()->get_constant_mgr()->GetUIntConstId(i); + + ok &= store_to_scalar(input_id, &i, &extra_array_index_id); + } + } + + if (!ok) { + return false; + } + + // It might be empty if current node is both scalar and a root. + if (!indices_chain.empty()) { + indices_chain.pop_back(); + } + } + } + + return true; +} + +const AdvancedInterfaceVariableScalarReplacement::Replacement* +AdvancedInterfaceVariableScalarReplacement::LookupReplacement( + Instruction* access_chain, const Replacement* root, + uint32_t extra_array_length) { + assert(access_chain); + + analysis::ConstantManager* const_mgr = context()->get_constant_mgr(); + + // In case of extra arrayness, the first index always targets that extra + // array, hence we skip it when looking-up the rest. + uint32_t start_index = extra_array_length == 0 ? 1 : 2; + + // Finds the target replacement, which might be a scalar or nested + // composite. + for (uint32_t i = start_index; i < access_chain->NumInOperands(); ++i) { + uint32_t index_id = access_chain->GetSingleWordInOperand(i); + + const analysis::Constant* index_constant = + const_mgr->FindDeclaredConstant(index_id); + if (!index_constant) { + context()->EmitErrorMessage( + "Variable cannot be replaced: index is not constant", access_chain); + return nullptr; + } + + assert(root->HasChildren()); + const auto& children = root->GetChildren(); + + // OpAccessChain treats indices as signed. + int64_t index_value = index_constant->GetSignExtendedValue(); + if (index_value < 0 || + index_value >= static_cast(children.size())) { + // Out of bounds access, this is illegal IR. + // Notice that OpAccessChain indexing is 0-based, so we should also + // reject index == size-of-array. + context()->EmitErrorMessage("Variable cannot be replaced: invalid index", + access_chain); + return nullptr; + } + + root = &children[index_value]; + } + return root; +} + +AdvancedInterfaceVariableScalarReplacement::Replacement +AdvancedInterfaceVariableScalarReplacement::CreateReplacementVariables( + InterfaceVar var, std::vector* scalar_vars) { + assert(scalar_vars); + + auto* def_use_mgr = get_def_use_mgr(); + auto storage_class = GetStorageClass(var.def); + + // Composite replacement tree we are building here. + Replacement root(var.type->result_id()); + // Names for newly added scalars. + std::vector> names_to_add; + + // A managed stack of indices, which will contain a chain of indices coming to + // the currently processed replacement node. + std::vector indices_chain; + + struct Entry { + // Currently processed node. + Replacement* node; + // Type of the interface variable part, which this node is about. + Instruction* var_type; + // Local index of the node inside of the parent. + uint32_t index; + // Current node depth in the nodes tree. + size_t depth; + }; + std::vector todo; + todo.push_back({&root, var.type, 0, 0}); + size_t current_depth = 0; + + while (!todo.empty()) { + const auto [node, type, index, depth] = todo.back(); + todo.pop_back(); + + assert(node); + + while (current_depth > depth) { + indices_chain.pop_back(); + --current_depth; + } + current_depth = depth; + if (node != &root) { + indices_chain.push_back(index); + } + + spv::Op opcode = type->opcode(); + if (opcode == spv::Op::OpTypeArray || opcode == spv::Op::OpTypeMatrix) { + // Handle array and matrix case. + + uint32_t length = 0; + Instruction* child_type{}; + + switch (type->opcode()) { + case spv::Op::OpTypeArray: + length = GetArrayLength(def_use_mgr, type); + child_type = GetArrayElementType(def_use_mgr, type); + break; + case spv::Op::OpTypeMatrix: + length = + type->GetSingleWordInOperand(kOpTypeMatrixColCountInOperandIndex); + child_type = GetMatrixColumnType(def_use_mgr, type); + break; + default: + assert(false && "Unexpected type."); + break; + } + assert(child_type); + uint32_t child_type_id = child_type->result_id(); + + for (uint32_t i = 0; i < length; ++i) { + node->AppendChild(child_type_id); + } + + auto& children = node->GetChildren(); + while (length > 0) { + --length; + todo.push_back( + {&children[length], child_type, length, current_depth + 1}); + } + } else { + // Handle scalar or vector case. + + std::unique_ptr variable = CreateVariable( + type->result_id(), storage_class, var.def, var.extra_array_length); + + node->SetSingleScalarVariable(variable.get()); + scalar_vars->push_back(variable.get()); + + uint32_t var_id = variable->result_id(); + context()->AddGlobalValue(std::move(variable)); + GenerateNames(var.def->result_id(), var_id, indices_chain, &names_to_add); + + indices_chain.pop_back(); + } + } + + // We shouldn't add the new names when we are iterating over name ranges + // above. We can add all the new names now. + for (auto& new_name : names_to_add) { + context()->AddDebug2Inst(std::move(new_name)); + } + + return root; +} + +std::unique_ptr AdvancedInterfaceVariableScalarReplacement::CreateVariable( + uint32_t type_id, spv::StorageClass storage_class, + const Instruction* debug_info_source, uint32_t extra_array_length) { + assert(debug_info_source); + + if (extra_array_length != 0) { + type_id = GetArrayType(type_id, extra_array_length); + } + + uint32_t ptr_type_id = + context()->get_type_mgr()->FindPointerToType(type_id, storage_class); + + uint32_t id = TakeNextId(); + if (id == 0) { + return {}; + } + + std::unique_ptr variable( + new Instruction(context(), spv::Op::OpVariable, ptr_type_id, id, + {{SPV_OPERAND_TYPE_STORAGE_CLASS, + {static_cast(storage_class)}}})); + variable->UpdateDebugInfoFrom(debug_info_source); + + return variable; +} + +void AdvancedInterfaceVariableScalarReplacement::GenerateNames( + uint32_t source_id, uint32_t destination_id, + const std::vector& indices, + std::vector>* names_to_add) { + assert(names_to_add); + auto* def_use_mgr = get_def_use_mgr(); + for (auto [_, name_inst] : context()->GetNames(source_id)) { + std::string name_str = utils::MakeString(name_inst->GetOperand(1).words); + for (uint32_t i : indices) { + name_str += "[" + utils::ToString(i) + "]"; + } + + std::unique_ptr new_name(new Instruction( + context(), spv::Op::OpName, 0, 0, + {{SPV_OPERAND_TYPE_ID, {destination_id}}, + {SPV_OPERAND_TYPE_LITERAL_STRING, utils::MakeVector(name_str)}})); + def_use_mgr->AnalyzeInstDefUse(new_name.get()); + names_to_add->push_back(std::move(new_name)); + } +} + +bool AdvancedInterfaceVariableScalarReplacement:: + CheckExtraArraynessConflictBetweenEntries(InterfaceVar var) { + if (var.extra_array_length != 0) { + return !ReportErrorIfHasNoExtraArraynessForOtherEntry(var.def); + } + return !ReportErrorIfHasExtraArraynessForOtherEntry(var.def); +} + +bool AdvancedInterfaceVariableScalarReplacement::GetVariableLocation( + Instruction* var, uint32_t* location) { + return !context()->get_decoration_mgr()->WhileEachDecoration( + var->result_id(), uint32_t(spv::Decoration::Location), + [location](const Instruction& inst) { + *location = + inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); + return false; + }); +} + +bool AdvancedInterfaceVariableScalarReplacement::GetVariableComponent( + Instruction* var, uint32_t* component) { + return !context()->get_decoration_mgr()->WhileEachDecoration( + var->result_id(), uint32_t(spv::Decoration::Component), + [component](const Instruction& inst) { + *component = + inst.GetSingleWordInOperand(kOpDecorateLiteralInOperandIndex); + return false; + }); +} + +std::vector +AdvancedInterfaceVariableScalarReplacement::CollectInterfaceVariables( + Instruction& entry_point) { + std::vector interface_vars; + for (uint32_t i = kOpEntryPointInOperandInterface; + i < entry_point.NumInOperands(); ++i) { + Instruction* interface_var = context()->get_def_use_mgr()->GetDef( + entry_point.GetSingleWordInOperand(i)); + assert(interface_var->opcode() == spv::Op::OpVariable); + + spv::StorageClass storage_class = GetStorageClass(interface_var); + if (storage_class != spv::StorageClass::Input && + storage_class != spv::StorageClass::Output) { + continue; + } + + interface_vars.push_back(interface_var); + } + return interface_vars; +} + +void AdvancedInterfaceVariableScalarReplacement::KillLocationAndComponentDecorations( + uint32_t var_id) { + context()->get_decoration_mgr()->RemoveDecorationsFrom( + var_id, [](const Instruction& inst) { + auto decoration = spv::Decoration( + inst.GetSingleWordInOperand(kOpDecorateDecorationInOperandIndex)); + return decoration == spv::Decoration::Location || + decoration == spv::Decoration::Component; + }); +} + +void AdvancedInterfaceVariableScalarReplacement:: + AddLocationAndComponentDecorations(const Replacement& vars, + uint32_t* location, + uint32_t* optional_component) { + if (!vars.HasChildren()) { + uint32_t var_id = vars.GetScalarVariable()->result_id(); + CreateDecoration(context()->get_decoration_mgr(), var_id, + spv::Decoration::Location, *location); + if (optional_component) { + CreateDecoration(context()->get_decoration_mgr(), var_id, + spv::Decoration::Component, *optional_component); + } + ++(*location); + return; + } + for (const auto& var : vars.GetChildren()) { + AddLocationAndComponentDecorations(var, location, optional_component); + } +} + +void AdvancedInterfaceVariableScalarReplacement::CloneAnnotationForVariable( + Instruction* annotation_inst, uint32_t var_id) { + assert(annotation_inst->opcode() == spv::Op::OpDecorate || + annotation_inst->opcode() == spv::Op::OpDecorateId || + annotation_inst->opcode() == spv::Op::OpDecorateString); + std::unique_ptr new_inst(annotation_inst->Clone(context())); + new_inst->SetInOperand(0, {var_id}); + context()->AddAnnotationInst(std::move(new_inst)); +} + +bool AdvancedInterfaceVariableScalarReplacement::HasExtraArrayness( + Instruction& entry_point, Instruction* var) { + auto execution_model = + static_cast(entry_point.GetSingleWordInOperand(0)); + if (execution_model != spv::ExecutionModel::TessellationEvaluation && + execution_model != spv::ExecutionModel::TessellationControl) { + return false; + } + if (!context()->get_decoration_mgr()->HasDecoration( + var->result_id(), uint32_t(spv::Decoration::Patch))) { + if (execution_model == spv::ExecutionModel::TessellationControl) + return true; + return GetStorageClass(var) != spv::StorageClass::Output; + } + return false; +} + +uint32_t AdvancedInterfaceVariableScalarReplacement::GetPointeeTypeIdOfVar( + Instruction* var) { + assert(var->opcode() == spv::Op::OpVariable); + + uint32_t ptr_type_id = var->type_id(); + analysis::DefUseManager* def_use_mgr = context()->get_def_use_mgr(); + Instruction* ptr_type_inst = def_use_mgr->GetDef(ptr_type_id); + + assert(ptr_type_inst->opcode() == spv::Op::OpTypePointer && + "Variable must have a pointer type."); + return ptr_type_inst->GetSingleWordInOperand(kOpTypePtrTypeInOperandIndex); +} + +uint32_t AdvancedInterfaceVariableScalarReplacement::GetArrayType( + uint32_t elem_type_id, uint32_t array_length) { + analysis::Type* elem_type = context()->get_type_mgr()->GetType(elem_type_id); + uint32_t array_length_id = + context()->get_constant_mgr()->GetUIntConstId(array_length); + analysis::Array array_type( + elem_type, + analysis::Array::LengthInfo{array_length_id, {0, array_length}}); + return context()->get_type_mgr()->GetTypeInstruction(&array_type); +} + +Instruction* AdvancedInterfaceVariableScalarReplacement::GetTypeOfVariable( + Instruction* var) { + assert(var->opcode() == spv::Op::OpVariable); + uint32_t pointee_type_id = GetPointeeTypeIdOfVar(var); + return context()->get_def_use_mgr()->GetDef(pointee_type_id); +} + +bool AdvancedInterfaceVariableScalarReplacement:: + ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var) { + if (vars_with_extra_arrayness.find(var) == vars_with_extra_arrayness.end()) + return false; + + std::string message( + "A variable is arrayed for an entry point but it is not " + "arrayed for another entry point"); + message += + "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return true; +} + +bool AdvancedInterfaceVariableScalarReplacement:: + ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var) { + if (vars_without_extra_arrayness.find(var) == + vars_without_extra_arrayness.end()) + return false; + + std::string message( + "A variable is not arrayed for an entry point but it is " + "arrayed for another entry point"); + message += + "\n " + var->PrettyPrint(SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES); + context()->consumer()(SPV_MSG_ERROR, "", {0, 0, 0}, message.c_str()); + return true; +} + +} // namespace opt +} // namespace spvtools diff --git a/source/opt/adv_interface_var_sroa.h b/source/opt/adv_interface_var_sroa.h new file mode 100644 index 0000000000..4c421636b1 --- /dev/null +++ b/source/opt/adv_interface_var_sroa.h @@ -0,0 +1,240 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SOURCE_OPT_ADV_INTERFACE_VAR_SROA_H_ +#define SOURCE_OPT_ADV_INTERFACE_VAR_SROA_H_ + +#include + +#include "source/opt/pass.h" + +namespace spvtools { +namespace opt { + +// See optimizer.hpp for documentation. +// +// Note that the there is another existing pass +// InterfaceVariableScalarReplacement, which doesn't handle tricky instruction +// chains and interface variables which are arrays of scalars. The plan is to +// replace that pass with this one. +class AdvancedInterfaceVariableScalarReplacement : public Pass { + public: + AdvancedInterfaceVariableScalarReplacement(bool process_matrices) + : process_matrices_(process_matrices) {} + + const char* name() const override { + return "adv-interface-variable-scalar-replacement"; + } + Status Process() override; + + IRContext::Analysis GetPreservedAnalyses() override { + return IRContext::kAnalysisDecorations | IRContext::kAnalysisDefUse | + IRContext::kAnalysisConstants | IRContext::kAnalysisTypes; + } + + private: + // A struct describing a single interface variable. + struct InterfaceVar { + // The corresponding OpVariable. + Instruction* def; + // The corresponding OpType*. + Instruction* type; + // If |extra_array_length| is not 0, it means that this interface variable + // has a Patch decoration. This will add extra-arrayness to the replacing + // scalar variables. + uint32_t extra_array_length; + + InterfaceVar(Instruction* def, Instruction* type, + uint32_t extra_array_length) + : def(def), type(type), extra_array_length(extra_array_length) { + assert(def); + assert(type); + } + }; + + // A struct containing components of a composite interface variable. If the + // composite consists of multiple or recursive components, |scalar_var| is + // nullptr and |children| keeps the nested components. If it has a single + // component, |children| is empty and |scalar_var| is the component. Note that + // each element of |children| has the Replacement struct as its type that can + // recursively keep the components. + struct Replacement { + explicit Replacement(uint32_t type_id) + : scalar_var(nullptr), type_id(type_id) {} + + bool HasChildren() const { return !children.empty(); } + + std::vector& GetChildren() { + return children; + } + + const std::vector& GetChildren() const { + return children; + } + + Replacement& AppendChild(uint32_t child_type_id) { + assert(!scalar_var && "Can add children only for non-scalars."); + return children.emplace_back(child_type_id); + } + + Instruction* GetScalarVariable() const { return scalar_var; } + + void SetSingleScalarVariable(Instruction* var) { scalar_var = var; } + + uint32_t GetTypeId() const { return type_id; } + + private: + std::vector children; + Instruction* scalar_var; + uint32_t type_id; + }; + + // Collects all interface variables used by the |entry_point|. + std::vector CollectInterfaceVariables(Instruction& entry_point); + + // Returns whether |var| has the extra arrayness for the entry point + // |entry_point| or not. + bool HasExtraArrayness(Instruction& entry_point, Instruction* var); + + // Finds a Location BuiltIn decoration of |var| and returns it via + // |location|. Returns true whether the location exists or not. + bool GetVariableLocation(Instruction* var, uint32_t* location); + + // Finds a Component BuiltIn decoration of |var| and returns it via + // |component|. Returns true whether the component exists or not. + bool GetVariableComponent(Instruction* var, uint32_t* component); + + // Returns the type of |var| as an instruction. + Instruction* GetTypeOfVariable(Instruction* var); + + // Replaces an interface variable |var| with scalars and returns whether it + // succeeds or not. |location| is the value of Location Decoration for |var|. + // |all_scalar_vars| will be appended with the replacement scalar vars for + // |var|. + bool ReplaceInterfaceVariable(InterfaceVar var, uint32_t location, + std::vector* all_scalar_vars); + + // Creates scalar variables to replace an interface variable |var|. + // |scalar_vars| will be filled as a list of all replacement scalar variables. + // As |Replacement| represents a tree, shaped as the original interface + // variable, this list will contain every leaf from that tree, stored in the + // depth-first order. + Replacement CreateReplacementVariables( + InterfaceVar var, std::vector* scalar_vars); + + // Recursively adds Location and Component decorations to variables in + // |vars| with |location| and |optional_component|. Increases |location| by + // one after it actually adds Location and Component decorations for a + // variable. + void AddLocationAndComponentDecorations(const Replacement& vars, + uint32_t* location, + uint32_t* optional_component); + + // Clones an annotation instruction |annotation_inst| and sets the target + // operand of the new annotation instruction as |var_id|. + void CloneAnnotationForVariable(Instruction* annotation_inst, + uint32_t var_id); + + // Replaces all the interface variables, which will be replaced, in the + // operands of the entry point |entry_point| with a set of variables from the + // |scalar_vars|. + bool ReplaceInEntryPoint(Instruction* entry_point, + const std::unordered_set& inteface_vars, + const std::vector& scalar_vars); + + // Replaces the load instruction |load| of the original interface variable or + // its part with a load from each replacement scalar variable from + // |replacement| followed by a composite construction. If target load is only + // transitively dependent on the replaced interface var, then the + // corresponding access chain |optional_access_chain| will be passed. + bool ReplaceLoad(Instruction* load, const Replacement& replacement, + Instruction* optional_access_chain, + uint32_t extra_array_length); + + // Replaces the store instruction |store| of the original interface variable + // or its part with a series of composite extracts and stores using the + // replacement scalar variables from |replacement|. If target load is only + // transitively dependent on the replaced interface var, then the + // corresponding access chain |optional_access_chain| will be passed. + bool ReplaceStore(Instruction* store, const Replacement& replacement, + Instruction* optional_access_chain, + uint32_t extra_array_length); + + // Looks up the replacement node according to the indices from the access + // chain |access_chain|, using the passed |root| as a base. If any index in + // the chain is non-constant or ouf-of-bound, return nullptr. If + // |extra_array_length| is not zero, the first index in the chain is skipped, + // as it is the one used for extra arrayness. + const Replacement* LookupReplacement(Instruction* access_chain, + const Replacement* root, + uint32_t extra_array_length); + + // Creates a variable with type |type_id| and storage class |storage_class|. + // Debug info for the newly created variable is copied from the source + // |debug_info_source|. + std::unique_ptr CreateVariable( + uint32_t type_id, spv::StorageClass storage_class, + const Instruction* debug_info_source, uint32_t extra_array_length); + + // Generate OpName instructions for the variable |destination_id|, based on + // the name of source variable |source_id| and a list of indices |indices| to + // make a suffix. + void GenerateNames(uint32_t source_id, uint32_t destination_id, + const std::vector& indices, + std::vector>* names_to_add); + + // Returns the pointee type of the type of variable |var|. + uint32_t GetPointeeTypeIdOfVar(Instruction* var); + + // Returns the result id of OpTypeArray instrunction whose Element Type + // operand is |elem_type_id| and Length operand is |array_length|. + uint32_t GetArrayType(uint32_t elem_type_id, uint32_t array_length); + + // Kills all OpDecorate instructions for Location and Component of the + // variable whose id is |var_id|. + void KillLocationAndComponentDecorations(uint32_t var_id); + + // If |var| has the extra arrayness for an entry point, reports an error and + // returns true. Otherwise, returns false. + bool ReportErrorIfHasExtraArraynessForOtherEntry(Instruction* var); + + // If |var| does not have the extra arrayness for an entry point, reports an + // error and returns true. Otherwise, returns false. + bool ReportErrorIfHasNoExtraArraynessForOtherEntry(Instruction* var); + + // If |var| has the extra arrayness for an entry point but it does not have + // one for another entry point, reports an error and returns false. Otherwise, + // returns true. + bool CheckExtraArraynessConflictBetweenEntries(InterfaceVar var); + + // Conducts the scalar replacement for the interface variables used by the + // |entry_point|. + Pass::Status ProcessEntryPoint(Instruction& entry_point); + + // A set of interface variables with the extra arrayness for any of the entry + // points. + std::unordered_set vars_with_extra_arrayness; + + // A set of interface variables without the extra arrayness for any of the + // entry points. + std::unordered_set vars_without_extra_arrayness; + + // Whether we need to replace matrix interface variables with scalars or not. + bool process_matrices_; +}; + +} // namespace opt +} // namespace spvtools + +#endif // SOURCE_OPT_ADV_INTERFACE_VAR_SROA_H_ diff --git a/source/opt/optimizer.cpp b/source/opt/optimizer.cpp index becef59551..f7a64a7c7d 100644 --- a/source/opt/optimizer.cpp +++ b/source/opt/optimizer.cpp @@ -655,6 +655,20 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag, } } else if (pass_name == "trim-capabilities") { RegisterPass(CreateTrimCapabilitiesPass()); + // UE Change Begin: Interface variable scalar replacement pass rewrite + } else if (pass_name == "adv-interface-variable-scalar-replacement") { + bool process_matrices = true; + if (pass_args == "skip-matrices") { + process_matrices = false; + } else if (pass_args.size() != 0) { + Errorf(consumer(), nullptr, {}, + "Invalid argument for --adv-interface-variable-scalar-replacement: %s " + "(must be 'skip-matrices' or absent)", + pass_args.c_str()); + return false; + } + RegisterPass(CreateAdvancedInterfaceVariableScalarReplacementPass(process_matrices)); + // UE Change End: Interface variable scalar replacement pass rewrite } else { Errorf(consumer(), nullptr, {}, "Unknown flag '--%s'. Use --help for a list of valid flags", @@ -1230,6 +1244,14 @@ Optimizer::PassToken CreateConvertCompositeToOpAccessChainPass() { } // UE Change End: Convert-Composite-To-Op-Access-Chain-Pass +// UE Change Begin: Interface variable scalar replacement pass rewrite +Optimizer::PassToken CreateAdvancedInterfaceVariableScalarReplacementPass( + bool process_matrices) { + return MakeUnique( + MakeUnique(process_matrices)); +} +// UE Change End: Interface variable scalar replacement pass rewrite + } // namespace spvtools extern "C" { diff --git a/source/opt/passes.h b/source/opt/passes.h index 092d13e4aa..02d1c20d46 100644 --- a/source/opt/passes.h +++ b/source/opt/passes.h @@ -17,6 +17,9 @@ // A single header to include all passes. +// UE Change Begin: Interface variable scalar replacement pass rewrite +#include "source/opt/adv_interface_var_sroa.h" +// UE Change End: Interface variable scalar replacement pass rewrite #include "source/opt/aggressive_dead_code_elim_pass.h" #include "source/opt/amd_ext_to_khr.h" #include "source/opt/analyze_live_input_pass.h" diff --git a/test/opt/CMakeLists.txt b/test/opt/CMakeLists.txt index 92d266bba3..c83d438bb8 100644 --- a/test/opt/CMakeLists.txt +++ b/test/opt/CMakeLists.txt @@ -16,7 +16,10 @@ add_subdirectory(dominator_tree) add_subdirectory(loop_optimizations) add_spvtools_unittest(TARGET opt - SRCS aggressive_dead_code_elim_test.cpp +# UE Change Begin: Interface variable scalar replacement pass rewrite + SRCS adv_interface_var_sroa_test.cpp +# UE Change End: Interface variable scalar replacement pass rewrite + aggressive_dead_code_elim_test.cpp amd_ext_to_khr.cpp analyze_live_input_test.cpp assembly_builder_test.cpp diff --git a/test/opt/adv_interface_var_sroa_test.cpp b/test/opt/adv_interface_var_sroa_test.cpp new file mode 100644 index 0000000000..97b4a807ad --- /dev/null +++ b/test/opt/adv_interface_var_sroa_test.cpp @@ -0,0 +1,548 @@ +// Copyright (c) 2022 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "test/opt/pass_fixture.h" +#include "test/opt/pass_utils.h" + +namespace spvtools { +namespace opt { +namespace { + +using AdvancedInterfaceVariableScalarReplacementTest = PassTest<::testing::Test>; + +TEST_F(AdvancedInterfaceVariableScalarReplacementTest, + ReplaceInterfaceVarsWithScalars) { + const std::string spirv = R"( + OpCapability Shader + OpCapability Tessellation + OpMemoryModel Logical GLSL450 + OpEntryPoint TessellationControl %func "shader" %x %y %z %w %u %v %gl_InvocationID + +; CHECK: OpName [[x:%\w+]] "x" +; CHECK-NOT: OpName {{%\w+}} "x" +; CHECK: OpName [[y:%\w+]] "y" +; CHECK-NOT: OpName {{%\w+}} "y" +; CHECK: OpName [[k:%\w+]] "k" +; CHECK-NOT: OpName {{%\w+}} "k" +; CHECK: OpName [[s:%\w+]] "s" +; CHECK-NOT: OpName {{%\w+}} "s" +; CHECK: OpName [[q:%\w+]] "q" +; CHECK-NOT: OpName {{%\w+}} "q" +; CHECK: OpName [[gl_InvocationID:%\w+]] "gl_InvocationID" +; CHECK-NOT: OpName {{%\w+}} "gl_InvocationID" +; CHECK: OpName [[z0:%\w+]] "z[0]" +; CHECK: OpName [[z1:%\w+]] "z[1]" +; CHECK: OpName [[w0:%\w+]] "w[0]" +; CHECK: OpName [[w1:%\w+]] "w[1]" +; CHECK: OpName [[u0:%\w+]] "u[0]" +; CHECK: OpName [[u1:%\w+]] "u[1]" +; CHECK: OpName [[v0:%\w+]] "v[0][0]" +; CHECK: OpName [[v1:%\w+]] "v[0][1]" +; CHECK: OpName [[v2:%\w+]] "v[1][0]" +; CHECK: OpName [[v3:%\w+]] "v[1][1]" +; CHECK: OpName [[v4:%\w+]] "v[2][0]" +; CHECK: OpName [[v5:%\w+]] "v[2][1]" + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpName %w "w" + OpName %u "u" + OpName %v "v" + OpName %k "k" + OpName %s "s" + OpName %q "q" + OpName %gl_InvocationID "gl_InvocationID" + +; CHECK-DAG: OpDecorate [[x]] Location 2 +; CHECK-DAG: OpDecorate [[y]] Location 0 +; CHECK-DAG: OpDecorate [[gl_InvocationID]] BuiltIn InvocationId +; CHECK-DAG: OpDecorate [[z0]] Location 0 +; CHECK-DAG: OpDecorate [[z0]] Component 0 +; CHECK-DAG: OpDecorate [[z1]] Location 1 +; CHECK-DAG: OpDecorate [[z1]] Component 0 +; CHECK-DAG: OpDecorate [[z0]] Patch +; CHECK-DAG: OpDecorate [[z1]] Patch +; CHECK-DAG: OpDecorate [[w0]] Location 2 +; CHECK-DAG: OpDecorate [[w0]] Component 0 +; CHECK-DAG: OpDecorate [[w1]] Location 3 +; CHECK-DAG: OpDecorate [[w1]] Component 0 +; CHECK-DAG: OpDecorate [[w0]] Patch +; CHECK-DAG: OpDecorate [[w1]] Patch +; CHECK-DAG: OpDecorate [[u0]] Location 3 +; CHECK-DAG: OpDecorate [[u0]] Component 2 +; CHECK-DAG: OpDecorate [[u1]] Location 4 +; CHECK-DAG: OpDecorate [[u1]] Component 2 +; CHECK-DAG: OpDecorate [[v0]] Location 3 +; CHECK-DAG: OpDecorate [[v0]] Component 3 +; CHECK-DAG: OpDecorate [[v1]] Location 4 +; CHECK-DAG: OpDecorate [[v1]] Component 3 +; CHECK-DAG: OpDecorate [[v2]] Location 5 +; CHECK-DAG: OpDecorate [[v2]] Component 3 +; CHECK-DAG: OpDecorate [[v3]] Location 6 +; CHECK-DAG: OpDecorate [[v3]] Component 3 +; CHECK-DAG: OpDecorate [[v4]] Location 7 +; CHECK-DAG: OpDecorate [[v4]] Component 3 +; CHECK-DAG: OpDecorate [[v5]] Location 8 +; CHECK-DAG: OpDecorate [[v5]] Component 3 + OpDecorate %z Patch + OpDecorate %w Patch + OpDecorate %z Location 0 + OpDecorate %x Location 2 + OpDecorate %v Location 3 + OpDecorate %v Component 3 + OpDecorate %y Location 0 + OpDecorate %w Location 2 + OpDecorate %u Location 3 + OpDecorate %u Component 2 + OpDecorate %gl_InvocationID BuiltIn InvocationId + + %int = OpTypeInt 32 1 + %uint = OpTypeInt 32 0 + %uint_1 = OpConstant %uint 1 + %uint_2 = OpConstant %uint 2 + %uint_3 = OpConstant %uint 3 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_2 = OpTypeArray %uint %uint_2 +%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2 +%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2 +%_ptr_Input_int = OpTypePointer Input %int +%_ptr_Input_uint = OpTypePointer Input %uint +%_ptr_Output_uint = OpTypePointer Output %uint +%_arr_arr_uint_uint_2_3 = OpTypeArray %_arr_uint_uint_2 %uint_3 +%_ptr_Input__arr_arr_uint_uint_2_3 = OpTypePointer Input %_arr_arr_uint_uint_2_3 +%_arr_arr_arr_uint_uint_2_3_4 = OpTypeArray %_arr_arr_uint_uint_2_3 %uint_4 +%_ptr_Output__arr_arr_arr_uint_uint_2_3_4 = OpTypePointer Output %_arr_arr_arr_uint_uint_2_3_4 +%_ptr_Output__arr_arr_uint_uint_2_3 = OpTypePointer Output %_arr_arr_uint_uint_2_3 +%_ptr_Function__arr__arr__arr_uint_uint_2_uint_3_uint_4 = OpTypePointer Function %_arr_arr_arr_uint_uint_2_3_4 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Function__arr_uint_uint_2 = OpTypePointer Function %_arr_uint_uint_2 + + %gl_InvocationID = OpVariable %_ptr_Input_int Input + %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input + %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input + %u = OpVariable %_ptr_Input__arr_arr_uint_uint_2_3 Input + %v = OpVariable %_ptr_Output__arr_arr_arr_uint_uint_2_3_4 Output + +; CHECK-DAG: [[x]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output +; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input +; CHECK-DAG: [[gl_InvocationID]] = OpVariable %_ptr_Input_int Input +; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output +; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output +; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input +; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input +; CHECK-DAG: [[u0]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input +; CHECK-DAG: [[u1]] = OpVariable %_ptr_Input__arr_uint_uint_3 Input +; CHECK-DAG: [[v0]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output +; CHECK-DAG: [[v1]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output +; CHECK-DAG: [[v2]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output +; CHECK-DAG: [[v3]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output +; CHECK-DAG: [[v4]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output +; CHECK-DAG: [[v5]] = OpVariable %_ptr_Output__arr_uint_uint_4 Output + + %void = OpTypeVoid + %void_f = OpTypeFunction %void + %func = OpFunction %void None %void_f + %label = OpLabel + + %k = OpVariable %_ptr_Function__arr__arr__arr_uint_uint_2_uint_3_uint_4 Function + %s = OpVariable %_ptr_Function_uint Function + %q = OpVariable %_ptr_Function__arr_uint_uint_2 Function +; CHECK-DAG: [[k]] = OpVariable %_ptr_Function__arr__arr__arr_uint_uint_2_uint_3_uint_4 Function +; CHECK-DAG: [[s]] = OpVariable %_ptr_Function_uint Function +; CHECK-DAG: [[q]] = OpVariable %_ptr_Function__arr_uint_uint_2 Function + +; CHECK: [[w0_value:%\w+]] = OpLoad %uint [[w0]] +; CHECK: [[w1_value:%\w+]] = OpLoad %uint [[w1]] +; CHECK: [[w_value:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[w0_value]] [[w1_value]] +; CHECK: [[w0:%\w+]] = OpCompositeExtract %uint [[w_value]] 0 +; CHECK: OpStore [[z0]] [[w0]] +; CHECK: [[w1:%\w+]] = OpCompositeExtract %uint [[w_value]] 1 +; CHECK: OpStore [[z1]] [[w1]] + %w_value = OpLoad %_arr_uint_uint_2 %w + OpStore %z %w_value + +; CHECK: [[u00_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_0 +; CHECK: [[u00:%\w+]] = OpLoad %uint [[u00_ptr]] +; CHECK: [[u10_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_0 +; CHECK: [[u10:%\w+]] = OpLoad %uint [[u10_ptr]] +; CHECK-DAG: [[u0_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u00]] [[u10]] +; CHECK: [[u01_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_1 +; CHECK: [[u01:%\w+]] = OpLoad %uint [[u01_ptr]] +; CHECK: [[u11_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_1 +; CHECK: [[u11:%\w+]] = OpLoad %uint [[u11_ptr]] +; CHECK-DAG: [[u1_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u01]] [[u11]] +; CHECK: [[u02_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u0]] %uint_2 +; CHECK: [[u02:%\w+]] = OpLoad %uint [[u02_ptr]] +; CHECK: [[u12_ptr:%\w+]] = OpAccessChain %_ptr_Input_uint [[u1]] %uint_2 +; CHECK: [[u12:%\w+]] = OpLoad %uint [[u12_ptr]] +; CHECK-DAG: [[u2_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[u02]] [[u12]] + +; CHECK: [[u_val:%\w+]] = OpCompositeConstruct %_arr__arr_uint_uint_2_uint_3 [[u0_val]] [[u1_val]] [[u2_val]] + +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_1 +; CHECK: OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 0 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_1 +; CHECK: OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_1 +; CHECK: OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 1 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_1 +; CHECK: OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_1 +; CHECK: OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[u_val]] 2 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_1 +; CHECK: OpStore [[ptr]] [[val]] + %v_ptr = OpAccessChain %_ptr_Output__arr_arr_uint_uint_2_3 %v %uint_1 + %u_val = OpLoad %_arr_arr_uint_uint_2_3 %u + OpStore %v_ptr %u_val + +; CHECK: [[k_val:%\w+]] = OpLoad %_arr__arr__arr_uint_uint_2_uint_3_uint_4 %k +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 0 0 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_0 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 1 0 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_1 +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 2 0 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_2 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 3 0 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v0]] %uint_3 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 0 0 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_0 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 1 0 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_1 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 2 0 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_2 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 3 0 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v1]] %uint_3 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 0 1 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_0 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 1 1 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_1 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 2 1 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_2 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 3 1 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v2]] %uint_3 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 0 1 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_0 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 1 1 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_1 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 2 1 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_2 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 3 1 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v3]] %uint_3 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 0 2 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_0 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 1 2 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_1 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 2 2 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_2 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 3 2 0 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_3 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 0 2 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_0 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 1 2 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_1 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 2 2 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_2 +; CHECK OpStore [[ptr]] [[val]] +; CHECK: [[val:%\w+]] = OpCompositeExtract %uint [[k_val]] 3 2 1 +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_3 +; CHECK OpStore [[ptr]] [[val]] + %k_val = OpLoad %_arr_arr_arr_uint_uint_2_3_4 %k + OpStore %v %k_val + +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_3 +; CHECK: [[val:%\w+]] = OpLoad %uint [[ptr]] +; CHECK: OpStore %s [[val]] + %v213_ptr = OpAccessChain %_ptr_Output_uint %v %uint_3 %uint_2 %uint_1 + %v213_val = OpLoad %uint %v213_ptr + OpStore %s %v213_val + +; CHECK: [[v320_ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] %uint_3 +; CHECK: [[v320_val:%\w+]] = OpLoad %uint [[v320_ptr]] +; CHECK: [[v321_ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] %uint_3 +; CHECK: [[v321_val:%\w+]] = OpLoad %uint [[v321_ptr]] +; CHECK: [[v32_val:%\w+]] = OpCompositeConstruct %_arr_uint_uint_2 [[v320_val]] [[v321_val]] +; CHECK: OpStore %q [[v32_val]] + %v32_ptr = OpAccessChain %_ptr_Output__arr_uint_uint_2 %v %uint_3 %uint_2 + %v32_val = OpLoad %_arr_uint_uint_2 %v32_ptr + OpStore %q %v32_val + +; CHECK: [[id:%\w+]] = OpLoad %int %gl_InvocationID +; CHECK: [[s_val:%\w+]] = OpLoad %uint %s +; CHECK: [[ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] [[id]] +; CHECK: OpStore [[ptr]] [[s_val]] + %id = OpLoad %int %gl_InvocationID + %s_val = OpLoad %uint %s + %vi21_ptr = OpAccessChain %_ptr_Output_uint %v %id %uint_2 %uint_1 + OpStore %vi21_ptr %s_val + +; CHECK: [[q_val:%\w+]] = OpLoad %_arr_uint_uint_2 %q +; CHECK: [[q0:%\w+]] = OpCompositeExtract %uint [[q_val]] 0 +; CHECK: [[vi20_ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v4]] [[id]] +; CHECK: OpStore [[vi20_ptr]] [[q0]] +; CHECK: [[q1:%\w+]] = OpCompositeExtract %uint [[q_val]] 1 +; CHECK: [[vi21_ptr:%\w+]] = OpAccessChain %_ptr_Output_uint [[v5]] [[id]] +; CHECK: OpStore [[vi21_ptr]] [[q1]] + %q_val = OpLoad %_arr_uint_uint_2 %q + %vi2_ptr = OpAccessChain %_ptr_Output__arr_uint_uint_2 %v %id %uint_2 + OpStore %vi2_ptr %q_val + + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spirv, true, true); +} + +TEST_F(AdvancedInterfaceVariableScalarReplacementTest, + CheckPatchDecorationPreservation) { + // Make sure scalars for the variables with the extra arrayness have the extra + // arrayness after running the pass while others do not have it. + // Only "y" does not have the extra arrayness in the following SPIR-V. + const std::string spirv = R"( + OpCapability Shader + OpCapability Tessellation + OpMemoryModel Logical GLSL450 + OpEntryPoint TessellationEvaluation %func "shader" %x %y %z %w + OpDecorate %z Patch + OpDecorate %w Patch + OpDecorate %z Location 0 + OpDecorate %x Location 2 + OpDecorate %y Location 0 + OpDecorate %w Location 1 + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpName %w "w" + + ; CHECK: OpName [[y:%\w+]] "y" + ; CHECK-NOT: OpName {{%\w+}} "y" + ; CHECK-DAG: OpName [[z0:%\w+]] "z[0]" + ; CHECK-DAG: OpName [[z1:%\w+]] "z[1]" + ; CHECK-DAG: OpName [[w0:%\w+]] "w[0]" + ; CHECK-DAG: OpName [[w1:%\w+]] "w[1]" + ; CHECK-DAG: OpName [[x0:%\w+]] "x[0]" + ; CHECK-DAG: OpName [[x1:%\w+]] "x[1]" + + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_uint_uint_2 = OpTypeArray %uint %uint_2 +%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2 +%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2 + %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input + %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input + + ; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input + ; CHECK-DAG: [[z0]] = OpVariable %_ptr_Output_uint Output + ; CHECK-DAG: [[z1]] = OpVariable %_ptr_Output_uint Output + ; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input + ; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input + ; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output + ; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output + + %void = OpTypeVoid + %void_f = OpTypeFunction %void + %func = OpFunction %void None %void_f + %label = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spirv, true, true); +} + +TEST_F(AdvancedInterfaceVariableScalarReplacementTest, + CheckEntryPointInterfaceOperands) { + const std::string spirv = R"( + OpCapability Shader + OpCapability Tessellation + OpMemoryModel Logical GLSL450 + OpEntryPoint TessellationEvaluation %tess "tess" %x %y + OpEntryPoint Vertex %vert "vert" %w + OpDecorate %z Location 0 + OpDecorate %x Location 2 + OpDecorate %y Location 0 + OpDecorate %w Location 1 + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpName %w "w" + + ; CHECK: OpName [[y:%\w+]] "y" + ; CHECK-DAG: OpName [[z:%\w+]] "z" + ; CHECK-NOT: OpName {{%\w+}} "z" + ; CHECK-NOT: OpName {{%\w+}} "y" + ; CHECK-DAG: OpName [[x0:%\w+]] "x[0]" + ; CHECK-DAG: OpName [[x1:%\w+]] "x[1]" + ; CHECK-DAG: OpName [[w0:%\w+]] "w[0]" + ; CHECK-DAG: OpName [[w1:%\w+]] "w[1]" + + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_uint_uint_2 = OpTypeArray %uint %uint_2 +%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2 +%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2 + %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input + %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input + + ; CHECK-DAG: [[y]] = OpVariable %_ptr_Input__arr_uint_uint_2 Input + ; CHECK-DAG: [[z]] = OpVariable %_ptr_Output__arr_uint_uint_2 Output + ; CHECK-DAG: [[w0]] = OpVariable %_ptr_Input_uint Input + ; CHECK-DAG: [[w1]] = OpVariable %_ptr_Input_uint Input + ; CHECK-DAG: [[x0]] = OpVariable %_ptr_Output_uint Output + ; CHECK-DAG: [[x1]] = OpVariable %_ptr_Output_uint Output + + %void = OpTypeVoid + %void_f = OpTypeFunction %void + %tess = OpFunction %void None %void_f + %bb0 = OpLabel + OpReturn + OpFunctionEnd + %vert = OpFunction %void None %void_f + %bb1 = OpLabel + OpReturn + OpFunctionEnd + )"; + + SinglePassRunAndMatch(spirv, true, true); +} + +class InterfaceVarSROAErrorTest : public PassTest<::testing::Test> { + public: + InterfaceVarSROAErrorTest() + : consumer_([this](spv_message_level_t level, const char*, + const spv_position_t& position, const char* message) { + if (!error_message_.empty()) error_message_ += "\n"; + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + error_message_ += "ERROR"; + break; + case SPV_MSG_WARNING: + error_message_ += "WARNING"; + break; + case SPV_MSG_INFO: + error_message_ += "INFO"; + break; + case SPV_MSG_DEBUG: + error_message_ += "DEBUG"; + break; + } + error_message_ += + ": " + std::to_string(position.index) + ": " + message; + }) {} + + Pass::Status RunPass(const std::string& text) { + std::unique_ptr context_ = + spvtools::BuildModule(SPV_ENV_UNIVERSAL_1_2, consumer_, text); + if (!context_.get()) return Pass::Status::Failure; + + PassManager manager; + manager.SetMessageConsumer(consumer_); + manager.AddPass(true); + + return manager.Run(context_.get()); + } + + std::string GetErrorMessage() const { return error_message_; } + + void TearDown() override { error_message_.clear(); } + + private: + spvtools::MessageConsumer consumer_; + std::string error_message_; +}; + +TEST_F(InterfaceVarSROAErrorTest, CheckConflictOfExtraArraynessBetweenEntries) { + const std::string spirv = R"( + OpCapability Shader + OpCapability Tessellation + OpMemoryModel Logical GLSL450 + OpEntryPoint TessellationControl %tess "tess" %x %y %z + OpEntryPoint Vertex %vert "vert" %z %w + OpDecorate %z Location 0 + OpDecorate %x Location 2 + OpDecorate %y Location 0 + OpDecorate %w Location 1 + OpName %x "x" + OpName %y "y" + OpName %z "z" + OpName %w "w" + %uint = OpTypeInt 32 0 + %uint_2 = OpConstant %uint 2 +%_arr_uint_uint_2 = OpTypeArray %uint %uint_2 +%_ptr_Output__arr_uint_uint_2 = OpTypePointer Output %_arr_uint_uint_2 +%_ptr_Input__arr_uint_uint_2 = OpTypePointer Input %_arr_uint_uint_2 + %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %x = OpVariable %_ptr_Output__arr_uint_uint_2 Output + %y = OpVariable %_ptr_Input__arr_uint_uint_2 Input + %w = OpVariable %_ptr_Input__arr_uint_uint_2 Input + %void = OpTypeVoid + %void_f = OpTypeFunction %void + %tess = OpFunction %void None %void_f + %bb0 = OpLabel + OpReturn + OpFunctionEnd + %vert = OpFunction %void None %void_f + %bb1 = OpLabel + OpReturn + OpFunctionEnd + )"; + + EXPECT_EQ(RunPass(spirv), Pass::Status::Failure); + const char expected_error[] = + "ERROR: 0: A variable is arrayed for an entry point but it is not " + "arrayed for another entry point\n" + " %z = OpVariable %_ptr_Output__arr_uint_uint_2 Output"; + EXPECT_STREQ(GetErrorMessage().c_str(), expected_error); +} + +} // namespace +} // namespace opt +} // namespace spvtools