Skip to content

Commit

Permalink
[OpenMP] Implement support for the scan directive in target regions
Browse files Browse the repository at this point in the history
The `scan` directive is already supported on the host side. This patch
implements the directive inside a target region, more specifically, the
`target teams distribute parallel for` construct. The clang/CodeGen
takes care of splitting the 'for' loop-body into the 'BeforeScan' and
'AfterScan' blocks based upon the location of the `scan` directive. The
code generator generates two kernels - one for each of the
aforementioned blocks to ensure thread synchronization across the whole
grid. The code generator for the first kernel embeds a call to the
DeviceRTL entry point that implements the Cross-Team Parallel Prefix-
-sum/Scan algorithm. The Cross-Team scan codegen machinery is activated
by passing the option `-fopenmp-target-xteam-scan`. This patch also
adds IR tests and execution tests for the implementation.

This is the No-Loop Xteam Scan Kernel codegen which assumes that loop tripcount is same as (num_threads * num_teams). Working on Segmented scan which will accommodate all loop trip count sizes agnostic from num_teams and num_threads.

For more details, visit: https://confluence.amd.com/x/mu6HJg

Change-Id: Id01381a5f54fd919370155391bc98d126002ddf4
  • Loading branch information
animeshk-amd committed Dec 20, 2024
1 parent b8f9d29 commit 5ef5244
Show file tree
Hide file tree
Showing 21 changed files with 4,122 additions and 131 deletions.
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -11636,6 +11636,8 @@ def err_omp_inscan_reduction_expected : Error<
"expected 'reduction' clause with the 'inscan' modifier">;
def note_omp_previous_inscan_reduction : Note<
"'reduction' clause with 'inscan' modifier is used here">;
def err_omp_multivar_xteam_scan_unsupported : Error<
"multiple list items are not yet supported with the 'inclusive' or the 'exclusive' clauses that appear with the 'scan' directive">;
def err_omp_expected_predefined_allocator : Error<
"expected one of the predefined allocators for the variables with the static "
"storage: 'omp_default_mem_alloc', 'omp_large_cap_mem_alloc', "
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/LangOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ LANGOPT(OpenMPTargetNoLoop , 1, 1, "Use no-loop code generation technique.")
LANGOPT(OpenMPTargetXteamReduction , 1, 1, "Use cross-team code generation technique.")
LANGOPT(OpenMPTargetFastReduction , 1, 0, "Use fast reduction code generation technique.")
LANGOPT(OpenMPTargetMultiDevice , 1, 0, "Offload the iteration space of a single target region across multiple GPU devices.")
LANGOPT(OpenMPTargetXteamScan , 1, 0, "Use the cross-team scan code generation technique.")
LANGOPT(OpenMPOptimisticCollapse , 1, 0, "Use at most 32 bits to represent the collapsed loop nest counter.")
LANGOPT(OpenMPThreadSubscription , 1, 0, "Assume work-shared loops do not have more iterations than participating threads.")
LANGOPT(OpenMPTeamSubscription , 1, 0, "Assume distributed loops do not have more iterations than participating teams.")
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -3734,6 +3734,14 @@ def fno_openmp_target_fast_reduction : Flag<["-"], "fno-openmp-target-fast-reduc
Flags<[NoArgumentUnused, HelpHidden]>, Visibility<[ClangOption, CC1Option, FlangOption]>,
HelpText<"Do not use the fast reduction code generation technique">,
MarshallingInfoFlag<LangOpts<"OpenMPTargetFastReduction">>;
def fopenmp_target_xteam_scan : Flag<["-"], "fopenmp-target-xteam-scan">, Group<f_Group>,
Flags<[NoArgumentUnused, HelpHidden]>, Visibility<[ClangOption, CC1Option]>,
HelpText<"Use the cross-team scan code generation technique.">,
MarshallingInfoFlag<LangOpts<"OpenMPTargetXteamScan">>;
def fno_openmp_target_xteam_scan : Flag<["-"], "fno-openmp-target-xteam-scan">, Group<f_Group>,
Flags<[NoArgumentUnused, HelpHidden]>, Visibility<[ClangOption, CC1Option]>,
HelpText<"Do not use the cross-team scan code generation technique.">,
MarshallingInfoFlag<LangOpts<"OpenMPTargetXteamScan">>;
def fopenmp_target_multi_device : Flag<["-"], "fopenmp-target-multi-device">, Group<f_Group>,
Flags<[NoArgumentUnused, HelpHidden]>, Visibility<[ClangOption, CC1Option, FlangOption]>,
HelpText<"Enable code generation to emit support for multi device target region execution">,
Expand Down
256 changes: 156 additions & 100 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9464,7 +9464,13 @@ static void emitTargetCallKernelLaunch(
CodeGenModule::XteamRedVarMap &XteamRVM = CGF.CGM.getXteamRedVarMap(FStmt);
auto &XteamOrdVars = CGF.CGM.getXteamOrderedRedVar(FStmt);

assert((CapturedVars.size() == CapturedCount + 2 * XteamRVM.size()) &&
// The Xteam Reduction kernels require two helper variables - `team_vals`
// array and `teams_done_ptr`.
// The Xteam Scan Reduction kernels require a third helper variable -
// `scan_storage` array.
int ExpectedNumArgs = CGF.CGM.isXteamScanKernel() ? 3 : 2;
assert((CapturedVars.size() ==
CapturedCount + ExpectedNumArgs * XteamRVM.size()) &&
"Unexpected number of captured vars");

// Needed for processing the xteam reduction var pairs:
Expand Down Expand Up @@ -9526,103 +9532,147 @@ static void emitTargetCallKernelLaunch(
// reduction variables.
size_t ArgPos = 0;
size_t RedVarCount = 0;
for (; CapturedCount + ArgPos < CapturedVars.size();) {
// Process the pair of captured variables:
llvm::Value *DTeamValsInst = nullptr;

assert(CapturedCount + ArgPos < CapturedVars.size() &&
"Xteam reduction argument position out of bounds");
assert(RedVarCount < XteamOrdVars.size() &&
"Reduction variable count out of bounds");
const VarDecl *UserRedVar = XteamOrdVars[RedVarCount];
assert(XteamRVM.find(UserRedVar) != XteamRVM.end() &&
"Reduction variable not found in metadata");
auto RedVarQualType =
XteamRVM.find(UserRedVar)->second.RedVarExpr->getType();
llvm::Type *RedVarType = CGF.ConvertTypeForMem(RedVarQualType);

const ASTContext &Context = CGM.getContext();
if (IsXteamRedFast) {
// Placeholder for d_team_vals initialized to nullptr
DTeamValsInst =
CGF.Builder.CreateAlloca(RedVarType, nullptr, "d_team_vals");
Address DTeamValsAddr(DTeamValsInst, RedVarType,
Context.getTypeAlignInChars(RedVarQualType));
llvm::Value *NullPtrDTeamVals =
llvm::ConstantPointerNull::get(RedVarType->getPointerTo());
CGF.Builder.CreateStore(NullPtrDTeamVals, DTeamValsAddr);
} else {
// dteam_vals = omp_target_alloc(sizeof(red-type) * num_teams, devid)
llvm::Value *RedVarTySz = llvm::ConstantInt::get(
CGF.Int64Ty,
CGF.CGM.getDataLayout().getTypeSizeInBits(RedVarType) / 8);
assert((XteamRedNumTeamsFromClauseVal != nullptr ||
XteamRedNumTeamsFromOccupancy != nullptr) &&
"Number of teams cannot be null");
llvm::Value *DTeamValsSz = CGF.Builder.CreateMul(
RedVarTySz,
XteamRedNumTeamsFromClauseVal ? XteamRedNumTeamsFromClauseVal
: XteamRedNumTeamsFromOccupancy,
"d_team_vals_sz");
llvm::Value *TgtAllocArgs[] = {DTeamValsSz, DevIdVal};
DTeamValsInst = CGF.EmitRuntimeCall(
OMPBuilder.getOrCreateRuntimeFunction(CGF.CGM.getModule(),
OMPRTL_omp_target_alloc),
TgtAllocArgs, "d_team_vals");
}
ReductionVars.push_back(DTeamValsInst);
addXTeamReductionComponentHelper(CGF, CombinedInfo, DTeamValsInst);

// Advance to the next reduction variable in the pair:
++ArgPos;

llvm::Value *DTeamsDonePtrInst = nullptr;
if (IsXteamRedFast) {
// Placeholder for d_teams_done_ptr initialized to nullptr
DTeamsDonePtrInst =
CGF.Builder.CreateAlloca(CGF.Int32Ty, nullptr, "d_teams_done_ptr");
Address DTeamsDoneAddr(
DTeamsDonePtrInst, CGF.Int32Ty,
Context.getTypeAlignInChars(Context.UnsignedIntTy));
llvm::Value *NullPtrDTeamsDone =
llvm::ConstantPointerNull::get(CGF.Int32Ty->getPointerTo());
CGF.Builder.CreateStore(NullPtrDTeamsDone, DTeamsDoneAddr);
} else {
// uint32 teams_done = 0
Address TeamsDoneAddr(
CapturedVars[CapturedCount + ArgPos], CGF.Int32Ty,
CGF.getContext().getTypeAlignInChars(CGF.getContext().IntTy));
CGF.Builder.CreateStore(Int32Zero, TeamsDoneAddr);

// d_teams_done_ptr = omp_target_alloc(4, devid)
llvm::Value *IntTySz = llvm::ConstantInt::get(CGF.Int64Ty, 4);
llvm::Value *DTeamsDonePtrArgs[] = {IntTySz, DevIdVal};
DTeamsDonePtrInst = CGF.EmitRuntimeCall(
OMPBuilder.getOrCreateRuntimeFunction(CGF.CGM.getModule(),
OMPRTL_omp_target_alloc),
DTeamsDonePtrArgs, "d_teams_done_ptr");

// omp_target_memcpy(d_teams_done_ptr, &teams_done, 4 /*sizeof(uint32_t)
// */, 0 /* offset */, 0 /* offset */, devid, initial_devid)
llvm::Value *DTeamsDoneMemcpyArgs[] = {
DTeamsDonePtrInst,
TeamsDoneAddr.emitRawPointer(CGF),
/*sizeof(uint32_t)=*/llvm::ConstantInt::get(CGF.Int64Ty, 4),
/*dst_offset=*/llvm::ConstantInt::get(CGF.Int64Ty, 0),
/*src_offset=*/llvm::ConstantInt::get(CGF.Int64Ty, 0),
DevIdVal,
InitialDevInst};
CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
CGF.CGM.getModule(), OMPRTL_omp_target_memcpy),
DTeamsDoneMemcpyArgs);
}
ReductionVars.push_back(DTeamsDonePtrInst);
addXTeamReductionComponentHelper(CGF, CombinedInfo, DTeamsDonePtrInst);

// Advance to the next reduction variable in the pair:
++ArgPos;
if (CGF.CGM.isXteamScanKernel() && !CGF.CGM.isXteamScanPhaseOne) {
// For the Phase 2 of the Xteam Scan codegen, fresh memory allocation for
// reduction helper data structures is not needed. The helpers generated
// during the Phase 1 will be re-used here.
assert(CGF.CGM.ReductionVars.size() == 3 &&
"Xteam Scan reduction code-generates three helper variables");
addXTeamReductionComponentHelper(
CGF, CombinedInfo, CGF.CGM.ReductionVars[0]); // team_vals
addXTeamReductionComponentHelper(
CGF, CombinedInfo, CGF.CGM.ReductionVars[1]); // teams_done_ptr
addXTeamReductionComponentHelper(
CGF, CombinedInfo, CGF.CGM.ReductionVars[2]); // scan_storage
} else {
for (; CapturedCount + ArgPos < CapturedVars.size();) {
// Process the pair of captured variables:
llvm::Value *DTeamValsInst = nullptr;
llvm::Value *DScanStorageInst = nullptr;

assert(CapturedCount + ArgPos < CapturedVars.size() &&
"Xteam reduction argument position out of bounds");
assert(RedVarCount < XteamOrdVars.size() &&
"Reduction variable count out of bounds");
const VarDecl *UserRedVar = XteamOrdVars[RedVarCount];
assert(XteamRVM.find(UserRedVar) != XteamRVM.end() &&
"Reduction variable not found in metadata");
auto RedVarQualType =
XteamRVM.find(UserRedVar)->second.RedVarExpr->getType();
llvm::Type *RedVarType = CGF.ConvertTypeForMem(RedVarQualType);

const ASTContext &Context = CGM.getContext();
if (IsXteamRedFast) {
// Placeholder for d_team_vals initialized to nullptr
DTeamValsInst =
CGF.Builder.CreateAlloca(RedVarType, nullptr, "d_team_vals");
Address DTeamValsAddr(DTeamValsInst, RedVarType,
Context.getTypeAlignInChars(RedVarQualType));
llvm::Value *NullPtrDTeamVals =
llvm::ConstantPointerNull::get(RedVarType->getPointerTo());
CGF.Builder.CreateStore(NullPtrDTeamVals, DTeamValsAddr);
} else {
// dteam_vals = omp_target_alloc(sizeof(red-type) * num_teams, devid)
llvm::Value *RedVarTySz = llvm::ConstantInt::get(
CGF.Int64Ty,
CGF.CGM.getDataLayout().getTypeSizeInBits(RedVarType) / 8);
assert((XteamRedNumTeamsFromClauseVal != nullptr ||
XteamRedNumTeamsFromOccupancy != nullptr) &&
"Number of teams cannot be null");
llvm::Value *DTeamValsSz = CGF.Builder.CreateMul(
RedVarTySz,
XteamRedNumTeamsFromClauseVal ? XteamRedNumTeamsFromClauseVal
: XteamRedNumTeamsFromOccupancy,
"d_team_vals_sz");
llvm::Value *TgtAllocArgs[] = {DTeamValsSz, DevIdVal};
DTeamValsInst = CGF.EmitRuntimeCall(
OMPBuilder.getOrCreateRuntimeFunction(CGF.CGM.getModule(),
OMPRTL_omp_target_alloc),
TgtAllocArgs, "d_team_vals");

if (CGF.CGM.isXteamScanKernel()) {
// d_scan_storage = omp_target_alloc(sizeof(red-type) * (2*num_teams*num_threads + 1), devid)
llvm::Value *TotalNumThreads = CGF.Builder.CreateMul(
XteamRedNumTeamsFromClauseVal ? XteamRedNumTeamsFromClauseVal
: XteamRedNumTeamsFromOccupancy,
CGF.Builder.CreateIntCast(
OMPRuntime->emitNumThreadsForTargetDirective(CGF, D),
CGF.Int64Ty, false),
"total_num_threads");
llvm::Value *StorageSize = CGF.Builder.CreateAdd(
CGF.Builder.CreateMul(TotalNumThreads,
llvm::ConstantInt::get(CGF.Int64Ty, 2)),
llvm::ConstantInt::get(CGF.Int64Ty, 1), "storage_size");
llvm::Value *DScanStorageSz = CGF.Builder.CreateMul(
RedVarTySz, StorageSize, "d_scan_storage_sz");
llvm::Value *TgtAllocArgsScan[] = {DScanStorageSz, DevIdVal};
DScanStorageInst = CGF.EmitRuntimeCall(
OMPBuilder.getOrCreateRuntimeFunction(CGF.CGM.getModule(),
OMPRTL_omp_target_alloc),
TgtAllocArgsScan, "d_scan_storage");
}
}
CGF.CGM.ReductionVars.push_back(DTeamValsInst);
addXTeamReductionComponentHelper(CGF, CombinedInfo, DTeamValsInst);

// Advance to the next reduction variable in the pair:
++ArgPos;

llvm::Value *DTeamsDonePtrInst = nullptr;
if (IsXteamRedFast) {
// Placeholder for d_teams_done_ptr initialized to nullptr
DTeamsDonePtrInst = CGF.Builder.CreateAlloca(CGF.Int32Ty, nullptr,
"d_teams_done_ptr");
Address DTeamsDoneAddr(
DTeamsDonePtrInst, CGF.Int32Ty,
Context.getTypeAlignInChars(Context.UnsignedIntTy));
llvm::Value *NullPtrDTeamsDone =
llvm::ConstantPointerNull::get(CGF.Int32Ty->getPointerTo());
CGF.Builder.CreateStore(NullPtrDTeamsDone, DTeamsDoneAddr);
} else {
// uint32 teams_done = 0
Address TeamsDoneAddr(
CapturedVars[CapturedCount + ArgPos], CGF.Int32Ty,
CGF.getContext().getTypeAlignInChars(CGF.getContext().IntTy));
CGF.Builder.CreateStore(Int32Zero, TeamsDoneAddr);

// d_teams_done_ptr = omp_target_alloc(4, devid)
llvm::Value *IntTySz = llvm::ConstantInt::get(CGF.Int64Ty, 4);
llvm::Value *DTeamsDonePtrArgs[] = {IntTySz, DevIdVal};
DTeamsDonePtrInst = CGF.EmitRuntimeCall(
OMPBuilder.getOrCreateRuntimeFunction(CGF.CGM.getModule(),
OMPRTL_omp_target_alloc),
DTeamsDonePtrArgs, "d_teams_done_ptr");

// omp_target_memcpy(d_teams_done_ptr, &teams_done, 4 /*sizeof(uint32_t)
// */, 0 /* offset */, 0 /* offset */, devid, initial_devid)
llvm::Value *DTeamsDoneMemcpyArgs[] = {
DTeamsDonePtrInst,
TeamsDoneAddr.emitRawPointer(CGF),
/*sizeof(uint32_t)=*/llvm::ConstantInt::get(CGF.Int64Ty, 4),
/*dst_offset=*/llvm::ConstantInt::get(CGF.Int64Ty, 0),
/*src_offset=*/llvm::ConstantInt::get(CGF.Int64Ty, 0),
DevIdVal,
InitialDevInst};
CGF.EmitRuntimeCall(
OMPBuilder.getOrCreateRuntimeFunction(CGF.CGM.getModule(),
OMPRTL_omp_target_memcpy),
DTeamsDoneMemcpyArgs);
}
CGF.CGM.ReductionVars.push_back(DTeamsDonePtrInst);
addXTeamReductionComponentHelper(CGF, CombinedInfo, DTeamsDonePtrInst);

if (CGF.CGM.isXteamScanKernel()) {
// Advance to the next reduction variable in the pair:
++ArgPos;
CGF.CGM.ReductionVars.push_back(DScanStorageInst);
addXTeamReductionComponentHelper(CGF, CombinedInfo, DScanStorageInst);
}
// Advance to the next reduction variable in the pair:
++ArgPos;

++RedVarCount;
++RedVarCount;
}
}
}

Expand Down Expand Up @@ -9726,14 +9776,16 @@ static void emitTargetCallKernelLaunch(
OMPRuntime->emitInlinedDirective(CGF, D.getDirectiveKind(), ThenGen);

if (HasXTeamReduction) {
if (!CGF.CGM.isXteamRedFast(FStmt)) {
if (!CGF.CGM.isXteamRedFast(FStmt) &&
!(CGF.CGM.isXteamScanKernel() && CGF.CGM.isXteamScanPhaseOne)) {
// Deallocate XTeam reduction variables:
for (uint32_t I = 0; I < ReductionVars.size(); ++I) {
llvm::Value *FreeArgs[] = {ReductionVars[I], DevIdVal};
for (uint32_t I = 0; I < CGF.CGM.ReductionVars.size(); ++I) {
llvm::Value *FreeArgs[] = {CGF.CGM.ReductionVars[I], DevIdVal};
CGF.EmitRuntimeCall(OMPBuilder.getOrCreateRuntimeFunction(
CGF.CGM.getModule(), OMPRTL_omp_target_free),
FreeArgs);
}
CGF.CGM.ReductionVars.clear();
}
}
}
Expand Down Expand Up @@ -9891,6 +9943,10 @@ void CGOpenMPRuntime::scanForTargetRegionsFunctions(const Stmt *S,
CodeGenFunction::EmitOMPTargetTeamsDistributeParallelForDeviceFunction(
CGM, ParentName,
cast<OMPTargetTeamsDistributeParallelForDirective>(E));
if (CGM.isXteamScanKernel() && !CGM.isXteamScanPhaseOne)
CodeGenFunction::EmitOMPTargetTeamsDistributeParallelForDeviceFunction(
CGM, ParentName,
cast<OMPTargetTeamsDistributeParallelForDirective>(E));
break;
case OMPD_target_teams_distribute_parallel_for_simd:
CodeGenFunction::
Expand Down
Loading

0 comments on commit 5ef5244

Please sign in to comment.