Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dyno: resolve foo() to foo(standalone) if serial version of the iterator is not available. #26023

Merged
merged 12 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 36 additions & 4 deletions frontend/include/chpl/resolution/resolution-types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,8 @@ enum CandidateFailureReason {
FAIL_CANNOT_PASS,
/* Not a valid formal-actual mapping for this candidate. */
FAIL_FORMAL_ACTUAL_MISMATCH,
/* Special case of formal/actual mismatch when we tried to call a parallel iterator without a tag. */
FAIL_FORMAL_ACTUAL_MISMATCH_ITERATOR_API,
/* The wrong number of varargs were given to the function. */
FAIL_VARARG_MISMATCH,
/* The where clause returned 'false'. */
Expand Down Expand Up @@ -1461,6 +1463,11 @@ class FormalActualMap {
int failingActualIdx_ = -1;
int failingFormalIdx_ = -1;

// A standalone iterator will have an extra formal for the iterKind.
// If we call foo() and fail to get a formal-actual mapping, but only
// because we're missing the iterKind, we should still come back to it.
bool missingIteratorActuals_ = false;

public:

using FormalActualIterable = Iterable<std::vector<FormalActual>>;
Expand All @@ -1477,7 +1484,8 @@ class FormalActualMap {
actualIdxToFormalIdx_ == other.actualIdxToFormalIdx_ &&
mappingIsValid_ == other.mappingIsValid_ &&
failingActualIdx_ == other.failingActualIdx_ &&
failingFormalIdx_ == other.failingFormalIdx_;
failingFormalIdx_ == other.failingFormalIdx_ &&
missingIteratorActuals_ == other.missingIteratorActuals_;
}

bool operator!=(const FormalActualMap& other) const {
Expand All @@ -1492,16 +1500,27 @@ class FormalActualMap {
(void) mappingIsValid_; // nothing to mark
(void) failingActualIdx_; // nothing to mark
(void) failingFormalIdx_; // nothing to mark
(void) missingIteratorActuals_; // nothing to mark
}

size_t hash() const {
return chpl::hash(byFormalIdx_, actualIdxToFormalIdx_, mappingIsValid_,
failingActualIdx_, failingFormalIdx_);
failingActualIdx_, failingFormalIdx_,
missingIteratorActuals_);
}

/** check if mapping is valid */
bool isValid() const { return mappingIsValid_; }

/** Check why the mapping was invalid. */
CandidateFailureReason reason() const {
CHPL_ASSERT(!mappingIsValid_);
if (missingIteratorActuals_) {
return FAIL_FORMAL_ACTUAL_MISMATCH_ITERATOR_API;
}
return FAIL_FORMAL_ACTUAL_MISMATCH;
}

/** get the FormalActuals in the order of the formal arguments */
FormalActualIterable byFormals() const {
return FormalActualIterable(byFormalIdx_);
Expand Down Expand Up @@ -1849,6 +1868,7 @@ class CallResolutionResult {
// whether the resolution result was handled using some compiler-level logic,
// which does not correspond to a TypedSignature or AST.
bool speciallyHandled_ = false;
bool rejectedPossibleIteratorCandidates_ = false;

public:
CallResolutionResult() {}
Expand All @@ -1861,13 +1881,15 @@ class CallResolutionResult {
}

CallResolutionResult(MostSpecificCandidates mostSpecific,
bool rejectedPossibleIteratorCandidates,
types::QualifiedType exprType,
PoiInfo poiInfo,
bool speciallyHandled = false)
: mostSpecific_(std::move(mostSpecific)),
exprType_(std::move(exprType)),
poiInfo_(std::move(poiInfo)),
speciallyHandled_(speciallyHandled)
speciallyHandled_(speciallyHandled),
rejectedPossibleIteratorCandidates_(rejectedPossibleIteratorCandidates)
{
}

Expand All @@ -1883,11 +1905,19 @@ class CallResolutionResult {
/** whether the resolution result was handled using some compiler-level logic */
bool speciallyHandled() const { return speciallyHandled_; }

/** whether we rejected candidates because they expected a tag or followThis.
This might indicate that we need to re-resolve with tag to find parallel
iterators. */
bool rejectedPossibleIteratorCandidates() const {
return rejectedPossibleIteratorCandidates_;
}

bool operator==(const CallResolutionResult& other) const {
return mostSpecific_ == other.mostSpecific_ &&
exprType_ == other.exprType_ &&
PoiInfo::updateEquals(poiInfo_, other.poiInfo_) &&
speciallyHandled_ == other.speciallyHandled_;
speciallyHandled_ == other.speciallyHandled_ &&
rejectedPossibleIteratorCandidates_ == other.rejectedPossibleIteratorCandidates_;
}
bool operator!=(const CallResolutionResult& other) const {
return !(*this == other);
Expand All @@ -1897,6 +1927,8 @@ class CallResolutionResult {
exprType_.swap(other.exprType_);
poiInfo_.swap(other.poiInfo_);
std::swap(speciallyHandled_, other.speciallyHandled_);
std::swap(rejectedPossibleIteratorCandidates_,
other.rejectedPossibleIteratorCandidates_);
}

void stringify(std::ostream& ss, chpl::StringifyKind stringKind) const;
Expand Down
150 changes: 131 additions & 19 deletions frontend/lib/resolution/Resolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3951,6 +3951,99 @@ shouldSkipCallResolution(Resolver* rv, const uast::Call* call,
return skip;
}

static const bool& warnForMissingIterKindEnum(Context* context,
const AstNode* astForErr) {
QUERY_BEGIN(warnForMissingIterKindEnum, context, astForErr);
context->warning(astForErr, "resolving parallel iterators is not supported "
"without module code");
return QUERY_END(true);
}

static const QualifiedType&
getIterKindConstantOrWarn(Context* context,
const AstNode* astForErr,
UniqueString iterKindStr) {
QUERY_BEGIN(getIterKindConstantOrWarn, context, astForErr, iterKindStr);

auto iterKindActual = getIterKindConstantOrUnknownQuery(context, iterKindStr);
bool needSerial = iterKindStr.isEmpty();

// Exit early if we need a parallel iterator and don't have the enum.
if (!needSerial && iterKindActual.isUnknown()) {
warnForMissingIterKindEnum(context, astForErr);
}

return QUERY_END(iterKindActual);
}

static optional<CallResolutionResult>
rerunCallInfoWithIteratorTag(ResolutionContext* rc,
const uast::Call* call,
ResolvedExpression& r,
const CallInfo& ci,
const CallScopeInfo& inScopes,
QualifiedType receiverType,
UniqueString iterKindStr) {
auto iterKindActual = getIterKindConstantOrWarn(rc->context(), call, iterKindStr);
if (iterKindActual.isUnknown()) return empty;

std::vector<CallInfoActual> actuals;
for (const auto& actual : ci.actuals())
actuals.push_back(actual);
actuals.emplace_back(iterKindActual, USTR("tag"));

auto newCi = CallInfo(ci.name(), ci.calledType(), ci.isMethodCall(),
ci.hasQuestionArg(), ci.isParenless(), actuals);

auto newC = resolveCallInMethod(rc, call, newCi, inScopes, receiverType,
/* rejected */ nullptr);

// Also note the call as an associated action
if (!newC.mostSpecific().isEmpty()) {
for (auto sig : newC.mostSpecific()) {
if (!sig) continue;
r.addAssociatedAction(AssociatedAction::ITERATE, sig.fn(), call->id());
}

return newC;
}

return empty;
}

// Invokes resolveCallInMethod, and if that fails due to iterator candidates
// (e.g., we called foo() but only foo(standalone) is in scope), re-attempts
// the resolution with the other candidates.
//
// Note that the order here should match resolveIterDetailsInPriorityOrder.
static CallResolutionResult
resolveCallInMethodReattemptIfNeeded(ResolutionContext* rc,
const uast::Call* call,
ResolvedExpression& r,
const CallInfo& ci,
const CallScopeInfo& inScopes,
QualifiedType receiverType,
std::vector<ApplicabilityResult>* rejected) {
auto c = resolveCallInMethod(rc, call, ci, inScopes,
receiverType,
rejected);

// Other overloads are present and may be usable to fill in for 'foo()'.
if (c.mostSpecific().isEmpty() && c.rejectedPossibleIteratorCandidates()) {
if (auto standalone =
rerunCallInfoWithIteratorTag(rc, call, r, ci, inScopes, receiverType,
USTR("standalone"))) {
return *standalone;
}
if (auto parallel =
rerunCallInfoWithIteratorTag(rc, call, r, ci, inScopes, receiverType,
USTR("leader"))) {
return *parallel;
}
}
return c;
}

void Resolver::handleCallExpr(const uast::Call* call) {
if (scopeResolveOnly) {
return;
Expand Down Expand Up @@ -3991,6 +4084,7 @@ void Resolver::handleCallExpr(const uast::Call* call) {
ci);

if (!skip) {
ResolvedExpression& r = byPostorder.byAst(call);
QualifiedType receiverType = methodReceiverType();

// If the user has mistakenly instantiated a field of the type before
Expand All @@ -4015,12 +4109,11 @@ void Resolver::handleCallExpr(const uast::Call* call) {
}

std::vector<ApplicabilityResult>* rejected = nullptr;
auto c = resolveCallInMethod(rc, call, ci, inScopes,
receiverType,
rejected);
auto c = resolveCallInMethodReattemptIfNeeded(rc, call, r, ci, inScopes,
receiverType,
rejected);

// save the most specific candidates in the resolution result for the id
ResolvedExpression& r = byPostorder.byAst(call);
handleResolvedCallPrintCandidates(r, call, ci, inScopes, receiverType, c);

// handle type inference for variables split-inited by 'out' formals
Expand Down Expand Up @@ -4589,6 +4682,23 @@ static QualifiedType resolveTheseMethod(Resolver& rv,
return c.exprType();
}

static bool isExplicitlyTaggedIteratorCall(Context* context,
ResolvedExpression& re,
const TypedFnSignature* fn) {
if (!fn || !fn->isParallelIterator(context)) return false;

// We could've ended up resolving a leader automatically from a serial
// call (if the serial overload doesn't exist). To check that this was
// an explicit tag, we need to not have any ITERATE associted actions.
auto count = std::count_if(re.associatedActions().begin(),
re.associatedActions().end(),
[](const AssociatedAction& aa) {
return aa.action() == AssociatedAction::ITERATE;
});

return count == 0;
}

static QualifiedType
resolveIterTypeWithTag(Resolver& rv,
IterDetails::Pieces& outIterPieces,
Expand All @@ -4601,13 +4711,14 @@ resolveIterTypeWithTag(Resolver& rv,
QualifiedType unknown(QualifiedType::UNKNOWN, UnknownType::get(context));
QualifiedType error(QualifiedType::UNKNOWN, ErroneousType::get(context));

auto iterKindFormal = getIterKindConstantOrUnknownQuery(context, iterKindStr);
auto iterKindActual = getIterKindConstantOrWarn(context, astForErr, iterKindStr);
bool needSerial = iterKindStr.isEmpty();
bool needStandalone = iterKindStr == USTR("standalone");
bool needLeader = iterKindStr == USTR("leader");
bool needFollower = iterKindStr == USTR("follower");

// Exit early if we need a parallel iterator and don't have the enum.
if (!needSerial && iterKindFormal.isUnknown()) {
context->warning(astForErr, "resolving parallel iterators is not supported "
"without module code");
if (!needSerial && iterKindActual.isUnknown()) {
return error;
}

Expand All @@ -4619,18 +4730,19 @@ resolveIterTypeWithTag(Resolver& rv,
// are automatically provided by the compiler. Report an error.
auto& MSC = iterandRE.mostSpecific();
auto fn = MSC.only() ? MSC.only().fn() : nullptr;
if (fn && fn->isParallelIterator(context)) {
context->error(astForErr,
"explicitly invoking parallel iterators is not allowed -- "
"they are invoked implicitly by the compiler.");
return error;
}

bool wasIterandTypeResolved = !iterandType.isUnknownOrErroneous();
bool wasMatchingIterResolved =
// Call to a serial iterator overload, and we are looking for a serial iterator.
(fn && fn->isSerialIterator(context) && needSerial) ||
// Loop expressions (which we just resolved) and we are looking for a serial iterator.
// For iterator forwarding, we can write serial 'for' loops over tagged iterator calls
bool treatAsSerial = fn &&
(fn->isSerialIterator(context) || isExplicitlyTaggedIteratorCall(context, iterandRE, fn));
// Call to a serial iterator overload, and we are looking for a serial iterator.
bool wasMatchingIterResolved = fn &&
((needSerial && treatAsSerial) ||
(needStandalone && fn->isParallelStandaloneIterator(context)) ||
(needLeader && fn->isParallelLeaderIterator(context)) ||
(needFollower && fn->isParallelFollowerIterator(context)));
// Loop expressions (which we just resolved) and we are looking for a serial iterator.
wasMatchingIterResolved |=
(iterandType.type() && iterandType.type()->isLoopExprIteratorType() && needSerial);

// The iterand was a call to a serial iterator, and we need a serial iterator.
Expand All @@ -4650,7 +4762,7 @@ resolveIterTypeWithTag(Resolver& rv,
// or an iterator. The latter have compiler-generated 'these' methods
// which implement the dispatch logic like rewriting an iterator from `iter foo()`
// to `iter foo(tag)`. So just resolve the 'these' method.
auto qt = resolveTheseMethod(rv, iterand, iterandType, iterKindFormal, followThisFormal);
auto qt = resolveTheseMethod(rv, iterand, iterandType, iterKindActual, followThisFormal);
if (!qt.isUnknownOrErroneous() && qt.type()->isIteratorType()) {
// These produced a valid iterator. We already configured the call
// with the desired tag, so that's sufficient.
Expand Down
3 changes: 2 additions & 1 deletion frontend/lib/resolution/intents.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ static QualifiedType::Kind defaultIntentForType(const Type* t,
return QualifiedType::CONST_IN;

if (t->isStringType() || t->isBytesType() ||
t->isRecordType() || t->isUnionType() || t->isTupleType()) {
t->isRecordType() || t->isUnionType() || t->isTupleType() ||
t->isIteratorType()) {
if (isThis) {
if (isInit)
return QualifiedType::REF;
Expand Down
8 changes: 6 additions & 2 deletions frontend/lib/resolution/prims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,9 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc,
} else {
CHPL_ASSERT(false && "unsupported param folding");
}
return CallResolutionResult(candidates, type, poi, /* specially handled */ true);
return CallResolutionResult(candidates,
/* rejectedPossibleIteratorCandidates */ false,
type, poi, /* specially handled */ true);
}

// otherwise, handle each primitive individually
Expand Down Expand Up @@ -1816,7 +1818,9 @@ CallResolutionResult resolvePrimCall(ResolutionContext* rc,
type = QualifiedType(QualifiedType::UNKNOWN, ErroneousType::get(context));
}

return CallResolutionResult(candidates, type, poi, /* specially handled */ true);
return CallResolutionResult(candidates,
/* rejectedPossibleIteratorCandidates */ false,
type, poi, /* specially handled */ true);
}


Expand Down
Loading