diff --git a/example/ExampleMain.cpp b/example/ExampleMain.cpp index d6371b9..18b4c5f 100644 --- a/example/ExampleMain.cpp +++ b/example/ExampleMain.cpp @@ -264,6 +264,12 @@ template const Visitor &getExampleVisitor() { b.add([](raw_ostream &out, ReturnInst &ret) { out << "visiting ReturnInst: " << ret << '\n'; }); + b.add([](raw_ostream &out, CallInst &ret) { + out << "visiting CallInst: " << ret << '\n'; + }); + b.add([](raw_ostream &out, CallBrInst &ret) { + out << "visiting CallBrInst: " << ret << '\n'; + }); b.addIntrinsic( Intrinsic::umax, [](raw_ostream &out, IntrinsicInst &umax) { out << "visiting umax intrinsic: " << umax << '\n'; diff --git a/include/llvm-dialects/Dialect/OpMap.h b/include/llvm-dialects/Dialect/OpMap.h index 1d18bfd..5873fea 100644 --- a/include/llvm-dialects/Dialect/OpMap.h +++ b/include/llvm-dialects/Dialect/OpMap.h @@ -595,34 +595,26 @@ template class OpMapIteratorBase final { if (std::get(m_iterator) == map->m_intrinsics.end()) invalidate(); } - } else { - createFromDialectOp(desc.getMnemonic()); + } else if (!createFromDialectOp(desc.getMnemonic())) { + invalidate(); } } OpMapIteratorBase(OpMapT *map, const llvm::Function &func) : m_map{map} { - createFromFunc(func); + if (!createFromFunc(func)) + invalidate(); } - // Do a lookup for a given instruction. Mark the iterator as invalid - // if the instruction is a call-like core instruction. + // Do a lookup for a given instruction. OpMapIteratorBase(OpMapT *map, const llvm::Instruction &inst) : m_map{map} { if (auto *CI = llvm::dyn_cast(&inst)) { const llvm::Function *callee = CI->getCalledFunction(); - if (callee) { - createFromFunc(*callee); + if (callee && createFromFunc(*callee)) return; - } } const unsigned op = inst.getOpcode(); - // Construct an invalid iterator. - if (op == llvm::Instruction::Call || op == llvm::Instruction::CallBr) { - invalidate(); - return; - } - BaseIteratorT it = m_map->m_coreOpcodes.find(op); if (it != m_map->m_coreOpcodes.end()) { m_desc = OpDescription::fromCoreOp(op); @@ -699,20 +691,20 @@ template class OpMapIteratorBase final { private: void invalidate() { m_isInvalid = true; } - void createFromFunc(const llvm::Function &func) { + bool createFromFunc(const llvm::Function &func) { if (func.isIntrinsic()) { m_iterator = m_map->m_intrinsics.find(func.getIntrinsicID()); if (std::get(m_iterator) != m_map->m_intrinsics.end()) { m_desc = OpDescription::fromIntrinsic(func.getIntrinsicID()); - return; + return true; } } - createFromDialectOp(func.getName()); + return createFromDialectOp(func.getName()); } - void createFromDialectOp(llvm::StringRef funcName) { + bool createFromDialectOp(llvm::StringRef funcName) { size_t idx = 0; bool found = false; for (auto &dialectOpKV : m_map->m_dialectOps) { @@ -729,8 +721,7 @@ template class OpMapIteratorBase final { ++idx; } - if (!found) - invalidate(); + return found; } // Re-construct base OpDescription from the stored iterator. diff --git a/test/example/visitor-basic.ll b/test/example/visitor-basic.ll index 919a429..49a336c 100644 --- a/test/example/visitor-basic.ll +++ b/test/example/visitor-basic.ll @@ -17,6 +17,9 @@ ; DEFAULT-NEXT: %q = ; DEFAULT-NEXT: visiting umin (set): %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) ; DEFAULT-NEXT: visiting StringAttrOp: Hello world! +; DEFAULT-NEXT: visiting CallInst: %0 = call i32 @op.func(i32 %v1, i32 %q) +; DEFAULT-NEXT: visiting CallBrInst: callbr void @callee() +; DEFAULT-NEXT: to label %continueBB [label %label1BB, label %label2BB] ; DEFAULT-NEXT: visiting Ret (set): ret void ; DEFAULT-NEXT: visiting ReturnInst: ret void ; DEFAULT-NEXT: inner.counter = 1 @@ -40,9 +43,21 @@ entry: call void (...) @xd.ir.write.vararg(i8 %t, i32 %v2, i32 %q) %vm = call i32 @llvm.umin.i32(i32 %v1, i32 %q) call void @xd.ir.string.attr.op(ptr @0) + call i32 @op.func(i32 %v1, i32 %q) + callbr void @callee() to label %continueBB [label %label1BB, label %label2BB] ret void + +continueBB: + br label %entry + +label1BB: + br label %entry + +label2BB: + br label %entry } +declare void @callee() declare i32 @xd.ir.read__i32() declare i1 @xd.ir.set.read__i1() declare i32 @xd.ir.set.read__i32() @@ -53,3 +68,4 @@ declare i8 @xd.ir.itrunc__i8(...) declare void @xd.ir.string.attr.op(ptr) declare i32 @llvm.umax.i32(i32, i32) declare i32 @llvm.umin.i32(i32, i32) +declare void @op.func(i32, i32) diff --git a/test/unit/interface/OpMapIRTests.cpp b/test/unit/interface/OpMapIRTests.cpp index a0aa6c8..4895394 100644 --- a/test/unit/interface/OpMapIRTests.cpp +++ b/test/unit/interface/OpMapIRTests.cpp @@ -118,12 +118,13 @@ TEST_F(OpMapIRTestFixture, IntrinsicOpMatchesInstructionTest) { EXPECT_EQ(map[AssumeDesc], "assume"); const auto &SideEffect = *B.CreateCall( - Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect)); + Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::sideeffect)); const std::array AssumeArgs = { ConstantInt::getBool(Type::getInt1Ty(Context), true)}; const auto &Assume = *B.CreateCall( - Intrinsic::getDeclaration(Mod.get(), Intrinsic::assume), AssumeArgs); + Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::assume), + AssumeArgs); EXPECT_FALSE(map.lookup(SideEffect) == map.lookup(Assume)); EXPECT_EQ(map.lookup(SideEffect), "sideeffect"); @@ -171,7 +172,7 @@ TEST_F(OpMapIRTestFixture, MixedOpMatchesInstructionTest) { EXPECT_EQ(map[SideEffectDesc], "sideeffect"); const auto &SideEffect = *B.CreateCall( - Intrinsic::getDeclaration(Mod.get(), Intrinsic::sideeffect)); + Intrinsic::getOrInsertDeclaration(Mod.get(), Intrinsic::sideeffect)); EXPECT_EQ(map.lookup(SideEffect), "sideeffect"); @@ -252,3 +253,65 @@ TEST_F(OpMapIRTestFixture, DialectOpOverloadTests) { EXPECT_EQ(map.lookup(Op1), "DialectOp4"); EXPECT_EQ(map.lookup(Op2), "DialectOp4"); } + +TEST_F(OpMapIRTestFixture, CallCoreOpMatchesInstructionTest) { + OpMap map; + llvm_dialects::Builder B{Context}; + + // Define types + PointerType *PtrTy = B.getPtrTy(); + IntegerType *I32Ty = Type::getInt32Ty(Context); + + // Declare: ptr @ProcOpaqueHandle(i32, ptr) + FunctionType *ProcOpaqueHandleFuncTy = + FunctionType::get(PtrTy, {I32Ty, PtrTy}, false); + FunctionCallee ProcOpaqueHandleFunc = + Mod->getOrInsertFunction("ProcOpaqueHandle", ProcOpaqueHandleFuncTy); + + B.SetInsertPoint(getEntryBlock()); + + // Declare %OpaqueTy = type opaque + StructType *OpaqueTy = StructType::create(Context, "OpaqueTy"); + + // Create a dummy global variable of type %OpaqueTy* + GlobalVariable *GV = new GlobalVariable( + *Mod, OpaqueTy, false, GlobalValue::PrivateLinkage, nullptr, "handle"); + GV->setInitializer(ConstantAggregateZero::get(OpaqueTy)); + Value *Op2 = GV; + + // Create a constant value (e.g., 123) + Value *Op1 = B.getInt32(123); + + // Build a call instruction + Value *Args[] = {Op1, Op2}; + const CallInst &Call = *B.CreateCall(ProcOpaqueHandleFunc, Args); + + // Create basic blocks for the function + auto *FC = getEntryBlock()->getParent(); + BasicBlock *Label1BB = BasicBlock::Create(Context, "label1", FC); + BasicBlock *Label2BB = BasicBlock::Create(Context, "label2", FC); + BasicBlock *ContinueBB = BasicBlock::Create(Context, "continue", FC); + + // Simulate a function that can branch to multiple labels + // For demonstration purposes, we'll create a placeholder function that represents this behavior + FunctionType *BranchFuncTy = FunctionType::get(Type::getVoidTy(Context), false); + FunctionCallee BranchFunc = Mod->getOrInsertFunction("Branch", BranchFuncTy); + + // Create the CallBr instruction + const CallBrInst &CallBr = *B.CreateCallBr(BranchFunc, ContinueBB, {Label1BB, Label2BB}); + + // Load and test OpMap with Call and CallBr + + // Add Instruction::Call to OpMap + const OpDescription CallDesc = OpDescription::fromCoreOp(Instruction::Call); + map[CallDesc] = "Call"; + + // Add Instruction::CallBr to OpMap + const OpDescription CallBrDesc = OpDescription::fromCoreOp(Instruction::CallBr); + map[CallBrDesc] = "CallBr"; + + // Look up the Call and CallBr in the map and verify it finds the entries for + // Instruction::Call and Instruction::CallBr + EXPECT_EQ(map.lookup(Call), "Call"); + EXPECT_EQ(map.lookup(CallBr), "CallBr"); +}