Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ajreynol committed Feb 27, 2024
2 parents 047d8c6 + d0a67cc commit 5e5b05d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 33 deletions.
27 changes: 27 additions & 0 deletions src/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,33 @@ std::string Expr::getSymbol() const

ExprValue* Expr::getValue() const { return d_value; }

std::pair<std::vector<Expr>, Expr> Expr::getFunctionType() const
{
Expr et = *this;
std::vector<Expr> args;
while (et.getKind()==Kind::FUNCTION_TYPE)
{
size_t nchild = et.getNumChildren();
for (size_t i=0; i<nchild-1; i++)
{
args.push_back(et[i]);
}
et = et[nchild-1];
// strip off requires
while (et.getKind()==Kind::EVAL_REQUIRES)
{
et = et[2];
}
}
return std::pair<std::vector<Expr>, Expr>(args, et);
}

size_t Expr::getFunctionArity() const
{
std::pair<std::vector<Expr>, Expr> ftype = getFunctionType();
return ftype.first.size();
}

/**
* SMT-LIB 2 quoting for symbols
*/
Expand Down
7 changes: 6 additions & 1 deletion src/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ class Expr
std::string getSymbol() const;
/** Get underlying value */
ExprValue* getValue() const;

/**
* Get function type, which is a pair of argument types and the range type.
*/
std::pair<std::vector<Expr>, Expr> getFunctionType() const;
/** Get arity, where this is a function type. Used for overloading. */
size_t getFunctionArity() const;
private:
/** The underlying value */
ExprValue* d_value;
Expand Down
51 changes: 21 additions & 30 deletions src/state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -824,40 +824,40 @@ Expr State::mkExpr(Kind k, const std::vector<Expr>& children)
{
Trace("overload") << "process alf.as " << children[0] << " " << children[1] << std::endl;
AppInfo* ai = getAppInfo(vchildren[0]);
Expr ret;
Expr ret = children[0];
std::pair<std::vector<Expr>, Expr> ftype = children[1].getFunctionType();
if (ai!=nullptr && !ai->d_overloads.empty())
{
size_t arity = 0;
Expr cur = children[1];
while (cur.getKind()==Kind::FUNCTION_TYPE)
{
size_t nchild = cur.getNumChildren();
arity += nchild-1;
cur = cur[nchild-1];
}
Trace("overload") << "...overloaded with arity " << arity << std::endl;
size_t arity = ftype.first.size();
Trace("overload") << "...overloaded, check arity " << arity << std::endl;
// look up the overload
std::map<size_t, Expr>::iterator ito = ai->d_overloads.find(arity);
if (ito!=ai->d_overloads.end())
{
ret = ito->second;
}
// otherwise try the default (first) symbol parsed, which is children[0]
}
else
{
Trace("overload") << "...not overloaded" << std::endl;
ret = children[0];
}
if (!ret.isNull())
Trace("overload") << "Apply " << ret << " of type " << d_tc.getType(ret) << " to children of types:" << std::endl;
std::vector<Expr> cchildren;
cchildren.push_back(ret);
for (const Expr& t : ftype.first)
{
Expr tret = d_tc.getType(ret);
Trace("overload") << "Compare " << tret << " " << children[1] << std::endl;
// must be matchable
Ctx ctx;
if (d_tc.match(tret.getValue(), vchildren[1], ctx))
{
return ret;
}
Trace("overload") << "- " << t << std::endl;
cchildren.push_back(getBoundVar("as.v", t));
}
Expr cret = mkExpr(Kind::APPLY, cchildren);
Expr tcret = d_tc.getType(cret);
Trace("overload") << "Range expected/computed: " << ftype.second << " " << tcret<< std::endl;
// if succeeded, we return the disambiguated term, otherwise the alf.as does not evaluate
// and we construct the (bogus) term below.
if (ftype.second==tcret)
{
return ret;
}
}
}
Expand Down Expand Up @@ -1021,16 +1021,7 @@ bool State::bind(const std::string& name, const Expr& e)
AppInfo& ai = d_appData[its->second.getValue()];
Expr ee = e;
Expr et = d_tc.getType(ee);
size_t arity = 0;
while (et.getKind()==Kind::FUNCTION_TYPE)
{
arity++;
et = et[et.getNumChildren()-1];
while (et.getKind()==Kind::EVAL_REQUIRES)
{
et = et[2];
}
}
size_t arity = et.getFunctionArity();
Trace("overload") << "Overload " << e << " for " << its->second << " with arity " << arity << std::endl;
if (ai.d_overloads.find(arity)!=ai.d_overloads.end())
{
Expand Down
2 changes: 1 addition & 1 deletion src/type_checker.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ class TypeChecker
* operator does not evaluate.
*/
Expr evaluateLiteralOp(Kind k, const std::vector<ExprValue*>& args);
private:
/**
* Match expression a with b. If this returns true, then ctx is a substitution
* that when applied to b gives a. The substitution
*/
bool match(ExprValue* a, ExprValue* b, Ctx& ctx);
private:
/** Same as above, but takes a cache of pairs we have already visited */
bool match(ExprValue* a,
ExprValue* b,
Expand Down
2 changes: 1 addition & 1 deletion tests/overloading-as.smt3
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
(declare-const ho_pred (-> (-> Int Int) Bool))


(assume @p0 (ho_pred (alf.as - (-> Bool Int))))
(assume @p0 (ho_pred (alf.as - (-> Int Int))))

0 comments on commit 5e5b05d

Please sign in to comment.