Skip to content

Commit

Permalink
Merge pull request #72 from TinkoffCreditSystems/classy-optics
Browse files Browse the repository at this point in the history
Classy Optics
  • Loading branch information
Odomontois authored Dec 18, 2019
2 parents 41d6e02 + a4704d9 commit ec00b41
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 47 deletions.
2 changes: 1 addition & 1 deletion optics/core/src/main/scala/tofu/optics/Folded.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import tofu.optics.data.Constant

/** S has some or none occurrences of A
* and can collect them */
trait PFolded[-S, +T, +A, -B] {
trait PFolded[-S, +T, +A, -B] extends PBase[S, T, A, B]{
def foldMap[X: Monoid](s: S)(f: A => X): X

def getAll(s: S): List[A] = foldMap(s)(List(_))
Expand Down
5 changes: 5 additions & 0 deletions optics/core/src/main/scala/tofu/optics/PBase.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package tofu.optics

trait PBase[-S, +T, +A, -B] {
def label[label]: this.type with Label[label] = this.asInstanceOf[this.type with Label[label]]
}
4 changes: 2 additions & 2 deletions optics/core/src/main/scala/tofu/optics/Upcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import cats.{Functor, Id}
import tofu.optics.classes.PChoice
import tofu.optics.data.{Identity, Tagged}

trait PUpcast[-S, +T, +A, -B] {
trait PUpcast[-S, +T, +A, -B] extends PBase[S, T, A, B] {
def upcast(b: B): T
}

Expand All @@ -18,7 +18,7 @@ object PUpcast extends OpticCompanion[PUpcast] {

class Context extends PSubset.Context {
override type P[-x, +y] = Tagged[x, y]
type F[+x] = x
type F[+x] = x
def pure = Pure[Id]
def profunctor = PChoice[Tagged]
def functor = Functor[Identity]
Expand Down
2 changes: 1 addition & 1 deletion optics/core/src/main/scala/tofu/optics/Update.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import tofu.optics.data.Identity

/** aka Setter
* can update all occurrences of A in S */
trait PUpdate[-S, +T, +A, -B] {
trait PUpdate[-S, +T, +A, -B] extends PBase[S, T, A, B] {
def update(s: S, fb: A => B): T

def put(s: S, b: B) = update(s, _ => b)
Expand Down
2 changes: 2 additions & 0 deletions optics/core/src/main/scala/tofu/optics/optics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,5 @@ object Optic {
}




6 changes: 6 additions & 0 deletions optics/core/src/main/scala/tofu/optics/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,10 @@ package object optics {
type Extract[A, B] = PExtract[A, A, B, B]
type Folded[A, B] = PFolded[A, A, B, B]
type Update[A, B] = PUpdate[A, A, B, B]

/** label provider for instance discrimination
* like Contains[A, B] with Label["first"] */
type Label[label] = Any {
type Label = label
}
}
138 changes: 96 additions & 42 deletions optics/macro/src/main/scala/tofu/optics/macros/Optics.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package tofu.optics.macros

import tofu.optics.PContains

import scala.reflect.macros.blackbox

class Optics(val prefix: String = "") extends scala.annotation.StaticAnnotation {
Expand All @@ -10,84 +12,136 @@ class POptics(val prefix: String = "") extends scala.annotation.StaticAnnotation
def macroTransform(annottees: Any*): Any = macro OpticsImpl.popticsAnnotationMacro
}

class ClassyOptics(val prefix: String = "") extends scala.annotation.StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro OpticsImpl.classyOpticsAnnotationMacro
}

class ClassyPOptics(val prefix: String = "") extends scala.annotation.StaticAnnotation {
def macroTransform(annottees: Any*): Any = macro OpticsImpl.classyPopticsAnnotationMacro
}

private[macros] class OpticsImpl(val c: blackbox.Context) {

def opticsAnnotationMacro(annottees: c.Expr[Any]*): c.Expr[Any] = annotationMacro(annottees, poly = false)
def opticsAnnotationMacro(annottees: c.Expr[Any]*): c.Expr[Any] =
annotationMacro(annottees, poly = false, classy = false)
def classyOpticsAnnotationMacro(annottees: c.Expr[Any]*): c.Expr[Any] =
annotationMacro(annottees, poly = false, classy = true)

def popticsAnnotationMacro(annottees: c.Expr[Any]*): c.Expr[Any] = annotationMacro(annottees, poly = true)
def popticsAnnotationMacro(annottees: c.Expr[Any]*): c.Expr[Any] =
annotationMacro(annottees, poly = true, classy = false)
def classyPopticsAnnotationMacro(annottees: c.Expr[Any]*): c.Expr[Any] =
annotationMacro(annottees, poly = true, classy = true)

def annotationMacro(annottees: Seq[c.Expr[Any]], poly: Boolean): c.Expr[Any] = {
def annotationMacro(annottees: Seq[c.Expr[Any]], poly: Boolean, classy: Boolean): c.Expr[Any] = {
import c.universe._

val LensesTpe = TypeName(if (poly) "POptics" else "Optics")
lazy val PContainsC = typeOf[PContains[_, _, _, _]].typeConstructor.typeSymbol

val LensesTpe = TypeName((poly, classy) match {
case (false, false) => "Optics"
case (false, true) => "ClassyOptics"
case (true, false) => "POptics"
case (true, true) => "ClassyPOptics"
})

val prefix = c.macroApplication match {
case Apply(Select(Apply(Select(New(Ident(LensesTpe)), t), args), _), _) if t == termNames.CONSTRUCTOR => args match {
case Literal(Constant(s: String)) :: Nil => s
case _ => ""
}
case Apply(Select(Apply(Select(New(Ident(LensesTpe)), t), args), _), _) if t == termNames.CONSTRUCTOR =>
args match {
case Literal(Constant(s: String)) :: Nil => s
case _ => ""
}
case _ => ""
}

def monolenses(tpname: TypeName, params: List[ValDef]): List[Tree] = params.map { param =>
def labelType(s: String): Type = internal.constantType(Constant(s))

def labelClass(p: ValDef, res: Tree)(s: Tree, t: Tree, a: Tree, b: Tree): (Tree, Tree) = {
val fieldName = p.name.toString
def labelT = labelType(fieldName)
def classyT: Tree = tq"$PContainsC[$s, $t, $a, $b] with _root_.tofu.optics.Label[$labelT]"
(q"$res.label[$labelT]", classyT)
}

def monolenses(tpname: TypeName, params: List[ValDef], classy: Boolean): List[Tree] = params.map { param =>
val lensName = TermName(prefix + param.name.decodedName)
q"""val $lensName =
_root_.tofu.optics.macros.internal.Macro.mkContains[$tpname, $tpname, ${param.tpt}, ${param.tpt}](${param.name.toString})"""

val res =
q"""_root_.tofu.optics.macros.internal.Macro.mkContains[$tpname, $tpname, ${param.tpt}, ${param.tpt}](${param.name.toString})"""

lazy val (resClassy, classyT) = labelClass(param, res)(tq"$tpname", tq"$tpname", param.tpt, param.tpt)

if (classy) q"implicit val $lensName : $classyT = $resClassy"
else q"val $lensName = $res"
}

def optics(tpname: TypeName, tparams: List[TypeDef], params: List[ValDef]): List[Tree] = {
def optics(tpname: TypeName, tparams: List[TypeDef], params: List[ValDef], classy: Boolean): List[Tree] = {
if (tparams.isEmpty) {
monolenses(tpname, params)
monolenses(tpname, params, classy)
} else {
params.map { param =>
val lensName = TermName(prefix + param.name.decodedName)
val q"x: $s" = q"x: $tpname[..${tparams.map(_.name)}]"
val q"x: $a" = q"x: ${param.tpt}"
q"""def $lensName[..$tparams] =
_root_.tofu.optics.macros.internal.Macro.mkContains[$s, $s, $a, $a](${param.name.toString})"""

val res = q"_root_.tofu.optics.macros.internal.Macro.mkContains[$s, $s, $a, $a](${param.name.toString})"

lazy val (resClassy, classyT) = labelClass(param, res)(s, s, a, a)
if (classy) q"implicit def $lensName[..$tparams] : $classyT = $resClassy"
q"""def $lensName[..$tparams] = $res"""
}
}
}

def poptics(tpname: TypeName, tparams: List[TypeDef], params: List[ValDef]): List[Tree] = {
def poptics(tpname: TypeName, tparams: List[TypeDef], params: List[ValDef], classy: Boolean): List[Tree] = {
if (tparams.isEmpty) {
monolenses(tpname, params)
monolenses(tpname, params, classy)
} else {
// number of fields in which each tparam is used
val tparamsUsages: Map[TypeName, Int] = params.foldLeft(tparams.map { _.name -> 0 }.toMap){ (acc, param) =>
val typeNames = param.collect{ case Ident(tn: TypeName) => tn }.toSet
typeNames.foldLeft(acc){ (map, key) => map.get(key).fold(map){ value => map.updated(key, value + 1) }}
val tparamsUsages: Map[TypeName, Int] = params.foldLeft(tparams.map { _.name -> 0 }.toMap) { (acc, param) =>
val typeNames = param.collect { case Ident(tn: TypeName) => tn }.toSet
typeNames.foldLeft(acc) { (map, key) =>
map.get(key).fold(map) { value =>
map.updated(key, value + 1)
}
}
}

val groupedTpnames: Map[Int, Set[TypeName]] =
tparamsUsages.toList.groupBy(_._2).map{ case (n, tps) => (n, tps.map(_._1).toSet) }
val phantomTpnames = groupedTpnames.getOrElse(0, Set.empty)
tparamsUsages.toList.groupBy(_._2).map { case (n, tps) => (n, tps.map(_._1).toSet) }
val phantomTpnames = groupedTpnames.getOrElse(0, Set.empty)
val singleFieldTpnames = groupedTpnames.getOrElse(1, Set.empty)

params.map { param =>
val lensName = TermName(prefix + param.name.decodedName)
val tpnames = param.collect{ case Ident(tn: TypeName) => tn }.toSet
val lensName = TermName(prefix + param.name.decodedName)
val tpnames = param.collect { case Ident(tn: TypeName) => tn }.toSet
val tpnamesToChange = tpnames.intersect(singleFieldTpnames) ++ phantomTpnames
val tpnamesMap = tpnamesToChange.foldLeft((tparams.map(_.name).toSet ++ tpnames).map(x => (x, x)).toMap){ (acc, tpname) =>
acc.updated(tpname, c.freshName(tpname))
val tpnamesMap = tpnamesToChange.foldLeft((tparams.map(_.name).toSet ++ tpnames).map(x => (x, x)).toMap) {
(acc, tpname) =>
acc.updated(tpname, c.freshName(tpname))
}
val defParams = tparams ++ tparams.filter(x => tpnamesToChange.contains(x.name)).map{
case TypeDef(mods, name, tps, rhs) => TypeDef(mods, tpnamesMap(name), tps, rhs)
}.toSet
val defParams = tparams ++ tparams
.filter(x => tpnamesToChange.contains(x.name))
.map {
case TypeDef(mods, name, tps, rhs) => TypeDef(mods, tpnamesMap(name), tps, rhs)
}
.toSet

object tptTransformer extends Transformer {
override def transform(tree: Tree): Tree = tree match {
case Ident(tn: TypeName) => Ident(tpnamesMap(tn))
case x => super.transform(x)
case x => super.transform(x)
}
}

val q"x: $s" = q"x: $tpname[..${tparams.map(_.name)}]"
val q"x: $t" = q"x: $tpname[..${tparams.map(x => tpnamesMap(x.name))}]"
val q"x: $a" = q"x: ${param.tpt}"
val q"x: $b" = q"x: ${tptTransformer.transform(param.tpt)}"
val q"x: $s" = q"x: $tpname[..${tparams.map(_.name)}]"
val q"x: $t" = q"x: $tpname[..${tparams.map(x => tpnamesMap(x.name))}]"
val q"x: $a" = q"x: ${param.tpt}"
val q"x: $b" = q"x: ${tptTransformer.transform(param.tpt)}"
val res = q"_root_.tofu.optics.macros.internal.Macro.mkContains[$s, $t, $a, $b](${param.name.toString})"
lazy val (resClassy, classyT) = labelClass(param, res)(s, t, a, b)

q"""def $lensName[..$defParams] =
_root_.tofu.optics.macros.internal.Macro.mkContains[$s, $t, $a, $b](${param.name.toString})"""
if (classy) q"implicit def $lensName[..$defParams] : $classyT = $resClassy"
else q"def $lensName[..$defParams] = $res"
}
}
}
Expand All @@ -96,21 +150,21 @@ private[macros] class OpticsImpl(val c: blackbox.Context) {

val result = annottees map (_.tree) match {
case (classDef @ q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }")
:: Nil if mods.hasFlag(Flag.CASE) =>
:: Nil if mods.hasFlag(Flag.CASE) =>
val name = tpname.toTermName
q"""
$classDef
object $name {
..${lensDefs(tpname, tparams, paramss.head)}
..${lensDefs(tpname, tparams, paramss.head, classy)}
}
"""
case (classDef @ q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }")
:: q"$objMods object $objName extends { ..$objEarlyDefs } with ..$objParents { $objSelf => ..$objDefs }"
:: Nil if mods.hasFlag(Flag.CASE) =>
:: q"$objMods object $objName extends { ..$objEarlyDefs } with ..$objParents { $objSelf => ..$objDefs }"
:: Nil if mods.hasFlag(Flag.CASE) =>
q"""
$classDef
$objMods object $objName extends { ..$objEarlyDefs} with ..$objParents { $objSelf =>
..${lensDefs(tpname, tparams, paramss.head)}
..${lensDefs(tpname, tparams, paramss.head, classy)}
..$objDefs
}
"""
Expand All @@ -119,4 +173,4 @@ private[macros] class OpticsImpl(val c: blackbox.Context) {

c.Expr[Any](result)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package tofu.optics.macros
import tofu.optics.{Contains, Label, PContains, Update}

object TestClassyContains {
// compile test for searching classy optics
implicitly[Contains[FooBar2, Int] with Label["i"]]
implicitly[Contains[FooBar2, String] with Label["j"]]
implicitly[Contains[FooBar4[Double], Int] with Label["i"]]
implicitly[Contains[FooBar4[Double], String] with Label["j"]]
implicitly[Contains[FooBar4[Double], String]]
implicitly[Update[FooBar4[Double], String]]
implicitly[PContains[FooBar4[Double], FooBar4[Long], Double, Long] with Label["x"]]
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package tofu.optics.macros

import org.scalatest.{FunSuite, Matchers}
import tofu.optics.{Contains, Label, PContains}

class GenContainsSpec extends FunSuite with Matchers {

test("Nested GenContains") {
val sut = GenContains[Foo](_.b.i)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,8 @@ case class Bar(i: Int)

case class Foo(b: Bar)

@Optics("contains_") case class FooBar(i: Int, j: String)
@Optics("contains_") case class FooBar(i: Int, j: String)
@ClassyOptics("contains_") case class FooBar2(i: Int, j: String)

@POptics("contains_") case class FooBar3[X](i: Int, j: String, x: X)
@ClassyPOptics("contains_") case class FooBar4[X](i: Int, j: String, x: X)

0 comments on commit ec00b41

Please sign in to comment.