Skip to content

Commit

Permalink
spirv-link: allow linking functions with different pointer arguments (#…
Browse files Browse the repository at this point in the history
…5534)

* linker: run dedup earlier

Otherwise `linkings_to_do` might end up with stale IDs.

* linker: allow linking functions with different pointer arguments

Since llvm-17 there are no typed pointers and hte SPIRV-LLVM-Translator
doesn't know the function signature of imported functions.

I'm investigating different ways of solving this problem and adding an
option to work around it inside `spirv-link` is one of those.

The code is almost complete, just I'm having troubles constructing the
bitcast to cast the pointer parameters to the final type.

Closes: KhronosGroup/SPIRV-LLVM-Translator#2153

* test/linker: add tests to test the AllowPtrTypeMismatch feature
  • Loading branch information
karolherbst authored Jul 24, 2024
1 parent ca37349 commit e99a5c0
Show file tree
Hide file tree
Showing 4 changed files with 498 additions and 49 deletions.
7 changes: 6 additions & 1 deletion include/spirv-tools/linker.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#define INCLUDE_SPIRV_TOOLS_LINKER_HPP_

#include <cstdint>

#include <memory>
#include <vector>

Expand Down Expand Up @@ -63,11 +62,17 @@ class SPIRV_TOOLS_EXPORT LinkerOptions {
use_highest_version_ = use_highest_vers;
}

bool GetAllowPtrTypeMismatch() const { return allow_ptr_type_mismatch_; }
void SetAllowPtrTypeMismatch(bool allow_ptr_type_mismatch) {
allow_ptr_type_mismatch_ = allow_ptr_type_mismatch;
}

private:
bool create_library_{false};
bool verify_ids_{false};
bool allow_partial_linkage_{false};
bool use_highest_version_{false};
bool allow_ptr_type_mismatch_{false};
};

// Links one or more SPIR-V modules into a new SPIR-V module. That is, combine
Expand Down
123 changes: 107 additions & 16 deletions source/link/linker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "source/opt/build_module.h"
#include "source/opt/compact_ids_pass.h"
#include "source/opt/decoration_manager.h"
#include "source/opt/ir_builder.h"
#include "source/opt/ir_loader.h"
#include "source/opt/pass_manager.h"
#include "source/opt/remove_duplicates_pass.h"
Expand All @@ -46,12 +47,14 @@ namespace spvtools {
namespace {

using opt::Instruction;
using opt::InstructionBuilder;
using opt::IRContext;
using opt::Module;
using opt::PassManager;
using opt::RemoveDuplicatesPass;
using opt::analysis::DecorationManager;
using opt::analysis::DefUseManager;
using opt::analysis::Function;
using opt::analysis::Type;
using opt::analysis::TypeManager;

Expand Down Expand Up @@ -126,6 +129,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,
// checked.
spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
const LinkageTable& linkings_to_do,
bool allow_ptr_type_mismatch,
opt::IRContext* context);

// Remove linkage specific instructions, such as prototypes of imported
Expand Down Expand Up @@ -502,6 +506,7 @@ spv_result_t GetImportExportPairs(const MessageConsumer& consumer,

spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
const LinkageTable& linkings_to_do,
bool allow_ptr_type_mismatch,
opt::IRContext* context) {
spv_position_t position = {};

Expand All @@ -513,14 +518,42 @@ spv_result_t CheckImportExportCompatibility(const MessageConsumer& consumer,
type_manager.GetType(linking_entry.imported_symbol.type_id);
Type* exported_symbol_type =
type_manager.GetType(linking_entry.exported_symbol.type_id);
if (!(*imported_symbol_type == *exported_symbol_type))
if (!(*imported_symbol_type == *exported_symbol_type)) {
Function* imported_symbol_type_func = imported_symbol_type->AsFunction();
Function* exported_symbol_type_func = exported_symbol_type->AsFunction();

if (imported_symbol_type_func && exported_symbol_type_func) {
const auto& imported_params = imported_symbol_type_func->param_types();
const auto& exported_params = exported_symbol_type_func->param_types();
// allow_ptr_type_mismatch allows linking functions where the pointer
// type of arguments doesn't match. Everything else still needs to be
// equal. This is to workaround LLVM-17+ not having typed pointers and
// generated SPIR-Vs not knowing the actual pointer types in some cases.
if (allow_ptr_type_mismatch &&
imported_params.size() == exported_params.size()) {
bool correct = true;
for (size_t i = 0; i < imported_params.size(); i++) {
const auto& imported_param = imported_params[i];
const auto& exported_param = exported_params[i];

if (!imported_param->IsSame(exported_param) &&
(imported_param->kind() != Type::kPointer ||
exported_param->kind() != Type::kPointer)) {
correct = false;
break;
}
}
if (correct) continue;
}
}
return DiagnosticStream(position, consumer, "", SPV_ERROR_INVALID_BINARY)
<< "Type mismatch on symbol \""
<< linking_entry.imported_symbol.name
<< "\" between imported variable/function %"
<< linking_entry.imported_symbol.id
<< " and exported variable/function %"
<< linking_entry.exported_symbol.id << ".";
}
}

// Ensure the import and export decorations are similar
Expand Down Expand Up @@ -696,6 +729,57 @@ spv_result_t VerifyLimits(const MessageConsumer& consumer,
return SPV_SUCCESS;
}

spv_result_t FixFunctionCallTypes(opt::IRContext& context,
const LinkageTable& linkings) {
auto mod = context.module();
const auto type_manager = context.get_type_mgr();
const auto def_use_mgr = context.get_def_use_mgr();

for (auto& func : *mod) {
func.ForEachInst([&](Instruction* inst) {
if (inst->opcode() != spv::Op::OpFunctionCall) return;
opt::Operand& target = inst->GetInOperand(0);

// only fix calls to imported functions
auto linking = std::find_if(
linkings.begin(), linkings.end(), [&](const auto& entry) {
return entry.exported_symbol.id == target.AsId();
});
if (linking == linkings.end()) return;

auto builder = InstructionBuilder(&context, inst);
for (uint32_t i = 1; i < inst->NumInOperands(); ++i) {
auto exported_func_param =
def_use_mgr->GetDef(linking->exported_symbol.parameter_ids[i - 1]);
const Type* target_type =
type_manager->GetType(exported_func_param->type_id());
if (target_type->kind() != Type::kPointer) continue;

opt::Operand& arg = inst->GetInOperand(i);
const Type* param_type =
type_manager->GetType(def_use_mgr->GetDef(arg.AsId())->type_id());

// No need to cast if it already matches
if (*param_type == *target_type) continue;

auto new_id = context.TakeNextId();

// cast to the expected pointer type
builder.AddInstruction(MakeUnique<opt::Instruction>(
&context, spv::Op::OpBitcast, exported_func_param->type_id(),
new_id,
opt::Instruction::OperandList(
{{SPV_OPERAND_TYPE_ID, {arg.AsId()}}})));

inst->SetInOperand(i, {new_id});
}
});
}
context.InvalidateAnalyses(opt::IRContext::kAnalysisDefUse |
opt::IRContext::kAnalysisInstrToBlockMapping);
return SPV_SUCCESS;
}

} // namespace

spv_result_t Link(const Context& context,
Expand Down Expand Up @@ -773,26 +857,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
if (res != SPV_SUCCESS) return res;
}

// Phase 4: Find the import/export pairs
// Phase 4: Remove duplicates
PassManager manager;
manager.SetMessageConsumer(consumer);
manager.AddPass<RemoveDuplicatesPass>();
opt::Pass::Status pass_res = manager.Run(&linked_context);
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;

// Phase 5: Find the import/export pairs
LinkageTable linkings_to_do;
res = GetImportExportPairs(consumer, linked_context,
*linked_context.get_def_use_mgr(),
*linked_context.get_decoration_mgr(),
options.GetAllowPartialLinkage(), &linkings_to_do);
if (res != SPV_SUCCESS) return res;

// Phase 5: Ensure the import and export have the same types and decorations.
res =
CheckImportExportCompatibility(consumer, linkings_to_do, &linked_context);
// Phase 6: Ensure the import and export have the same types and decorations.
res = CheckImportExportCompatibility(consumer, linkings_to_do,
options.GetAllowPtrTypeMismatch(),
&linked_context);
if (res != SPV_SUCCESS) return res;

// Phase 6: Remove duplicates
PassManager manager;
manager.SetMessageConsumer(consumer);
manager.AddPass<RemoveDuplicatesPass>();
opt::Pass::Status pass_res = manager.Run(&linked_context);
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;

// Phase 7: Remove all names and decorations of import variables/functions
for (const auto& linking_entry : linkings_to_do) {
linked_context.KillNamesAndDecorates(linking_entry.imported_symbol.id);
Expand All @@ -815,21 +900,27 @@ spv_result_t Link(const Context& context, const uint32_t* const* binaries,
&linked_context);
if (res != SPV_SUCCESS) return res;

// Phase 10: Compact the IDs used in the module
// Phase 10: Optionally fix function call types
if (options.GetAllowPtrTypeMismatch()) {
res = FixFunctionCallTypes(linked_context, linkings_to_do);
if (res != SPV_SUCCESS) return res;
}

// Phase 11: Compact the IDs used in the module
manager.AddPass<opt::CompactIdsPass>();
pass_res = manager.Run(&linked_context);
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;

// Phase 11: Recompute EntryPoint variables
// Phase 12: Recompute EntryPoint variables
manager.AddPass<opt::RemoveUnusedInterfaceVariablesPass>();
pass_res = manager.Run(&linked_context);
if (pass_res == opt::Pass::Status::Failure) return SPV_ERROR_INVALID_DATA;

// Phase 12: Warn if SPIR-V limits were exceeded
// Phase 13: Warn if SPIR-V limits were exceeded
res = VerifyLimits(consumer, linked_context);
if (res != SPV_SUCCESS) return res;

// Phase 13: Output the module
// Phase 14: Output the module
linked_context.module()->ToBinary(linked_binary, true);

return SPV_SUCCESS;
Expand Down
Loading

0 comments on commit e99a5c0

Please sign in to comment.