Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Visitor to use CallInst #115

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions example/ExampleMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,12 @@ template <bool rpot> const Visitor<VisitorContainer> &getExampleVisitor() {
b.add<ReturnInst>([](raw_ostream &out, ReturnInst &ret) {
out << "visiting ReturnInst: " << ret << '\n';
});
b.add<CallInst>([](raw_ostream &out, CallInst &ret) {
out << "visiting CallInst: " << ret << '\n';
});
b.add<CallBrInst>([](raw_ostream &out, CallBrInst &ret) {
out << "visiting CallBrInst: " << ret << '\n';
});
b.addIntrinsic(
Intrinsic::umax, [](raw_ostream &out, IntrinsicInst &umax) {
out << "visiting umax intrinsic: " << umax << '\n';
Expand Down
31 changes: 11 additions & 20 deletions include/llvm-dialects/Dialect/OpMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -595,34 +595,26 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
if (std::get<BaseIteratorT>(m_iterator) == map->m_intrinsics.end())
invalidate();
}
} else {
createFromDialectOp(desc.getMnemonic());
} else if (!createFromDialectOp(desc.getMnemonic())) {
invalidate();
}
}

OpMapIteratorBase(OpMapT *map, const llvm::Function &func) : m_map{map} {
createFromFunc(func);
if (!createFromFunc(func))
invalidate();
}

// Do a lookup for a given instruction. Mark the iterator as invalid
// if the instruction is a call-like core instruction.
// Do a lookup for a given instruction.
OpMapIteratorBase(OpMapT *map, const llvm::Instruction &inst) : m_map{map} {
if (auto *CI = llvm::dyn_cast<llvm::CallInst>(&inst)) {
TylerNowicki marked this conversation as resolved.
Show resolved Hide resolved
const llvm::Function *callee = CI->getCalledFunction();
if (callee) {
createFromFunc(*callee);
if (callee && createFromFunc(*callee))
return;
}
}

const unsigned op = inst.getOpcode();

// Construct an invalid iterator.
if (op == llvm::Instruction::Call || op == llvm::Instruction::CallBr) {
invalidate();
return;
}

BaseIteratorT it = m_map->m_coreOpcodes.find(op);
if (it != m_map->m_coreOpcodes.end()) {
m_desc = OpDescription::fromCoreOp(op);
Expand Down Expand Up @@ -699,20 +691,20 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
private:
void invalidate() { m_isInvalid = true; }

void createFromFunc(const llvm::Function &func) {
bool createFromFunc(const llvm::Function &func) {
if (func.isIntrinsic()) {
m_iterator = m_map->m_intrinsics.find(func.getIntrinsicID());

if (std::get<BaseIteratorT>(m_iterator) != m_map->m_intrinsics.end()) {
m_desc = OpDescription::fromIntrinsic(func.getIntrinsicID());
return;
return true;
}
}

createFromDialectOp(func.getName());
return createFromDialectOp(func.getName());
}

void createFromDialectOp(llvm::StringRef funcName) {
bool createFromDialectOp(llvm::StringRef funcName) {
size_t idx = 0;
bool found = false;
for (auto &dialectOpKV : m_map->m_dialectOps) {
Expand All @@ -729,8 +721,7 @@ template <typename ValueT, bool isConst> class OpMapIteratorBase final {
++idx;
}

if (!found)
invalidate();
return found;
}

// Re-construct base OpDescription from the stored iterator.
Expand Down
16 changes: 16 additions & 0 deletions test/example/visitor-basic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
; DEFAULT-NEXT: %q =
; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q)
; DEFAULT-NEXT: visiting StringAttrOp: Hello world!
; DEFAULT-NEXT: visiting CallInst: %0 = call i32 @op.func(i32 %v1, i32 %q)
; DEFAULT-NEXT: visiting CallBrInst: callbr void @callee()
; DEFAULT-NEXT: to label %continueBB [label %label1BB, label %label2BB]
; DEFAULT-NEXT: visiting Ret (set): ret void
; DEFAULT-NEXT: visiting ReturnInst: ret void
; DEFAULT-NEXT: inner.counter = 1
Expand All @@ -40,9 +43,21 @@ entry:
call void (...) @xd.ir.write.vararg(i8 %t, i32 %v2, i32 %q)
%vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q)
call void @xd.ir.string.attr.op(ptr @0)
call i32 @op.func(i32 %v1, i32 %q)
callbr void @callee() to label %continueBB [label %label1BB, label %label2BB]
ret void

continueBB:
br label %entry

label1BB:
br label %entry

label2BB:
br label %entry
}

declare void @callee()
declare i32 @xd.ir.read__i32()
declare i1 @xd.ir.set.read__i1()
declare i32 @xd.ir.set.read__i32()
Expand All @@ -53,3 +68,4 @@ declare i8 @xd.ir.itrunc__i8(...)
declare void @xd.ir.string.attr.op(ptr)
declare i32 @llvm.umax.i32(i32, i32)
declare i32 @llvm.umin.i32(i32, i32)
declare void @op.func(i32, i32)
69 changes: 66 additions & 3 deletions test/unit/interface/OpMapIRTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ TEST_F(OpMapIRTestFixture, IntrinsicOpMatchesInstructionTest) {
EXPECT_EQ(map[AssumeDesc], "assume");

const auto &SideEffect = *B.CreateCall(
Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect));
Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::sideeffect));

const std::array<Value *, 1> AssumeArgs = {
ConstantInt::getBool(Type::getInt1Ty(Context), true)};
const auto &Assume = *B.CreateCall(
Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume), AssumeArgs);
Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::assume),
AssumeArgs);

EXPECT_FALSE(map.lookup(SideEffect) == map.lookup(Assume));
EXPECT_EQ(map.lookup(SideEffect), "sideeffect");
Expand Down Expand Up @@ -171,7 +172,7 @@ TEST_F(OpMapIRTestFixture, MixedOpMatchesInstructionTest) {
EXPECT_EQ(map[SideEffectDesc], "sideeffect");

const auto &SideEffect = *B.CreateCall(
Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect));
Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::sideeffect));

EXPECT_EQ(map.lookup(SideEffect), "sideeffect");

Expand Down Expand Up @@ -252,3 +253,65 @@ TEST_F(OpMapIRTestFixture, DialectOpOverloadTests) {
EXPECT_EQ(map.lookup(Op1), "DialectOp4");
EXPECT_EQ(map.lookup(Op2), "DialectOp4");
}

TEST_F(OpMapIRTestFixture, CallCoreOpMatchesInstructionTest) {
OpMap<StringRef> map;
llvm_dialects::Builder B{Context};

// Define types
PointerType *PtrTy = B.getPtrTy();
IntegerType *I32Ty = Type::getInt32Ty(Context);

// Declare: ptr @ProcOpaqueHandle(i32, ptr)
FunctionType *ProcOpaqueHandleFuncTy =
FunctionType::get(PtrTy, {I32Ty, PtrTy}, false);
FunctionCallee ProcOpaqueHandleFunc =
Mod->getOrInsertFunction("ProcOpaqueHandle", ProcOpaqueHandleFuncTy);

B.SetInsertPoint(getEntryBlock());

// Declare %OpaqueTy = type opaque
StructType *OpaqueTy = StructType::create(Context, "OpaqueTy");

// Create a dummy global variable of type %OpaqueTy*
GlobalVariable *GV = new GlobalVariable(
*Mod, OpaqueTy, false, GlobalValue::PrivateLinkage, nullptr, "handle");
GV->setInitializer(ConstantAggregateZero::get(OpaqueTy));
Value *Op2 = GV;

// Create a constant value (e.g., 123)
Value *Op1 = B.getInt32(123);

// Build a call instruction
Value *Args[] = {Op1, Op2};
const CallInst &Call = *B.CreateCall(ProcOpaqueHandleFunc, Args);

// Create basic blocks for the function
auto *FC = getEntryBlock()->getParent();
BasicBlock *Label1BB = BasicBlock::Create(Context, "label1", FC);
BasicBlock *Label2BB = BasicBlock::Create(Context, "label2", FC);
BasicBlock *ContinueBB = BasicBlock::Create(Context, "continue", FC);

// Simulate a function that can branch to multiple labels
// For demonstration purposes, we'll create a placeholder function that represents this behavior
FunctionType *BranchFuncTy = FunctionType::get(Type::getVoidTy(Context), false);
FunctionCallee BranchFunc = Mod->getOrInsertFunction("Branch", BranchFuncTy);

// Create the CallBr instruction
const CallBrInst &CallBr = *B.CreateCallBr(BranchFunc, ContinueBB, {Label1BB, Label2BB});

// Load and test OpMap with Call and CallBr

// Add Instruction::Call to OpMap
const OpDescription CallDesc = OpDescription::fromCoreOp(Instruction::Call);
TylerNowicki marked this conversation as resolved.
Show resolved Hide resolved
map[CallDesc] = "Call";

// Add Instruction::CallBr to OpMap
const OpDescription CallBrDesc = OpDescription::fromCoreOp(Instruction::CallBr);
map[CallBrDesc] = "CallBr";

// Look up the Call and CallBr in the map and verify it finds the entries for
// Instruction::Call and Instruction::CallBr
EXPECT_EQ(map.lookup(Call), "Call");
EXPECT_EQ(map.lookup(CallBr), "CallBr");
}