Skip to content

Commit

Permalink
Merge pull request #32 from KacperFKorban/run-lift
Browse files Browse the repository at this point in the history
Run function lifting and named form refs
  • Loading branch information
KacperFKorban authored Mar 23, 2024
2 parents ecbc18b + 839ceed commit 14c8caf
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 35 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ lazy val guinep = projectMatrix
.in(file("guinep"))
.settings(commonSettings)
.settings(
name := "GUInep"
name := "GUInep",
libraryDependencies ++= Seq(
"com.softwaremill.quicklens" %%% "quicklens" % "1.9.7"
)
)
.jvmPlatform(scalaVersions = List(scala3))

Expand Down
104 changes: 89 additions & 15 deletions guinep/src/main/scala/macros.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package guinep

import guinep.model.*
import scala.quoted.*
import scala.collection.mutable
import com.softwaremill.quicklens.*

private[guinep] object macros {
inline def funInfos(inline fs: Any): Seq[Fun] =
Expand Down Expand Up @@ -33,7 +35,13 @@ private[guinep] object macros {
extension (t: Term)
private def select(s: Term): Term = Select(t, s.symbol)
private def select(s: String): Term =
t.select(t.tpe.typeSymbol.methodMember(s).head)
t.select(
t.tpe
.typeSymbol
.methodMember(s)
.headOption.
getOrElse(report.errorAndAbort(s"PANIC: No member $s in term ${t.show} with type ${t.tpe.show}"))
)

extension (s: Symbol)
private def prettyName: String =
Expand Down Expand Up @@ -93,7 +101,28 @@ private[guinep] object macros {
val isEnumCaseNonClassDef = typeSymbol.flags.is(Flags.Enum) && typeSymbol.flags.is(Flags.Case) && !typeSymbol.isClassDef
isModule || isEnumCaseNonClassDef

private def functionFormElementFromTree(paramName: String, paramType: TypeRepr): FormElement = paramType match {
private case class FormConstrContext(constructedTpes: mutable.Map[String, Option[FormElement]], referencedTpes: mutable.Set[String])
private def formConstrCtx(using FormConstrContext) = summon[FormConstrContext]

extension (tpe: TypeRepr)
private def namedRef: String = tpe match
case ntpe: NamedType => ntpe.typeSymbol.fullName
case AppliedType(tpe, args) => s"${tpe.namedRef}[${args.map(_.namedRef).mkString(", ")}]"
case AnnotatedType(tpe, _) => tpe.namedRef
case _ => tpe.show

private def functionFormElementFromTreeWithCaching(paramName: String, paramTpe: TypeRepr)(using FormConstrContext): FormElement =
formConstrCtx.constructedTpes.get(paramTpe.namedRef) match
case Some(_) =>
formConstrCtx.referencedTpes.add(paramTpe.namedRef)
FormElement.NamedRef(paramName, paramTpe.namedRef)
case _ =>
formConstrCtx.constructedTpes.update(paramTpe.namedRef, None)
val formElement = functionFormElementFromTree(paramName, paramTpe)
formConstrCtx.constructedTpes.update(paramTpe.namedRef, Some(formElement.modify(_.name).setTo("value")))
formElement

private def functionFormElementFromTree(paramName: String, paramType: TypeRepr)(using FormConstrContext): FormElement = paramType match {
case ntpe: NamedType if ntpe.name == "String" => FormElement.TextInput(paramName)
case ntpe: NamedType if ntpe.name == "Int" => FormElement.NumberInput(paramName)
case ntpe: NamedType if ntpe.name == "Boolean" => FormElement.CheckboxInput(paramName)
Expand All @@ -104,7 +133,7 @@ private[guinep] object macros {
FormElement.FieldSet(
paramName,
fields.map { valdef =>
functionFormElementFromTree(
functionFormElementFromTreeWithCaching(
valdef.name,
valdef.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs).stripAnnots
)
Expand All @@ -113,17 +142,29 @@ private[guinep] object macros {
case ntpe if isSumTpe(ntpe) =>
val classSymbol = ntpe.typeSymbol
val childrenAppliedTpes = classSymbol.children.map(child => appliedChild(child, classSymbol, ntpe.typeArgs)).map(_.stripAnnots)
val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTree("value", t))
val childrenFormElements = childrenAppliedTpes.map(t => functionFormElementFromTreeWithCaching("value", t))
val options = classSymbol.children.map(_.prettyName).zip(childrenFormElements)
FormElement.Dropdown(paramName, options)
case _ =>
unsupportedFunctionParamType(paramType)
}

private def functionFormElementsImpl(f: Expr[Any]): Expr[Seq[FormElement]] =
Expr.ofSeq(
functionParams(f).map { case ValDef(name, tpt, _) => functionFormElementFromTree(name, tpt.tpe) } .map(Expr(_))
)
private def formImpl(f: Expr[Any]): Expr[Form] =
given FormConstrContext = FormConstrContext(mutable.Map.empty, mutable.Set.empty)
val inputs = functionParams(f)
.map {
case ValDef(name, tpt, _) =>
functionFormElementFromTreeWithCaching(name, tpt.tpe)
}
val usedFormDecls =
formConstrCtx.constructedTpes
.toList.filter( (ref, formElement) => formConstrCtx.referencedTpes.contains(ref) )
.collect {
case (ref, Some(formElement)) => ref -> formElement
}
.toMap
val form = Form(inputs, usedFormDecls)
Expr(form)

private def appliedChild(childSym: Symbol, parentSym: Symbol, parentArgs: List[TypeRepr]): TypeRepr = childSym.tree match {
case classDef @ ClassDef(_, _, parents, _, _) =>
Expand All @@ -149,7 +190,35 @@ private[guinep] object macros {
childSym.typeRef
}

private def constructArg(paramTpe: TypeRepr, param: Term): Term = {
private case class ConstrEntry(definition: Option[Statement], ref: Term)
private case class ConstrContext(constrMap: mutable.Map[String, ConstrEntry])
private def constrCtx(using ConstrContext) = summon[ConstrContext]

private def constructArgWithCaching(paramTpe: TypeRepr, param: Term)(using ConstrContext): Term =
constrCtx.constrMap.get(paramTpe.namedRef) match
case Some(ConstrEntry(_, ref)) =>
ref.appliedTo(param)
case None =>
val ConstrEntry(_, ref) = constructFunction(paramTpe)
ref.appliedTo(param)

private def constructFunction(paramTpe: TypeRepr)(using ConstrContext): ConstrEntry =
val defdefSymbol =
Symbol.newMethod(
Symbol.spliceOwner,
s"constrFunFor${paramTpe.namedRef}",
MethodType(List("inputs"))(_ => List(TypeRepr.of[Any]), _ => paramTpe)
)
constrCtx.constrMap.update(paramTpe.namedRef, ConstrEntry(None, Ref(defdefSymbol)))
val defdef = DefDef(defdefSymbol, {
case List(List(param: Term)) =>
Some(constructArg(paramTpe, param))
})
val constrEntry = ConstrEntry(Some(defdef), Ref(defdefSymbol))
val newMap = constrCtx.constrMap.update(paramTpe.namedRef, constrEntry)
constrEntry

private def constructArg(paramTpe: TypeRepr, param: Term)(using ConstrContext): Term = {
paramTpe match {
case ntpe: NamedType if ntpe.name == "String" => param.select("asInstanceOf").appliedToType(ntpe)
case ntpe: NamedType if ntpe.name == "Int" => param.select("asInstanceOf").appliedToType(ntpe)
Expand All @@ -166,7 +235,7 @@ private[guinep] object macros {
val args = fields.collect { case field: ValDef =>
val fieldName = field.name
val fieldValue = paramValue.select("apply").appliedTo(Literal(StringConstant(fieldName)))
constructArg(
constructArgWithCaching(
field.tpt.tpe.substituteTypes(typeDefParams, ntpe.typeArgs),
fieldValue
)
Expand All @@ -186,7 +255,7 @@ private[guinep] object macros {
val childName = Literal(StringConstant(child.prettyName))
If(
paramName.select("equals").appliedTo(childName),
constructArg(childAppliedTpe, paramValue),
constructArgWithCaching(childAppliedTpe, paramValue),
acc
)
}
Expand All @@ -199,14 +268,15 @@ private[guinep] object macros {
private def functionRunImpl(f: Expr[Any]): Expr[List[Any] => String] = f.asTerm match {
case l@Lambda(params, _) =>
/* (params: List[Any]) => l.apply(constructArg(params(0)), constructArg(params(1)), ...) */
Lambda(
given ConstrContext = ConstrContext(mutable.Map.empty)
val resLambda = Lambda(
Symbol.spliceOwner,
MethodType(List("inputs"))(_ => List(TypeRepr.of[List[Any]]), _ => TypeRepr.of[String]),
{ case (sym, List(params: Term)) =>
val args = functionParams(f).zipWithIndex.map { case (valdef, i) =>
val paramTpe = valdef.tpt.tpe
val param = params.select("apply").appliedTo(Literal(IntConstant(i)))
constructArg(paramTpe, param)
constructArgWithCaching(paramTpe, param)
}.toList
val aply = l.select("apply")
val res =
Expand All @@ -216,6 +286,10 @@ private[guinep] object macros {
aply.appliedToArgs(args)
res.select("toString").appliedToNone
}
)
Block(
constrCtx.constrMap.toList.flatMap(_._2.definition),
resLambda
).asExprOf[List[Any] => String]
case i@Ident(_) =>
Lambda(
Expand Down Expand Up @@ -249,9 +323,9 @@ private[guinep] object macros {

def funInfoImpl(f: Expr[Any]): Expr[Fun] = {
val name = functionNameImpl(f)
val params = functionFormElementsImpl(f)
val form = formImpl(f)
val run = functionRunImpl(f)
'{ Fun($name, $params, $run) }
'{ Fun($name, $form, $run) }
}
}
}
53 changes: 50 additions & 3 deletions guinep/src/main/scala/model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,26 @@ package guinep
import scala.quoted.*

private[guinep] object model {
case class Fun(name: String, inputs: Seq[FormElement], run: List[Any] => String)
case class Fun(name: String, form: Form, run: List[Any] => String)

case class Form(inputs: Seq[FormElement], namedFormElements: Map[String, FormElement]) {
def formElementsJSONRepr =
val elems = this.inputs.map(_.toJSONRepr).mkString(",")
s"[$elems]"
def namedFormElementsJSONRepr: String =
val entries = this.namedFormElements.toList.map { (name, formElement) =>
s""""$name": ${formElement.toJSONRepr}"""
}
.mkString(",")
s"{$entries}"
}
object Form:
given ToExpr[Form] with
def apply(form: Form)(using Quotes): Expr[Form] = form match
case Form(inputs, namedFormElements) =>
'{ Form(${Expr(inputs)}, ${Expr(namedFormElements)}) }

enum FormElement(val name: String):
case FieldSet(override val name: String, elements: List[FormElement]) extends FormElement(name)
case TextInput(override val name: String) extends FormElement(name)
case NumberInput(override val name: String) extends FormElement(name)
case CheckboxInput(override val name: String) extends FormElement(name)
Expand All @@ -15,6 +31,20 @@ private[guinep] object model {
case DateInput(override val name: String) extends FormElement(name)
case EmailInput(override val name: String) extends FormElement(name)
case PasswordInput(override val name: String) extends FormElement(name)
case FieldSet(override val name: String, elements: List[FormElement]) extends FormElement(name)
case NamedRef(override val name: String, ref: String) extends FormElement(name)

def constrOrd: Int = this match
case TextInput(_) => 0
case NumberInput(_) => 1
case CheckboxInput(_) => 2
case Dropdown(_, _) => 3
case TextArea(_, _, _) => 4
case DateInput(_) => 5
case EmailInput(_) => 6
case PasswordInput(_) => 7
case FieldSet(_, _) => 8
case NamedRef(_, _) => 9

def toJSONRepr: String = this match
case FormElement.FieldSet(name, elements) =>
Expand All @@ -26,7 +56,8 @@ private[guinep] object model {
case FormElement.CheckboxInput(name) =>
s"""{ "name": '$name', "type": 'checkbox' }"""
case FormElement.Dropdown(name, options) =>
s"""{ "name": '$name', "type": 'dropdown', "options": [${options.map { case (k, v) => s"""{"name": "$k", "value": ${v.toJSONRepr}}""" }.mkString(",")}] }"""
// TODO(kπ) this sortBy isn't 100% sure to be working (the only requirement is for the first constructor to not be recursive; this is a graph problem, sorta)
s"""{ "name": '$name', "type": 'dropdown', "options": [${options.sortBy(_._2).map { case (k, v) => s"""{"name": "$k", "value": ${v.toJSONRepr}}""" }.mkString(",")}] }"""
case FormElement.TextArea(name, rows, cols) =>
s"""{ "name": '$name', "type": 'textarea', "rows": ${rows.getOrElse("")}, "cols": ${cols.getOrElse("")} }"""
case FormElement.DateInput(name) =>
Expand All @@ -35,6 +66,8 @@ private[guinep] object model {
s"""{ "name": '$name', "type": 'email' }"""
case FormElement.PasswordInput(name) =>
s"""{ "name": '$name', "type": 'password' }"""
case FormElement.NamedRef(name, ref) =>
s"""{ "name": '$name', "ref": '$ref', "type": 'namedref' }"""

object FormElement:
given ToExpr[FormElement] with
Expand All @@ -57,4 +90,18 @@ private[guinep] object model {
'{ FormElement.EmailInput(${Expr(name)}) }
case FormElement.PasswordInput(name) =>
'{ FormElement.PasswordInput(${Expr(name)}) }
case FormElement.NamedRef(name, ref) =>
'{ FormElement.NamedRef(${Expr(name)}, ${Expr(ref)}) }

given Ordering[FormElement] = new Ordering[FormElement] {
def compare(x: FormElement, y: FormElement): Int =
if x.constrOrd < y.constrOrd then -1
else if x.constrOrd > y.constrOrd then 1
else (x, y) match
case (FormElement.FieldSet(_, elems1), FormElement.FieldSet(_, elems2)) =>
elems1.size - elems2.size
case (FormElement.Dropdown(_, opts1), FormElement.Dropdown(_, opts2)) =>
opts1.size - opts2.size
case _ => 0
}
}
25 changes: 23 additions & 2 deletions testcases/src/main/scala/main.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,28 @@ def printsWeirdGADT(g: WeirdGADT[String]): String = g match
case SomeValue(value) => s"SomeValue($value)"
case SomeOtherValue(value, value2) => s"SomeOtherValue($value, $value2)"

// This loops forever
def concatAll(elems: List[String]): String =
elems.mkString

enum IntTree:
case Leaf
case Node(left: IntTree, value: Int, right: IntTree)

def isInTree(elem: Int, tree: IntTree): Boolean = tree match
case IntTree.Leaf => false
case IntTree.Node(left, value, right) =>
value == elem || isInTree(elem, left) || isInTree(elem, right)

// Can't be handled right now
extension (elem: Int)
def isInTreeExt(tree: IntTree): Boolean = tree match
case IntTree.Leaf => false
case IntTree.Node(left, value, right) =>
value == elem || elem.isInTreeExt(left) || elem.isInTreeExt(right)

def addManyParamLists(a: Int)(b: Int): Int =
a + b

@main
def run: Unit =
guinep.web(
Expand All @@ -86,6 +104,9 @@ def run: Unit =
nameWithPossiblePrefix1,
roll20,
roll6(),
concatAll,
isInTree,
// isInTreeExt
// addManyParamLists
// printsWeirdGADT
// concatAll
)
Loading

0 comments on commit 14c8caf

Please sign in to comment.