@@ -14,11 +14,9 @@ type StratVar
14
14
type StratVarId = Uid [StratVar ]
15
15
type ClsOrModSymbol = ClassLikeSymbol
16
16
17
- sealed abstract class Strat
17
+ sealed abstract class ProdStrat
18
18
19
- sealed abstract class ProdStrat extends Strat
20
-
21
- sealed abstract class ConsStrat extends Strat
19
+ sealed abstract class ConsStrat
22
20
23
21
class StratVarState (val uid : StratVarId , val name : Str = " " ):
24
22
lazy val asProdStrat = ProdVar (this )
@@ -325,8 +323,6 @@ class Deforest(using TL, Raise, Elaborator.State):
325
323
else
326
324
S (p) -> s " 0 fusion opportunity " -> 0
327
325
328
- // these are never considered as free vars (because of their symbol type)
329
- // consider TopLevelSym, BlockMemberSymbols and BuiltInSyms as globally defined...
330
326
object globallyDefinedVars :
331
327
val store = mutable.Set .from[Symbol ](State .globalThisSymbol :: State .runtimeSymbol :: Nil )
332
328
@@ -391,7 +387,7 @@ class Deforest(using TL, Raise, Elaborator.State):
391
387
def constrain (p : ProdStrat , c : ConsStrat ) = constraints ::= p -> c
392
388
393
389
def processBlock (b : Block )(using
394
- inArm : LinkedHashMap [ProdVar , ClsOrModSymbol ] = LinkedHashMap .empty[ProdVar , ClsOrModSymbol ],
390
+ inArm : Map [ProdVar , ClsOrModSymbol ] = Map .empty[ProdVar , ClsOrModSymbol ],
395
391
matching : LinkedHashMap [ResultId , ClsOrModSymbol ] = LinkedHashMap .empty[ResultId , ClsOrModSymbol ]
396
392
): ProdStrat = b match
397
393
case m@ Match (scrut, arms, dflt, rest) =>
@@ -443,7 +439,7 @@ class Deforest(using TL, Raise, Elaborator.State):
443
439
freshVar(" throw" )._1
444
440
445
441
def constrFun (params : Ls [Param ], body : Block )(using
446
- inArm : LinkedHashMap [ProdVar , ClsOrModSymbol ],
442
+ inArm : Map [ProdVar , ClsOrModSymbol ],
447
443
matching : LinkedHashMap [ResultId , ClsOrModSymbol ]
448
444
) =
449
445
val paramSyms = params.map:
@@ -455,7 +451,7 @@ class Deforest(using TL, Raise, Elaborator.State):
455
451
ProdFun (paramStrats.map(s => s.asConsStrat), res._1)
456
452
457
453
def processResult (r : Result )(using
458
- inArm : LinkedHashMap [ProdVar , ClsOrModSymbol ],
454
+ inArm : Map [ProdVar , ClsOrModSymbol ],
459
455
matching : LinkedHashMap [ResultId , ClsOrModSymbol ]
460
456
): ProdStrat =
461
457
def handleCallLike (f : Path , args : Ls [Path ], c : Result ) =
@@ -636,7 +632,6 @@ class Deforest(using TL, Raise, Elaborator.State):
636
632
// ======== after resolving constraints ======
637
633
638
634
lazy val resolveClashes =
639
-
640
635
val ctorToDtor = ctorDests.ctorDests
641
636
val dtorToCtor = dtorSources.dtorSources
642
637
@@ -647,84 +642,72 @@ class Deforest(using TL, Raise, Elaborator.State):
647
642
val toDeleteDtors = rm.flatMap(r => ctorToDtor.remove(r)).flatMap:
648
643
case CtorDest (mat, sels, _) => mat.keySet.map(s => DtorExpr .Match (s)) ++ sels.map(s => DtorExpr .Sel (s.expr))
649
644
removeDtor(toDeleteDtors)
650
- // val (newCtorDests, toDelete) = ctorDests.partition(c => !rm(c._1))
651
- // removeDtor(newCtorDests, dtorSources, toDelete.values.flatMap[DtorExpr]{ case CtorDest(mat, sels, _) =>
652
- // mat.keySet.map(s => DtorExpr.Match(s)) ++ sels.map(s => DtorExpr.Sel(s.expr))
653
- // }.toSet)
654
645
655
646
def removeDtor (rm : Set [DtorExpr ]): Unit =
656
647
if rm.isEmpty then ()
657
648
else
658
649
tl.log(" rm dtor: " + rm.mkString(" | " ))
659
- // val (newDtorSources, toDelete) = dtorSources.partition(d => !rm(d._1))
660
650
val toDeleteCtors = rm.flatMap(r => dtorToCtor.remove(r)).flatMap(_.ctors)
661
651
removeCtor(toDeleteCtors)
662
652
663
- val removeClashes =
664
- removeCtor(
665
- ctorToDtor.filterNot { case _ -> CtorDest (dtors, sels, noCons) =>
666
- ((dtors.size == 0 && sels.size == 1 )
667
- || (dtors.size == 1 && {
668
- val scrutRef @ Value .Ref (scrut) = dtors.head._1.getResult
669
- sels.forall { s => s.expr.getResult match
670
- case Select (Value .Ref (l), nme) => (l === scrut) && s.inMatching.contains(scrutRef.uid) // need to be in the matching arms, and checking the scrutinee
671
- case _ => false }
672
- }))
673
- && ! noCons
674
- }.keySet.toSet
675
- )
676
- removeDtor(dtorToCtor.filter(_._2.noProd).keySet.toSet)
677
-
653
+ // remove clashes:
654
+ removeCtor(
655
+ ctorToDtor.filterNot { case _ -> CtorDest (dtors, sels, noCons) =>
656
+ ((dtors.size == 0 && sels.size == 1 )
657
+ || (dtors.size == 1 && {
658
+ val scrutRef @ Value .Ref (scrut) = dtors.head._1.getResult
659
+ sels.forall { s => s.expr.getResult match
660
+ case Select (Value .Ref (l), nme) => (l === scrut) && s.inMatching.contains(scrutRef.uid) // need to be in the matching arms, and checking the scrutinee
661
+ case _ => false }
662
+ }))
663
+ && ! noCons
664
+ }.keySet.toSet
665
+ )
666
+ removeDtor(dtorToCtor.filter(_._2.noProd).keySet.toSet)
678
667
679
- val removeCycle = {
680
- def getCtorInArm (ctor : ResultId , dtor : Match ): Set [ResultId ] =
681
- val ctorSym = getClsSymOfUid(ctor)
682
- val arm = dtor.arms.find{ case (Case .Cls (c1, _) -> body) => c1 === ctorSym }.map(_._2).orElse(dtor.dflt).get
683
-
684
- object GetCtorsTraverser extends BlockTraverser :
685
- val ctors = mutable.Set .empty[ResultId ]
686
- override def applyResult (r : Result ): Unit =
687
- r.uid.handleCtorIds{ (id, f, clsOrMod, args) =>
688
- ctors += id
689
- args.foreach { case Arg (_, value) => applyResult(value) }
690
- } match
691
- case Some (_) => ()
692
- case None => r match
693
- case Call (_, args) =>
694
- args.foreach { case Arg (_, value) => applyResult(value) }
695
- case Instantiate (cls, args) =>
696
- args.foreach(applyResult)
697
- case _ => ()
698
-
699
- GetCtorsTraverser .applyBlock(arm)
700
- GetCtorsTraverser .ctors.toSet
701
-
702
- def findCycle (ctor : ResultId , dtor : Match ): Set [ResultId ] =
703
- val cache = mutable.Set (ctor)
704
- def go (ctorAndMatches : Set [ResultId -> Match ]): Set [ResultId ] =
705
- val newCtorsAndNewMatches =
706
- ctorAndMatches.flatMap((c, m) => getCtorInArm(c, m)).flatMap: c =>
707
- ctorToDtor.get(c).flatMap:
708
- case CtorDest (matches, sels, false ) => matches.values.headOption.map(m => c -> m)
709
- val cycled = newCtorsAndNewMatches.filter(c => ! cache.add(c._1))
710
- if newCtorsAndNewMatches.isEmpty then
711
- Set .empty
712
- else if cycled.nonEmpty then
713
- cycled.map(_._1)
714
- else
715
- go(newCtorsAndNewMatches)
716
- go(Set (ctor -> dtor))
717
-
718
- val toRmCtor = ctorToDtor.flatMap:
719
- case (c, CtorDest (matches, sels, false )) =>
720
- assert(matches.size <= 1 )
721
- matches.values.flatMap(m => findCycle(c, m))
668
+ // remove cycle:
669
+ def getCtorInArm (ctor : ResultId , dtor : Match ): Set [ResultId ] =
670
+ val ctorSym = getClsSymOfUid(ctor)
671
+ val arm = dtor.arms.find{ case (Case .Cls (c1, _) -> body) => c1 === ctorSym }.map(_._2).orElse(dtor.dflt).get
722
672
723
- removeCtor(toRmCtor.toSet)
724
- }
673
+ object GetCtorsTraverser extends BlockTraverser :
674
+ val ctors = mutable.Set .empty[ResultId ]
675
+ override def applyResult (r : Result ): Unit =
676
+ r.uid.handleCtorIds{ (id, f, clsOrMod, args) =>
677
+ ctors += id
678
+ args.foreach { case Arg (_, value) => applyResult(value) }
679
+ } match
680
+ case Some (_) => ()
681
+ case None => r match
682
+ case Call (_, args) =>
683
+ args.foreach { case Arg (_, value) => applyResult(value) }
684
+ case Instantiate (cls, args) =>
685
+ args.foreach(applyResult)
686
+ case _ => ()
687
+
688
+ GetCtorsTraverser .applyBlock(arm)
689
+ GetCtorsTraverser .ctors.toSet
690
+ def findCycle (ctor : ResultId , dtor : Match ): Set [ResultId ] =
691
+ val cache = mutable.Set (ctor)
692
+ def go (ctorAndMatches : Set [ResultId -> Match ]): Set [ResultId ] =
693
+ val newCtorsAndNewMatches =
694
+ ctorAndMatches.flatMap((c, m) => getCtorInArm(c, m)).flatMap: c =>
695
+ ctorToDtor.get(c).flatMap:
696
+ case CtorDest (matches, sels, false ) => matches.values.headOption.map(m => c -> m)
697
+ val cycled = newCtorsAndNewMatches.filter(c => ! cache.add(c._1))
698
+ if newCtorsAndNewMatches.isEmpty then
699
+ Set .empty
700
+ else if cycled.nonEmpty then
701
+ cycled.map(_._1)
702
+ else
703
+ go(newCtorsAndNewMatches)
704
+ go(Set (ctor -> dtor))
705
+ val toRmCtor = ctorToDtor.flatMap:
706
+ case (c, CtorDest (matches, sels, false )) =>
707
+ assert(matches.size <= 1 )
708
+ matches.values.flatMap(m => findCycle(c, m))
709
+ removeCtor(toRmCtor.toSet)
725
710
726
- val finalRes = removeCycle
727
- // finalRes
728
711
ctorToDtor -> dtorToCtor
729
712
730
713
0 commit comments