Skip to content

Commit

Permalink
Shorten transitive hidden sets and replace cycles by aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
odersky committed Feb 7, 2025
1 parent 91c819a commit 502cb7d
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 17 deletions.
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/cc/CaptureRef.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,8 @@ trait CaptureRef extends TypeProxy, ValueType:
* fail a comparison.
*/
def maxSubsumes(y: CaptureRef, canAddHidden: Boolean)(using ctx: Context, vs: VarState = VarState.Separate): Boolean =
this.match
(this eq y)
|| this.match
case Fresh.Cap(hidden) =>
vs.ifNotSeen(this)(hidden.elems.exists(_.subsumes(y)))
|| !y.stripReadOnly.isCap && canAddHidden && vs.addHidden(hidden, y)
Expand Down
84 changes: 77 additions & 7 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ sealed abstract class CaptureSet extends Showable:
*/
protected def addThisElem(elem: CaptureRef)(using Context, VarState): CompareResult

protected def addHiddenElem(elem: CaptureRef)(using ctx: Context, vs: VarState): CompareResult =
protected def addIfHiddenOrFail(elem: CaptureRef)(using ctx: Context, vs: VarState): CompareResult =
if elems.exists(_.maxSubsumes(elem, canAddHidden = true))
then CompareResult.OK
else CompareResult.Fail(this :: Nil)
Expand Down Expand Up @@ -438,7 +438,7 @@ object CaptureSet:
def isAlwaysEmpty = elems.isEmpty

def addThisElem(elem: CaptureRef)(using Context, VarState): CompareResult =
addHiddenElem(elem)
addIfHiddenOrFail(elem)

def addDependent(cs: CaptureSet)(using Context, VarState) = CompareResult.OK

Expand Down Expand Up @@ -487,7 +487,10 @@ object CaptureSet:
private var isSolved: Boolean = false

/** The elements currently known to be in the set */
var elems: Refs = initialElems
protected var myElems: Refs = initialElems

def elems: Refs = myElems
def elems_=(refs: Refs): Unit = myElems = refs

/** The sets currently known to be dependent sets (i.e. new additions to this set
* are propagated to these dependent sets.)
Expand Down Expand Up @@ -535,7 +538,7 @@ object CaptureSet:

final def addThisElem(elem: CaptureRef)(using Context, VarState): CompareResult =
if isConst || !recordElemsState() then // Fail if variable is solved or given VarState is frozen
addHiddenElem(elem)
addIfHiddenOrFail(elem)
else if Existential.isBadExistential(elem) then // Fail if `elem` is an out-of-scope existential
CompareResult.Fail(this :: Nil)
else if !levelOK(elem) then
Expand Down Expand Up @@ -925,10 +928,75 @@ object CaptureSet:
def elemIntersection(cs1: CaptureSet, cs2: CaptureSet)(using Context): Refs =
cs1.elems.filter(cs2.mightAccountFor) ++ cs2.elems.filter(cs1.mightAccountFor)

/** A capture set variable used to record the references hidden by a Fresh.Cap instance */
/** A capture set variable used to record the references hidden by a Fresh.Cap instance,
* The elems and deps members are repurposed as follows:
* elems: Set of hidden references
* deps : Set of hidden sets for which the Fresh.Cap instance owning this set
* is a hidden element.
* Hidden sets may become aliases of other hidden sets, which means that
* reads and writes of elems go to the alias.
* If H is an alias of R.hidden for some Fresh.Cap R then:
* H.elems == {R}
* H.deps = {R.hidden}
* This encoding was chosen because it relies only on the elems and deps fields
* which are already subject through snapshotting and rollbacks in VarState.
* It's advantageous if we don't need to deal with other pieces of state there.
*/
class HiddenSet(initialHidden: Refs = emptyRefs)(using @constructorOnly ictx: Context)
extends Var(initialElems = initialHidden):

private def aliasRef: AnnotatedType | Null =
if myElems.size == 1 then
myElems.nth(0) match
case al @ Fresh.Cap(hidden) if deps.contains(hidden) => al
case _ => null
else null

private def aliasSet: HiddenSet =
if myElems.size == 1 then
myElems.nth(0) match
case al @ Fresh.Cap(hidden) if deps.contains(hidden) => hidden
case _ => this
else this

override def elems: Refs =
val al = aliasSet
if al eq this then super.elems else al.elems

override def elems_=(refs: Refs) =
val al = aliasSet
if al eq this then super.elems_=(refs) else al.elems_=(refs)

/** Add element to hidden set. Also add it to all supersets (as indicated by
* deps of this set). Follow aliases on both hidden set and added element
* before adding. If the added element is also a Fresh.Cap instance with
* hidden set H which is a superset of this set, then make this set an
* alias of H.
*/
def add(elem: CaptureRef)(using ctx: Context, vs: VarState): Unit =
val alias = aliasSet
if alias ne this then alias.add(elem)
else
def addToElems() =
elems += elem
deps.foreach: dep =>
assert(dep != this)
vs.addHidden(dep.asInstanceOf[HiddenSet], elem)
elem match
case Fresh.Cap(hidden) =>
if this ne hidden then
if hidden.aliasRef != null then
add(hidden.aliasRef.nn)
else if deps.contains(hidden) then // make this an alias of elem
println(i"ALIAS $this to $hidden")
elems = SimpleIdentitySet(elem)
deps = SimpleIdentitySet(hidden)
else
addToElems()
hidden.deps += this
case _ =>
addToElems()

/** Apply function `f` to `elems` while setting `elems` to empty for the
* duration. This is used to escape infinite recursions if two Fresh.Caps
* refer to each other in their hidden sets.
Expand Down Expand Up @@ -1075,9 +1143,11 @@ object CaptureSet:
*/
def addHidden(hidden: HiddenSet, elem: CaptureRef)(using Context): Boolean =
elemsMap.get(hidden) match
case None => elemsMap(hidden) = hidden.elems
case None =>
elemsMap(hidden) = hidden.elems
depsMap(hidden) = hidden.deps
case _ =>
hidden.elems += elem
hidden.add(elem)(using ctx, this)
true

/** Roll back global state to what was recorded in this VarState */
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/cc/Fresh.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ object Fresh:
def apply(owner: Symbol)(using Context): CaptureRef =
apply(ownerToHidden(owner, reach = false))

def unapply(tp: AnnotatedType)(using Context): Option[CaptureSet.HiddenSet] = tp.annot match
def unapply(tp: AnnotatedType): Option[CaptureSet.HiddenSet] = tp.annot match
case Annot(hidden) => Some(hidden)
case _ => None
end Cap
Expand Down
22 changes: 14 additions & 8 deletions compiler/src/dotty/tools/dotc/util/SimpleIdentitySet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ abstract class SimpleIdentitySet[+Elem <: AnyRef] {
acc
def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A
def toList: List[Elem]
def iterator: Iterator[Elem]
def nth(n: Int): Elem

final def isEmpty: Boolean = size == 0

final def iterator: Iterator[Elem] = Iterator.tabulate(size)(nth)

def forall[E >: Elem <: AnyRef](p: E => Boolean): Boolean = !exists(!p(_))

def filter(p: Elem => Boolean): SimpleIdentitySet[Elem] =
Expand Down Expand Up @@ -74,7 +76,7 @@ object SimpleIdentitySet {
override def map[B <: AnyRef](f: Nothing => B): SimpleIdentitySet[B] = empty
def /: [A, E <: AnyRef](z: A)(f: (A, E) => A): A = z
def toList = Nil
def iterator = Iterator.empty
def nth(n: Int): Nothing = throw new IndexOutOfBoundsException(n)
}

private class Set1[+Elem <: AnyRef](x0: AnyRef) extends SimpleIdentitySet[Elem] {
Expand All @@ -92,7 +94,9 @@ object SimpleIdentitySet {
def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A =
f(z, x0.asInstanceOf[E])
def toList = x0.asInstanceOf[Elem] :: Nil
def iterator = Iterator.single(x0.asInstanceOf[Elem])
def nth(n: Int) =
if n == 0 then x0.asInstanceOf[Elem]
else throw new IndexOutOfBoundsException(n)
}

private class Set2[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef) extends SimpleIdentitySet[Elem] {
Expand All @@ -114,10 +118,10 @@ object SimpleIdentitySet {
def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A =
f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E])
def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: Nil
def iterator = Iterator.tabulate(2) {
def nth(n: Int) = n match
case 0 => x0.asInstanceOf[Elem]
case 1 => x1.asInstanceOf[Elem]
}
case _ => throw new IndexOutOfBoundsException(n)
}

private class Set3[+Elem <: AnyRef](x0: AnyRef, x1: AnyRef, x2: AnyRef) extends SimpleIdentitySet[Elem] {
Expand Down Expand Up @@ -154,11 +158,11 @@ object SimpleIdentitySet {
def /: [A, E >: Elem <: AnyRef](z: A)(f: (A, E) => A): A =
f(f(f(z, x0.asInstanceOf[E]), x1.asInstanceOf[E]), x2.asInstanceOf[E])
def toList = x0.asInstanceOf[Elem] :: x1.asInstanceOf[Elem] :: x2.asInstanceOf[Elem] :: Nil
def iterator = Iterator.tabulate(3) {
def nth(n: Int) = n match
case 0 => x0.asInstanceOf[Elem]
case 1 => x1.asInstanceOf[Elem]
case 2 => x2.asInstanceOf[Elem]
}
case _ => throw new IndexOutOfBoundsException(n)
}

private class SetN[+Elem <: AnyRef](val xs: Array[AnyRef]) extends SimpleIdentitySet[Elem] {
Expand Down Expand Up @@ -205,7 +209,9 @@ object SimpleIdentitySet {
foreach(buf += _)
buf.toList
}
def iterator = xs.iterator.asInstanceOf[Iterator[Elem]]
def nth(n: Int) =
if 0 <= n && n < size then xs(n).asInstanceOf[Elem]
else throw new IndexOutOfBoundsException
override def ++ [E >: Elem <: AnyRef](that: SimpleIdentitySet[E]): SimpleIdentitySet[E] =
that match {
case that: SetN[?] =>
Expand Down

0 comments on commit 502cb7d

Please sign in to comment.