Skip to content

Commit

Permalink
Add test of iteration over tuples
Browse files Browse the repository at this point in the history
Signed-off-by: Danila Fedorin <daniel.fedorin@hpe.com>
  • Loading branch information
DanilaFe committed Jan 31, 2025
1 parent 49e4b27 commit e7fe392
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 32 deletions.
66 changes: 34 additions & 32 deletions frontend/include/chpl/resolution/ResolvedVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,41 +31,43 @@ using namespace uast;
template <typename ResolvedVisitorImpl>
static bool resolvedVisitorEnterFor(ResolvedVisitorImpl& v,
const uast::For* loop) {
if (loop->isParam()) {
// Enter the param for-loop with the current resolution results, so that
// the user-defined visitor can choose to do something custom if it
// wants.
bool goInto = v.userVisitor().enter(loop, v);

if (goInto) {
const ResolvedExpression& rr = v.byAst(loop);
const ResolvedParamLoop* resolvedLoop = rr.paramLoop();

if (resolvedLoop == nullptr) return false;

const AstNode* iterand = loop->iterand();
iterand->traverse(v);

// TODO: Should there be some kind of function the UserVisitor can
// implement to observe a new iteration of the loop body?
for (const auto& loopBody : resolvedLoop->loopBodies()) {
ResolvedVisitorImpl loopVis(v.rc(), loop,
v.userVisitor(), loopBody);

for (const AstNode* child : loop->children()) {
// Written to visit "all but the iterand" in case we add more
// fields/children to the For class later.
if (child != iterand) {
child->traverse(loopVis);
}
}
bool goInto = v.userVisitor().enter(loop, v);
if (!goInto) return false;

// don't return 'true' if it's a param loop, we'll enter it below if able.
if (loop->isParam()) goInto = false;

// some loops have 'paramLoop' info (param loops, loops over heterogeneous tuples)
// but most don't, so check hasAst and bail if it's not there.
if (!v.hasAst(loop)) {
return goInto;
}

const ResolvedExpression& rr = v.byAst(loop);
const ResolvedParamLoop* resolvedLoop = rr.paramLoop();

// no param resolution results, act like a normal loop
if (resolvedLoop == nullptr) return goInto;

const AstNode* iterand = loop->iterand();
iterand->traverse(v);

// TODO: Should there be some kind of function the UserVisitor can
// implement to observe a new iteration of the loop body?
for (const auto& loopBody : resolvedLoop->loopBodies()) {
ResolvedVisitorImpl loopVis(v.rc(), loop,
v.userVisitor(), loopBody);

for (const AstNode* child : loop->children()) {
// Written to visit "all but the iterand" in case we add more
// fields/children to the For class later.
if (child != iterand) {
child->traverse(loopVis);
}
}

return false;
} else {
return v.userVisitor().enter(loop, v);
}

return false;
}

template <typename ResolvedVisitorImpl>
Expand Down
65 changes: 65 additions & 0 deletions frontend/test/resolution/testLoopIndexVars.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,70 @@ static void testNestedParamFor() {
assert(resolvedVals == pc.resolvedVals);
}

struct TupleCollector {
using RV = ResolvedVisitor<TupleCollector>;

std::multimap<std::string, const Type*> resolvedVals;

TupleCollector() { }

bool enter(const uast::VarLikeDecl* decl, RV& rv) {
if (decl->storageKind() == Qualifier::PARAM ||
decl->storageKind() == Qualifier::INDEX) {
const ResolvedExpression& rr = rv.byAst(decl);
if (!rr.type().isUnknownOrErroneous()) {
resolvedVals.emplace(decl->name().str(), rr.type().type());
}
}
return true;
}

bool enter(const uast::AstNode* ast, RV& rv) {
return true;
}
void exit(const uast::AstNode* ast, RV& rv) {}
};

static void testHeteroTuples() {
printf("testHeteroTuples\n");
auto context = buildStdContext();
ErrorGuard guard(context);
ResolutionContext rcval(context);
auto rc = &rcval;

auto loopText = R"""(
var tmp = (1.0, 1, true);
for x in tmp {
for y in tmp {
}
}
)""";
const Module* m = parseModule(context, loopText);

const ResolutionResultByPostorderID& rr = resolveModule(context, m->id());
TupleCollector pc;
ResolvedVisitor<TupleCollector> rv(rc, m, pc, rr);
m->traverse(rv);

std::multimap<std::string, const Type*> expected;
expected.emplace("x", RealType::get(context, 0));
expected.emplace("y", RealType::get(context, 0));
expected.emplace("y", IntType::get(context, 0));
expected.emplace("y", BoolType::get(context));
expected.emplace("x", IntType::get(context, 0));
expected.emplace("y", RealType::get(context, 0));
expected.emplace("y", IntType::get(context, 0));
expected.emplace("y", BoolType::get(context));
expected.emplace("x", BoolType::get(context));
expected.emplace("y", RealType::get(context, 0));
expected.emplace("y", IntType::get(context, 0));
expected.emplace("y", BoolType::get(context));

assert(expected == pc.resolvedVals);
}

static void testIndexScope0() {
printf("testIndexScope0\n");
auto context = buildStdContext();
Expand Down Expand Up @@ -1469,6 +1533,7 @@ int main() {
testCForLoop();
testParamFor();
testNestedParamFor();
testHeteroTuples();
testIndexScope0();
testIndexScope1();

Expand Down

0 comments on commit e7fe392

Please sign in to comment.