Skip to content

Commit d40c9b7

Browse files
committed
fix lcg and add constraints filtering by lowerbound
1 parent 3ad402f commit d40c9b7

File tree

5 files changed

+130
-60
lines changed

5 files changed

+130
-60
lines changed

shared/src/main/scala/mlscript/ConstraintSolver.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,14 +820,23 @@ class ConstraintSolver extends NormalForms { self: Typer =>
820820
val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) =>
821821
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
822822
lhs.upperBounds ::= newBound // update the bound
823-
lhs.tsc.foreach { case (tsc, i) => tsc.filterUB(i, rhs) }
823+
lhs.lbtsc.foreach {
824+
case (tsc, i) =>
825+
tsc.filterUB(i, rhs)
826+
if (tsc.constraints.isEmpty) reportError()
827+
}
824828
lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound
825829

826830
case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level =>
827831
println(s"NEW $rhs LB (${lhs.level})")
828832
val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) =>
829833
if (c.prov is noProv) ty else mkProxy(ty, c.prov))
830834
rhs.lowerBounds ::= newBound // update the bound
835+
rhs.ubtsc.foreach {
836+
case (tsc, i) =>
837+
tsc.filterLB(i, lhs)
838+
if (tsc.constraints.isEmpty) reportError()
839+
}
831840
rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound
832841

833842

shared/src/main/scala/mlscript/TypeSimplifier.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ trait TypeSimplifier { self: Typer =>
7676
.reduceOption(_ &- _).filterNot(_.isTop).toList
7777
else Nil
7878
}
79-
8079
nv
8180

8281
case ComposedType(true, l, r) =>

shared/src/main/scala/mlscript/TyperDatatypes.scala

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -535,7 +535,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
535535
_assignedTo = value
536536
}
537537

538-
var tsc: Opt[(TupleSetConstraints, Int)] = N
538+
var lbtsc: Opt[(TupleSetConstraints, Int)] = N
539+
var ubtsc: Opt[(TupleSetConstraints, Int)] = N
539540

540541
// * Bounds should always be disregarded when `equatedTo` is defined, as they are then irrelevant:
541542
def lowerBounds: List[SimpleType] = { require(assignedTo.isEmpty, this); _lowerBounds }
@@ -654,7 +655,7 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
654655
def go(ub: ST): Unit = ub match {
655656
case ub: TV =>
656657
ub.upperBounds.foreach(go)
657-
ub.tsc = S(this, index)
658+
ub.lbtsc = S(this, index)
658659
case _ =>
659660
constraints.filterInPlace { constrs =>
660661
val ty = constrs(index)
@@ -667,65 +668,89 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
667668
if (constraints.sizeCompare(1) === 0) {
668669
constraints.head.zip(tvs).foreach {
669670
case (ty, tv) =>
670-
tv.tsc = N
671+
tv.lbtsc = N
672+
tv.ubtsc = N
673+
constrain(tv, ty)(raise, prov, ctx)
674+
constrain(ty, tv)(raise, prov, ctx)
675+
}
676+
}
677+
}
678+
def filterLB(index: Int, lb: ST)(implicit raise: Raise, ctx: Ctx): Unit = {
679+
constraints.filterInPlace { constrs =>
680+
val ty = constrs(index)
681+
val dnf = DNF.mk(MaxLevel, Nil, lb & ty.neg(), true)
682+
dnf.isBot || dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty))
683+
}
684+
println(s"TSC filterLB: $tvs in $constraints")
685+
if (constraints.sizeCompare(1) === 0) {
686+
constraints.head.zip(tvs).foreach {
687+
case (ty, tv) =>
688+
tv.lbtsc = N
689+
tv.ubtsc = N
671690
constrain(tv, ty)(raise, prov, ctx)
672691
constrain(ty, tv)(raise, prov, ctx)
673692
}
674693
}
675694
}
676695
}
677696
object TupleSetConstraints {
678-
def lcgField(a: FieldType, b: FieldType)
697+
def lcgField(first: FieldType, rest: Ls[FieldType])
679698
(implicit prov: TypeProvenance, lvl: Level)
680699
: (FieldType, Ls[TV], Ls[Ls[ST]]) = {
681-
val (ub, tvs, constrs) = lcg(a.ub, b.ub)
682-
if (a.lb.isEmpty && b.lb.isEmpty) {
700+
val (ub, tvs, constrs) = lcg(first.ub, rest.map(_.ub))
701+
if (first.lb.isEmpty && rest.forall(_.lb.isEmpty)) {
683702
(FieldType(N, ub)(prov), tvs, constrs)
684703
} else {
685-
val (lb, ltvs, lconstrs) = lcg(a.lb.getOrElse(BotType), b.lb.getOrElse(BotType))
704+
val (lb, ltvs, lconstrs) = lcg(first.lb.getOrElse(BotType), rest.map(_.lb.getOrElse(BotType)))
686705
(FieldType(S(lb), ub)(prov), tvs ++ ltvs, constrs ++ lconstrs)
687706
}
688707
}
689-
def lcg(a: ST, b: ST)
708+
def lcg(first: ST, rest: Ls[ST])
690709
(implicit prov: TypeProvenance, lvl: Level)
691-
: (ST, Ls[TV], Ls[Ls[ST]]) = (a, b) match {
692-
case (_, b: ProvType) => lcg(a, b.underlying)
693-
case (a: ProvType, _) => lcg(a.underlying, b)
694-
case (a: FT, b: FT) => lcgFunction(a, b)
695-
case (a: ArrayType, b: ArrayType) =>
696-
val (t, tvs, constrs) = lcgField(a.inner, b.inner)
710+
: (ST, Ls[TV], Ls[Ls[ST]]) = first match {
711+
case a: FunctionType if rest.forall(_.isInstanceOf[FunctionType]) =>
712+
val (lhss, rhss) = rest.collect {
713+
case FunctionType(lhs, rhs) => lhs -> rhs
714+
}.unzip
715+
val (lhs, ltvs, lconstrs) = lcg(a.lhs, lhss)
716+
val (rhs, rtvs, rconstrs) = lcg(a.rhs, rhss)
717+
(FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
718+
case a: ArrayType if rest.forall(_.isInstanceOf[ArrayType]) =>
719+
val inners = rest.collect { case b: ArrayType => b.inner }
720+
val (t, tvs, constrs) = lcgField(a.inner, inners)
697721
(ArrayType(t)(prov), tvs, constrs)
698-
case (a: TupleType, b: TupleType) if a.fields.sizeCompare(b.fields.size) === 0 =>
699-
val (fts, tvss, constrss) = a.fields.map(_._2).zip(b.fields.map(_._2)).map {
700-
case (a, b) => lcgField(a, b)
701-
}.unzip3
722+
case a: TupleType if rest.forall { case b: TupleType => a.fields.sizeCompare(b.fields.size) === 0; case _ => false } =>
723+
val fields = rest.collect { case TupleType(fields) => fields.map(_._2) }
724+
val (fts, tvss, constrss) = a.fields.map(_._2).zip(fields.transpose).map { case (a, bs) => lcgField(a, bs) }.unzip3
702725
(TupleType(fts.map(N -> _))(prov), tvss.flatten, constrss.flatten)
703-
case (a: TR, b: TR) if a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0 =>
704-
val (ts, tvss, constrss) = a.targs.zip(b.targs).map {
705-
case (a, b) => lcg(a, b)
706-
}.unzip3
726+
case a: TR if rest.forall { case b: TR => a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0; case _ => false } =>
727+
val targs = rest.collect { case b: TR => b.targs }
728+
val (ts, tvss, constrss) = a.targs.zip(targs.transpose).map { case (a, bs) => lcg(a, bs) }.unzip3
707729
(TypeRef(a.defn, ts)(prov), tvss.flatten, constrss.flatten)
708-
case (a: TV, b: TV) if a.compare(b) === 0 => (a, Nil, Nil)
709-
case (a: ExtrType, b: ExtrType) if a.pol === b.pol => (a, Nil, Nil)
730+
case a: TV if rest.forall { case b: TV => a.compare(b) === 0; case _ => false } => (a, Nil, Nil)
731+
case a if rest.forall(_ === a) => (a, Nil, Nil)
710732
case _ =>
711733
val tv = freshVar(prov, N)
712-
(tv, List(tv), List(List(a, b)))
734+
(tv, List(tv), List(first :: rest))
713735
}
714-
def lcgFunction(a: FT, b: FT)
715-
(implicit prov: TypeProvenance, lvl: Level)
716-
: (FT, Ls[TV], Ls[Ls[ST]]) = {
717-
val (lhs, ltvs, lconstrs) = lcg(a.lhs, b.lhs)
718-
val (rhs, rtvs, rconstrs) = lcg(a.rhs, b.rhs)
736+
def lcgFunction(first: FunctionType, rest: Ls[FunctionType])(implicit prov: TypeProvenance, lvl: Level)
737+
: (FunctionType, Ls[TV], Ls[Ls[ST]]) = {
738+
val (lhss, rhss) = rest.map {
739+
case FunctionType(lhs, rhs) => lhs -> rhs
740+
}.unzip
741+
val (lhs, ltvs, lconstrs) = lcg(first.lhs, lhss)
742+
val (rhs, rtvs, rconstrs) = lcg(first.rhs, rhss)
719743
(FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs)
720744
}
721745
def mk(ov: Overload)(implicit lvl: Level): FunctionType = {
722-
val (t, tvs, constrs) =
723-
ov.alts.tail.foldLeft((ov.alts.head, Nil: Ls[TV], Nil: Ls[Ls[ST]])) {
724-
case ((a, tvs, constrs), b) => lcgFunction(a, b)(ov.prov, lvl)
725-
}
726-
// val (t, tvs, constrs) = lcgFunction(ov.alts.head, ov.alts.tail)(ov.prov, lvl)
746+
def unwrap(t: ST): ST = t.map(unwrap)
747+
val f = ov.mapAlts(unwrap)(unwrap)
748+
val (t, tvs, constrs) = lcgFunction(f.alts.head, f.alts.tail)(ov.prov, lvl)
727749
val tsc = new TupleSetConstraints(MutSet.empty ++ constrs.transpose, tvs)(ov.prov)
728-
tvs.zipWithIndex.foreach { case (tv, i) => tv.tsc = S((tsc, i)) }
750+
tvs.zipWithIndex.foreach { case (tv, i) =>
751+
tv.lbtsc = S((tsc, i))
752+
tv.ubtsc = S((tsc, i))
753+
}
729754
println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}")
730755
t
731756
}

shared/src/main/scala/mlscript/TyperHelpers.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -983,13 +983,24 @@ abstract class TyperHelpers { Typer: Typer =>
983983
def getVars: SortedSet[TypeVariable] = getVarsImpl(includeBounds = true)
984984

985985
def showBounds: String =
986-
getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty).map {
986+
getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty || (tv.lbtsc.fold(false)(!_._1.tvs.contains(tv)))).map {
987987
case tv @ AssignedVariable(ty) => "\n\t\t" + tv.toString + " := " + ty
988988
case tv => ("\n\t\t" + tv.toString
989989
+ (if (tv.lowerBounds.isEmpty) "" else " :> " + tv.lowerBounds.mkString(" | "))
990-
+ (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & ")))
991-
}.mkString
992-
990+
+ (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & "))
991+
+ tv.lbtsc.fold(""){ case (tsc, i) => " :> " + tsc.tvs(i) } )
992+
}.mkString + {
993+
val visited: MutSet[TV] = MutSet.empty
994+
getVars.iterator.filter(tv => tv.lbtsc.fold(false)(_._1.tvs.contains(tv))).map {
995+
case tv if visited.contains(tv) => ""
996+
case tv =>
997+
visited ++= tv.lbtsc.fold(Nil: Ls[TV])(_._1.tvs)
998+
tv.lbtsc.fold("") { case (tsc, _) => ("\n\t\t[ "
999+
+ tsc.tvs.mkString(", ")
1000+
+ " ] in { " + tsc.constraints.mkString(", ") + " }")
1001+
}
1002+
}.mkString
1003+
}
9931004
}
9941005

9951006

shared/src/test/diff/nu/HeungTung.mls

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,16 @@ fun g = h
6666
//│ fun g: (Bool | Int) -> (Int | false | true)
6767

6868
// * In one step
69+
:e // TODO: argument of union type
6970
fun g: (Int | Bool) -> (Int | Bool)
7071
fun g = f
72+
//│ ╔══[ERROR] Type mismatch in definition:
73+
//│ ║ l.71: fun g = f
74+
//│ ║ ^^^^^
75+
//│ ╟── expression of type `Int | false | true` does not match type `?a`
76+
//│ ╟── Note: constraint arises from function type:
77+
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
78+
//│ ╙── ^^^^^^^^^^^^^^
7179
//│ fun g: Int -> Int & Bool -> Bool
7280
//│ fun g: (Bool | Int) -> (Int | false | true)
7381

@@ -88,9 +96,11 @@ fun j = i
8896
fun j: (Int & Bool) -> (Int & Bool)
8997
fun j = f
9098
//│ ╔══[ERROR] Type mismatch in definition:
91-
//│ ║ l.89: fun j = f
99+
//│ ║ l.97: fun j = f
92100
//│ ║ ^^^^^
93-
//│ ╙── expression of type `Int` does not match type `nothing`
101+
//│ ╟── type `?a` does not match type `nothing`
102+
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
103+
//│ ╙── ^^^^^^^^^^^^^^
94104
//│ fun j: Int -> Int & Bool -> Bool
95105
//│ fun j: nothing -> nothing
96106

@@ -106,23 +116,30 @@ fun g = f
106116
// * With match-type-based constraint solving, we could return Int here
107117

108118
f(0)
109-
//│ Int | false | true
119+
//│ Int
110120
//│ res
111121
//│ = 0
112122

113123
// f(0) : case 0 of { Int => Int; Bool => Bool } == Int
114124

115125

116126
x => f(x)
117-
//│ (Bool | Int) -> (Int | false | true)
127+
//│ anything -> nothing
118128
//│ res
119129
//│ = [Function: res]
120130

121131
// : forall 'a: 'a -> case 'a of { Int => Int; Bool => Bool } where 'a <: Int | Bool
122132

123-
133+
:e
124134
f(if true then 0 else false)
125-
//│ Int | false | true
135+
//│ ╔══[ERROR] Type mismatch in application:
136+
//│ ║ l.134: f(if true then 0 else false)
137+
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
138+
//│ ╟── expression of type `0 | false` does not match type `?a`
139+
//│ ╟── Note: constraint arises from function type:
140+
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
141+
//│ ╙── ^^^^^^^^^^^^^^
142+
//│ error
126143
//│ res
127144
//│ = 0
128145

@@ -132,12 +149,21 @@ f(if true then 0 else false)
132149
:w
133150
f(refined if true then 0 else false) // this one can be precise again!
134151
//│ ╔══[WARNING] Paren-less applications should use the 'of' keyword
135-
//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again!
152+
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
136153
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
137154
//│ ╔══[ERROR] identifier not found: refined
138-
//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again!
155+
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
139156
//│ ╙── ^^^^^^^
140-
//│ Int | false | true
157+
//│ ╔══[ERROR] Type mismatch in application:
158+
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
159+
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
160+
//│ ╟── application of type `error` does not match type `?a`
161+
//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again!
162+
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
163+
//│ ╟── Note: constraint arises from function type:
164+
//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool)
165+
//│ ╙── ^^^^^^^^^^^^^^
166+
//│ error
141167
//│ Code generation encountered an error:
142168
//│ unresolved symbol refined
143169

@@ -193,7 +219,7 @@ type T = List[Int]
193219
:e // TODO application types
194220
type Res = M(T)
195221
//│ ╔══[ERROR] Wrong number of type arguments – expected 0, found 1
196-
//│ ║ l.194: type Res = M(T)
222+
//│ ║ l.220: type Res = M(T)
197223
//│ ╙── ^^^^
198224
//│ type Res = M
199225

@@ -216,21 +242,21 @@ fun f: Int -> Int
216242
fun f: Bool -> Bool
217243
fun f = id
218244
//│ ╔══[ERROR] A type signature for 'f' was already given
219-
//│ ║ l.216: fun f: Bool -> Bool
245+
//│ ║ l.242: fun f: Bool -> Bool
220246
//│ ╙── ^^^^^^^^^^^^^^^^^^^
221247
//│ fun f: forall 'a. 'a -> 'a
222248
//│ fun f: Int -> Int
223249

224250
:e // TODO support
225251
f: (Int -> Int) & (Bool -> Bool)
226252
//│ ╔══[ERROR] Type mismatch in type ascription:
227-
//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool)
253+
//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool)
228254
//│ ║ ^
229255
//│ ╟── type `Bool` is not an instance of `Int`
230-
//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool)
256+
//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool)
231257
//│ ║ ^^^^
232258
//│ ╟── Note: constraint arises from type reference:
233-
//│ ║ l.215: fun f: Int -> Int
259+
//│ ║ l.241: fun f: Int -> Int
234260
//│ ╙── ^^^
235261
//│ Int -> Int & Bool -> Bool
236262
//│ res
@@ -297,14 +323,14 @@ fun test(x) = refined if x is
297323
A then 0
298324
B then 1
299325
//│ ╔══[WARNING] Paren-less applications should use the 'of' keyword
300-
//│ ║ l.296: fun test(x) = refined if x is
326+
//│ ║ l.322: fun test(x) = refined if x is
301327
//│ ║ ^^^^^^^^^^^^^^^
302-
//│ ║ l.297: A then 0
328+
//│ ║ l.323: A then 0
303329
//│ ║ ^^^^^^^^^^
304-
//│ ║ l.298: B then 1
330+
//│ ║ l.324: B then 1
305331
//│ ╙── ^^^^^^^^^^
306332
//│ ╔══[ERROR] identifier not found: refined
307-
//│ ║ l.296: fun test(x) = refined if x is
333+
//│ ║ l.322: fun test(x) = refined if x is
308334
//│ ╙── ^^^^^^^
309335
//│ fun test: (A | B) -> error
310336
//│ Code generation encountered an error:

0 commit comments

Comments
 (0)