Skip to content

Commit e92a50f

Browse files
committed
cleanup
1 parent f5da463 commit e92a50f

File tree

1 file changed

+60
-77
lines changed

1 file changed

+60
-77
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/Deforestation.scala

Lines changed: 60 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,9 @@ type StratVar
1414
type StratVarId = Uid[StratVar]
1515
type ClsOrModSymbol = ClassLikeSymbol
1616

17-
sealed abstract class Strat
17+
sealed abstract class ProdStrat
1818

19-
sealed abstract class ProdStrat extends Strat
20-
21-
sealed abstract class ConsStrat extends Strat
19+
sealed abstract class ConsStrat
2220

2321
class StratVarState(val uid: StratVarId, val name: Str = ""):
2422
lazy val asProdStrat = ProdVar(this)
@@ -325,8 +323,6 @@ class Deforest(using TL, Raise, Elaborator.State):
325323
else
326324
S(p) -> s"0 fusion opportunity" -> 0
327325

328-
// these are never considered as free vars (because of their symbol type)
329-
// consider TopLevelSym, BlockMemberSymbols and BuiltInSyms as globally defined...
330326
object globallyDefinedVars:
331327
val store = mutable.Set.from[Symbol](State.globalThisSymbol ::State.runtimeSymbol :: Nil)
332328

@@ -391,7 +387,7 @@ class Deforest(using TL, Raise, Elaborator.State):
391387
def constrain(p: ProdStrat, c: ConsStrat) = constraints ::= p -> c
392388

393389
def processBlock(b: Block)(using
394-
inArm: LinkedHashMap[ProdVar, ClsOrModSymbol] = LinkedHashMap.empty[ProdVar, ClsOrModSymbol],
390+
inArm: Map[ProdVar, ClsOrModSymbol] = Map.empty[ProdVar, ClsOrModSymbol],
395391
matching: LinkedHashMap[ResultId, ClsOrModSymbol] = LinkedHashMap.empty[ResultId, ClsOrModSymbol]
396392
): ProdStrat = b match
397393
case m@Match(scrut, arms, dflt, rest) =>
@@ -443,7 +439,7 @@ class Deforest(using TL, Raise, Elaborator.State):
443439
freshVar("throw")._1
444440

445441
def constrFun(params: Ls[Param], body: Block)(using
446-
inArm: LinkedHashMap[ProdVar, ClsOrModSymbol],
442+
inArm: Map[ProdVar, ClsOrModSymbol],
447443
matching: LinkedHashMap[ResultId, ClsOrModSymbol]
448444
) =
449445
val paramSyms = params.map:
@@ -455,7 +451,7 @@ class Deforest(using TL, Raise, Elaborator.State):
455451
ProdFun(paramStrats.map(s => s.asConsStrat), res._1)
456452

457453
def processResult(r: Result)(using
458-
inArm: LinkedHashMap[ProdVar, ClsOrModSymbol],
454+
inArm: Map[ProdVar, ClsOrModSymbol],
459455
matching: LinkedHashMap[ResultId, ClsOrModSymbol]
460456
): ProdStrat =
461457
def handleCallLike(f: Path, args: Ls[Path], c: Result) =
@@ -636,7 +632,6 @@ class Deforest(using TL, Raise, Elaborator.State):
636632
// ======== after resolving constraints ======
637633

638634
lazy val resolveClashes =
639-
640635
val ctorToDtor = ctorDests.ctorDests
641636
val dtorToCtor = dtorSources.dtorSources
642637

@@ -647,84 +642,72 @@ class Deforest(using TL, Raise, Elaborator.State):
647642
val toDeleteDtors = rm.flatMap(r => ctorToDtor.remove(r)).flatMap:
648643
case CtorDest(mat, sels, _) => mat.keySet.map(s => DtorExpr.Match(s)) ++ sels.map(s => DtorExpr.Sel(s.expr))
649644
removeDtor(toDeleteDtors)
650-
// val (newCtorDests, toDelete) = ctorDests.partition(c => !rm(c._1))
651-
// removeDtor(newCtorDests, dtorSources, toDelete.values.flatMap[DtorExpr]{ case CtorDest(mat, sels, _) =>
652-
// mat.keySet.map(s => DtorExpr.Match(s)) ++ sels.map(s => DtorExpr.Sel(s.expr))
653-
// }.toSet)
654645

655646
def removeDtor(rm: Set[DtorExpr]): Unit =
656647
if rm.isEmpty then ()
657648
else
658649
tl.log("rm dtor: " + rm.mkString(" | "))
659-
// val (newDtorSources, toDelete) = dtorSources.partition(d => !rm(d._1))
660650
val toDeleteCtors = rm.flatMap(r => dtorToCtor.remove(r)).flatMap(_.ctors)
661651
removeCtor(toDeleteCtors)
662652

663-
val removeClashes =
664-
removeCtor(
665-
ctorToDtor.filterNot { case _ -> CtorDest(dtors, sels, noCons) =>
666-
((dtors.size == 0 && sels.size == 1)
667-
|| (dtors.size == 1 && {
668-
val scrutRef@Value.Ref(scrut) = dtors.head._1.getResult
669-
sels.forall { s => s.expr.getResult match
670-
case Select(Value.Ref(l), nme) => (l === scrut) && s.inMatching.contains(scrutRef.uid) // need to be in the matching arms, and checking the scrutinee
671-
case _ => false }
672-
}))
673-
&& !noCons
674-
}.keySet.toSet
675-
)
676-
removeDtor(dtorToCtor.filter(_._2.noProd).keySet.toSet)
677-
653+
// remove clashes:
654+
removeCtor(
655+
ctorToDtor.filterNot { case _ -> CtorDest(dtors, sels, noCons) =>
656+
((dtors.size == 0 && sels.size == 1)
657+
|| (dtors.size == 1 && {
658+
val scrutRef@Value.Ref(scrut) = dtors.head._1.getResult
659+
sels.forall { s => s.expr.getResult match
660+
case Select(Value.Ref(l), nme) => (l === scrut) && s.inMatching.contains(scrutRef.uid) // need to be in the matching arms, and checking the scrutinee
661+
case _ => false }
662+
}))
663+
&& !noCons
664+
}.keySet.toSet
665+
)
666+
removeDtor(dtorToCtor.filter(_._2.noProd).keySet.toSet)
678667

679-
val removeCycle = {
680-
def getCtorInArm(ctor: ResultId, dtor: Match): Set[ResultId] =
681-
val ctorSym = getClsSymOfUid(ctor)
682-
val arm = dtor.arms.find{ case (Case.Cls(c1, _) -> body) => c1 === ctorSym }.map(_._2).orElse(dtor.dflt).get
683-
684-
object GetCtorsTraverser extends BlockTraverser:
685-
val ctors = mutable.Set.empty[ResultId]
686-
override def applyResult(r: Result): Unit =
687-
r.uid.handleCtorIds{ (id, f, clsOrMod, args) =>
688-
ctors += id
689-
args.foreach { case Arg(_, value) => applyResult(value) }
690-
} match
691-
case Some(_) => ()
692-
case None => r match
693-
case Call(_, args) =>
694-
args.foreach { case Arg(_, value) => applyResult(value) }
695-
case Instantiate(cls, args) =>
696-
args.foreach(applyResult)
697-
case _ => ()
698-
699-
GetCtorsTraverser.applyBlock(arm)
700-
GetCtorsTraverser.ctors.toSet
701-
702-
def findCycle(ctor: ResultId, dtor: Match): Set[ResultId] =
703-
val cache = mutable.Set(ctor)
704-
def go(ctorAndMatches: Set[ResultId -> Match]): Set[ResultId] =
705-
val newCtorsAndNewMatches =
706-
ctorAndMatches.flatMap((c, m) => getCtorInArm(c, m)).flatMap: c =>
707-
ctorToDtor.get(c).flatMap:
708-
case CtorDest(matches, sels, false) => matches.values.headOption.map(m => c -> m)
709-
val cycled = newCtorsAndNewMatches.filter(c => !cache.add(c._1))
710-
if newCtorsAndNewMatches.isEmpty then
711-
Set.empty
712-
else if cycled.nonEmpty then
713-
cycled.map(_._1)
714-
else
715-
go(newCtorsAndNewMatches)
716-
go(Set(ctor -> dtor))
717-
718-
val toRmCtor = ctorToDtor.flatMap:
719-
case (c, CtorDest(matches, sels, false)) =>
720-
assert(matches.size <= 1)
721-
matches.values.flatMap(m => findCycle(c, m))
668+
// remove cycle:
669+
def getCtorInArm(ctor: ResultId, dtor: Match): Set[ResultId] =
670+
val ctorSym = getClsSymOfUid(ctor)
671+
val arm = dtor.arms.find{ case (Case.Cls(c1, _) -> body) => c1 === ctorSym }.map(_._2).orElse(dtor.dflt).get
722672

723-
removeCtor(toRmCtor.toSet)
724-
}
673+
object GetCtorsTraverser extends BlockTraverser:
674+
val ctors = mutable.Set.empty[ResultId]
675+
override def applyResult(r: Result): Unit =
676+
r.uid.handleCtorIds{ (id, f, clsOrMod, args) =>
677+
ctors += id
678+
args.foreach { case Arg(_, value) => applyResult(value) }
679+
} match
680+
case Some(_) => ()
681+
case None => r match
682+
case Call(_, args) =>
683+
args.foreach { case Arg(_, value) => applyResult(value) }
684+
case Instantiate(cls, args) =>
685+
args.foreach(applyResult)
686+
case _ => ()
687+
688+
GetCtorsTraverser.applyBlock(arm)
689+
GetCtorsTraverser.ctors.toSet
690+
def findCycle(ctor: ResultId, dtor: Match): Set[ResultId] =
691+
val cache = mutable.Set(ctor)
692+
def go(ctorAndMatches: Set[ResultId -> Match]): Set[ResultId] =
693+
val newCtorsAndNewMatches =
694+
ctorAndMatches.flatMap((c, m) => getCtorInArm(c, m)).flatMap: c =>
695+
ctorToDtor.get(c).flatMap:
696+
case CtorDest(matches, sels, false) => matches.values.headOption.map(m => c -> m)
697+
val cycled = newCtorsAndNewMatches.filter(c => !cache.add(c._1))
698+
if newCtorsAndNewMatches.isEmpty then
699+
Set.empty
700+
else if cycled.nonEmpty then
701+
cycled.map(_._1)
702+
else
703+
go(newCtorsAndNewMatches)
704+
go(Set(ctor -> dtor))
705+
val toRmCtor = ctorToDtor.flatMap:
706+
case (c, CtorDest(matches, sels, false)) =>
707+
assert(matches.size <= 1)
708+
matches.values.flatMap(m => findCycle(c, m))
709+
removeCtor(toRmCtor.toSet)
725710

726-
val finalRes = removeCycle
727-
// finalRes
728711
ctorToDtor -> dtorToCtor
729712

730713

0 commit comments

Comments
 (0)