@@ -535,7 +535,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
535
535
_assignedTo = value
536
536
}
537
537
538
- var tsc : Opt [(TupleSetConstraints , Int )] = N
538
+ var lbtsc : Opt [(TupleSetConstraints , Int )] = N
539
+ var ubtsc : Opt [(TupleSetConstraints , Int )] = N
539
540
540
541
// * Bounds should always be disregarded when `equatedTo` is defined, as they are then irrelevant:
541
542
def lowerBounds : List [SimpleType ] = { require(assignedTo.isEmpty, this ); _lowerBounds }
@@ -654,7 +655,7 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
654
655
def go (ub : ST ): Unit = ub match {
655
656
case ub : TV =>
656
657
ub.upperBounds.foreach(go)
657
- ub.tsc = S (this , index)
658
+ ub.lbtsc = S (this , index)
658
659
case _ =>
659
660
constraints.filterInPlace { constrs =>
660
661
val ty = constrs(index)
@@ -667,65 +668,89 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
667
668
if (constraints.sizeCompare(1 ) === 0 ) {
668
669
constraints.head.zip(tvs).foreach {
669
670
case (ty, tv) =>
670
- tv.tsc = N
671
+ tv.lbtsc = N
672
+ tv.ubtsc = N
673
+ constrain(tv, ty)(raise, prov, ctx)
674
+ constrain(ty, tv)(raise, prov, ctx)
675
+ }
676
+ }
677
+ }
678
+ def filterLB (index : Int , lb : ST )(implicit raise : Raise , ctx : Ctx ): Unit = {
679
+ constraints.filterInPlace { constrs =>
680
+ val ty = constrs(index)
681
+ val dnf = DNF .mk(MaxLevel , Nil , lb & ty.neg(), true )
682
+ dnf.isBot || dnf.cs.forall(c => ! (c.vars.isEmpty && c.nvars.isEmpty))
683
+ }
684
+ println(s " TSC filterLB: $tvs in $constraints" )
685
+ if (constraints.sizeCompare(1 ) === 0 ) {
686
+ constraints.head.zip(tvs).foreach {
687
+ case (ty, tv) =>
688
+ tv.lbtsc = N
689
+ tv.ubtsc = N
671
690
constrain(tv, ty)(raise, prov, ctx)
672
691
constrain(ty, tv)(raise, prov, ctx)
673
692
}
674
693
}
675
694
}
676
695
}
677
696
object TupleSetConstraints {
678
- def lcgField (a : FieldType , b : FieldType )
697
+ def lcgField (first : FieldType , rest : Ls [ FieldType ] )
679
698
(implicit prov : TypeProvenance , lvl : Level )
680
699
: (FieldType , Ls [TV ], Ls [Ls [ST ]]) = {
681
- val (ub, tvs, constrs) = lcg(a .ub, b.ub )
682
- if (a .lb.isEmpty && b. lb.isEmpty) {
700
+ val (ub, tvs, constrs) = lcg(first .ub, rest.map(_.ub) )
701
+ if (first .lb.isEmpty && rest.forall(_. lb.isEmpty) ) {
683
702
(FieldType (N , ub)(prov), tvs, constrs)
684
703
} else {
685
- val (lb, ltvs, lconstrs) = lcg(a .lb.getOrElse(BotType ), b. lb.getOrElse(BotType ))
704
+ val (lb, ltvs, lconstrs) = lcg(first .lb.getOrElse(BotType ), rest.map(_. lb.getOrElse(BotType ) ))
686
705
(FieldType (S (lb), ub)(prov), tvs ++ ltvs, constrs ++ lconstrs)
687
706
}
688
707
}
689
- def lcg (a : ST , b : ST )
708
+ def lcg (first : ST , rest : Ls [ ST ] )
690
709
(implicit prov : TypeProvenance , lvl : Level )
691
- : (ST , Ls [TV ], Ls [Ls [ST ]]) = (a, b) match {
692
- case (_, b : ProvType ) => lcg(a, b.underlying)
693
- case (a : ProvType , _) => lcg(a.underlying, b)
694
- case (a : FT , b : FT ) => lcgFunction(a, b)
695
- case (a : ArrayType , b : ArrayType ) =>
696
- val (t, tvs, constrs) = lcgField(a.inner, b.inner)
710
+ : (ST , Ls [TV ], Ls [Ls [ST ]]) = first match {
711
+ case a : FunctionType if rest.forall(_.isInstanceOf [FunctionType ]) =>
712
+ val (lhss, rhss) = rest.collect {
713
+ case FunctionType (lhs, rhs) => lhs -> rhs
714
+ }.unzip
715
+ val (lhs, ltvs, lconstrs) = lcg(a.lhs, lhss)
716
+ val (rhs, rtvs, rconstrs) = lcg(a.rhs, rhss)
717
+ (FunctionType (lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
718
+ case a : ArrayType if rest.forall(_.isInstanceOf [ArrayType ]) =>
719
+ val inners = rest.collect { case b : ArrayType => b.inner }
720
+ val (t, tvs, constrs) = lcgField(a.inner, inners)
697
721
(ArrayType (t)(prov), tvs, constrs)
698
- case (a : TupleType , b : TupleType ) if a.fields.sizeCompare(b.fields.size) === 0 =>
699
- val (fts, tvss, constrss) = a.fields.map(_._2).zip(b.fields.map(_._2)).map {
700
- case (a, b) => lcgField(a, b)
701
- }.unzip3
722
+ case a : TupleType if rest.forall { case b : TupleType => a.fields.sizeCompare(b.fields.size) === 0 ; case _ => false } =>
723
+ val fields = rest.collect { case TupleType (fields) => fields.map(_._2) }
724
+ val (fts, tvss, constrss) = a.fields.map(_._2).zip(fields.transpose).map { case (a, bs) => lcgField(a, bs) }.unzip3
702
725
(TupleType (fts.map(N -> _))(prov), tvss.flatten, constrss.flatten)
703
- case (a : TR , b : TR ) if a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0 =>
704
- val (ts, tvss, constrss) = a.targs.zip(b.targs).map {
705
- case (a, b) => lcg(a, b)
706
- }.unzip3
726
+ case a : TR if rest.forall { case b : TR => a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0 ; case _ => false } =>
727
+ val targs = rest.collect { case b : TR => b.targs }
728
+ val (ts, tvss, constrss) = a.targs.zip(targs.transpose).map { case (a, bs) => lcg(a, bs) }.unzip3
707
729
(TypeRef (a.defn, ts)(prov), tvss.flatten, constrss.flatten)
708
- case ( a : TV , b : TV ) if a.compare(b) === 0 => (a, Nil , Nil )
709
- case ( a : ExtrType , b : ExtrType ) if a.pol === b.pol => (a, Nil , Nil )
730
+ case a : TV if rest.forall { case b : TV => a.compare(b) === 0 ; case _ => false } => (a, Nil , Nil )
731
+ case a if rest.forall(_ === a) => (a, Nil , Nil )
710
732
case _ =>
711
733
val tv = freshVar(prov, N )
712
- (tv, List (tv), List (List (a, b) ))
734
+ (tv, List (tv), List (first :: rest ))
713
735
}
714
- def lcgFunction (a : FT , b : FT )
715
- (implicit prov : TypeProvenance , lvl : Level )
716
- : (FT , Ls [TV ], Ls [Ls [ST ]]) = {
717
- val (lhs, ltvs, lconstrs) = lcg(a.lhs, b.lhs)
718
- val (rhs, rtvs, rconstrs) = lcg(a.rhs, b.rhs)
736
+ def lcgFunction (first : FunctionType , rest : Ls [FunctionType ])(implicit prov : TypeProvenance , lvl : Level )
737
+ : (FunctionType , Ls [TV ], Ls [Ls [ST ]]) = {
738
+ val (lhss, rhss) = rest.map {
739
+ case FunctionType (lhs, rhs) => lhs -> rhs
740
+ }.unzip
741
+ val (lhs, ltvs, lconstrs) = lcg(first.lhs, lhss)
742
+ val (rhs, rtvs, rconstrs) = lcg(first.rhs, rhss)
719
743
(FunctionType (lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
720
744
}
721
745
def mk (ov : Overload )(implicit lvl : Level ): FunctionType = {
722
- val (t, tvs, constrs) =
723
- ov.alts.tail.foldLeft((ov.alts.head, Nil : Ls [TV ], Nil : Ls [Ls [ST ]])) {
724
- case ((a, tvs, constrs), b) => lcgFunction(a, b)(ov.prov, lvl)
725
- }
726
- // val (t, tvs, constrs) = lcgFunction(ov.alts.head, ov.alts.tail)(ov.prov, lvl)
746
+ def unwrap (t : ST ): ST = t.map(unwrap)
747
+ val f = ov.mapAlts(unwrap)(unwrap)
748
+ val (t, tvs, constrs) = lcgFunction(f.alts.head, f.alts.tail)(ov.prov, lvl)
727
749
val tsc = new TupleSetConstraints (MutSet .empty ++ constrs.transpose, tvs)(ov.prov)
728
- tvs.zipWithIndex.foreach { case (tv, i) => tv.tsc = S ((tsc, i)) }
750
+ tvs.zipWithIndex.foreach { case (tv, i) =>
751
+ tv.lbtsc = S ((tsc, i))
752
+ tv.ubtsc = S ((tsc, i))
753
+ }
729
754
println(s " TSC mk: ${tsc.tvs} in ${tsc.constraints}" )
730
755
t
731
756
}
0 commit comments