Skip to content

Commit e1b7fba

Browse files
committed
fix provtype
1 parent 09b5082 commit e1b7fba

File tree

7 files changed

+89
-79
lines changed

7 files changed

+89
-79
lines changed

shared/src/main/scala/mlscript/ConstraintSolver.scala

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,6 @@ class ConstraintSolver extends NormalForms { self: Typer =>
632632
case S(tsc) => if (!tsc.tvs.isEmpty && tsc.constraints.isEmpty) reportError()
633633
case N => reportError()
634634
}
635-
// val t = TupleSetConstraints.mk(ov)
636-
// annoying(Nil, LhsRefined(S(t), ts, r, trs), Nil, done_rs)
637635
case (LhsRefined(S(ov: Overload), ts, r, trs), _) =>
638636
annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints
639637
case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) =>
@@ -849,9 +847,8 @@ class ConstraintSolver extends NormalForms { self: Typer =>
849847
}
850848
val u = lhs.tsc.filter(_._1.constraints.sizeCompare(1) === 0)
851849
u.foreachEntry { case (k, _) =>
852-
k.tvs.mapValues(_.unwrapProxies).foreach { // TODO less inefficient; remove useless case
850+
k.tvs.mapValues(_.unwrapProxies).foreach {
853851
case (_,tv: TV) => tv.tsc.remove(k)
854-
case (_,ProvType(tv: TV)) => tv.tsc.remove(k)
855852
case _ => ()
856853
}
857854
}
@@ -877,9 +874,8 @@ class ConstraintSolver extends NormalForms { self: Typer =>
877874
}
878875
val u = rhs.tsc.filter(_._1.constraints.sizeCompare(1) === 0)
879876
u.foreachEntry { case (k, _) =>
880-
k.tvs.mapValues(_.unwrapProxies).foreach { // TODO less inefficient; remove useless case
877+
k.tvs.mapValues(_.unwrapProxies).foreach {
881878
case (_,tv: TV) => tv.tsc.remove(k)
882-
case (_,ProvType(tv: TV)) => tv.tsc.remove(k)
883879
case _ => ()
884880
}
885881
}
@@ -1612,24 +1608,20 @@ class ConstraintSolver extends NormalForms { self: Typer =>
16121608
lvl
16131609
})
16141610
val freshentsc = tv.tsc.flatMap { case (tsc,_) =>
1615-
if (tsc.tvs.forall {
1616-
case (_,tv: TV) => !freshened.contains(tv)
1617-
case (_,ProvType(tv: TV)) => !freshened.contains(tv)
1611+
if (tsc.tvs.map(_._2.unwrapProxies).forall {
1612+
case tv: TV => !freshened.contains(tv)
16181613
case _ => true
16191614
}) S(tsc) else N
16201615
}
16211616
freshened += tv -> v
16221617
v.lowerBounds = tv.lowerBounds.mapConserve(freshen)
16231618
v.upperBounds = tv.upperBounds.mapConserve(freshen)
16241619
freshentsc.foreach { tsc =>
1625-
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)(tsc.prov)
1620+
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
16261621
t.constraints = t.constraints.map(_.map(freshen))
16271622
t.tvs = t.tvs.map(x => (x._1,freshen(x._2)))
1628-
t.tvs.zipWithIndex.foreach {
1629-
case ((pol, tv: TV), i) =>
1630-
tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
1631-
case ((pol, ProvType(tv: TV)), i) =>
1632-
tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
1623+
t.tvs.map(_._2.unwrapProxies).zipWithIndex.foreach {
1624+
case (tv: TV, i) => tv.tsc.updateWith(t)(_.map(_ + i).orElse(S(Set(i))))
16331625
case _ => ()
16341626
}
16351627
}

shared/src/main/scala/mlscript/TypeSimplifier.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ trait TypeSimplifier { self: Typer =>
8383
case S(tsc) => (tsc, i)
8484
case N if inPlace => (tsc, i)
8585
case N =>
86-
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)(tsc.prov)
86+
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
8787
renewedtsc += tsc -> t
8888
t.tvs = t.tvs.map(x => (x._1, process(x._2, N)))
8989
(t, i)
@@ -1041,7 +1041,7 @@ trait TypeSimplifier { self: Typer =>
10411041
res.tsc ++= tv.tsc.map { case (tsc, i) => renewaltsc.get(tsc) match {
10421042
case S(tsc) => (tsc, i)
10431043
case N =>
1044-
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)(tsc.prov)
1044+
val t = new TupleSetConstraints(tsc.constraints, tsc.tvs)
10451045
renewaltsc += tsc -> t
10461046
t.tvs = t.tvs.map(x => (x._1, transform(x._2, PolMap.neu, Set.empty)))
10471047
(t, i)

shared/src/main/scala/mlscript/Typer.scala

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -686,10 +686,9 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
686686
}
687687
tscs.foreach { case (typevars, constrs) =>
688688
val tvs = typevars.map(x => (x._1, rec(x._2)))
689-
val tsc = new TupleSetConstraints(constrs.map(_.map(rec)), tvs)(res.prov)
690-
tvs.zipWithIndex.foreach {
691-
case ((_, tv: TV), i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
692-
case ((_, ProvType(tv: TV)), i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
689+
val tsc = new TupleSetConstraints(constrs.map(_.map(rec)), tvs)
690+
tvs.map(_._2.unwrapProxies).zipWithIndex.foreach {
691+
case (tv: TV, i) => tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
693692
case _ => ()
694693
}
695694
}
@@ -1992,11 +1991,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
19921991
tv.tsc.foreachEntry {
19931992
case (tsc, i) =>
19941993
if (seenTscs.add(tsc)) {
1995-
val tvs = tsc.tvs.map {
1996-
case (pol, tv: TV) => (pol, tv.asTypeVar)
1997-
case (pol, ProvType(tv: TV)) => (pol, tv.asTypeVar)
1998-
case (pol, t) => (pol, go(t))
1999-
}
1994+
val tvs = tsc.tvs.map(x => (x._1,go(x._2)))
20001995
val constrs = tsc.constraints.map(_.map(go))
20011996
tscs ::= tvs -> constrs
20021997
}

shared/src/main/scala/mlscript/TyperDatatypes.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -673,7 +673,7 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
673673
val prov = noProv
674674
}
675675

676-
class TupleSetConstraints(var constraints: Ls[Ls[ST]], var tvs: Ls[(Bool, ST)])(val prov: TypeProvenance) {
676+
class TupleSetConstraints(var constraints: Ls[Ls[ST]], var tvs: Ls[(Bool, ST)]) {
677677
def updateImpl(index: Int, bound: ST)(implicit raise: Raise, ctx: Ctx) : Unit = {
678678
val u0 = constraints.flatMap { c =>
679679
TupleSetConstraints.lcg(tvs(index)._1, bound, c(index)).map(tvs.zip(c)++_)
@@ -683,17 +683,15 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
683683
(u,l.reduce((x,y) => ComposedType(!p,x,y)(noProv)))
684684
}
685685
}
686-
tvs.foreach {
687-
case (_, tv: TV) => tv.tsc += this -> Set.empty
688-
case (_, ProvType(tv: TV)) => tv.tsc += this -> Set.empty
686+
tvs.map(_._2.unwrapProxies).foreach {
687+
case tv: TV => tv.tsc += this -> Set.empty
689688
case _ => ()
690689
}
691690
if (!u.isEmpty) {
692691
tvs = u.flatMap(_.keys).distinct
693692
constraints = tvs.map(x => u.map(_.getOrElse(x,if (x._1) TopType else BotType))).transpose
694-
tvs.zipWithIndex.foreach {
695-
case ((pol, tv: TV), i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i))))
696-
case ((pol, ProvType(tv: TV)), i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i))))
693+
tvs.map(_._2.unwrapProxies).zipWithIndex.foreach {
694+
case (tv: TV, i) => tv.tsc.updateWith(this)(_.map(_ + i).orElse(S(Set(i))))
697695
case _ => ()
698696
}
699697
} else {
@@ -779,8 +777,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
779777
if (u.isEmpty) { return N }
780778
val tvs = u.flatMap(_.keys).distinct
781779
val m = tvs.map(x => u.map(_.getOrElse(x,if (x._1) TopType else BotType)))
782-
val tsc = new TupleSetConstraints(m.transpose, tvs)(ov.prov)
783-
tvs.map(x => (x._1,x._2.unwrapProxies)).zipWithIndex.foreach {
780+
val tsc = new TupleSetConstraints(m.transpose, tvs)
781+
tvs.mapValues(_.unwrapProxies).zipWithIndex.foreach {
784782
case ((true, tv: TV), i) =>
785783
tv.tsc.updateWith(tsc)(_.map(_ + i).orElse(S(Set(i))))
786784
tv.lowerBounds.foreach(tsc.updateImpl(i, _))
@@ -791,9 +789,8 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer =>
791789
}
792790
println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}")
793791
if (tsc.constraints.sizeCompare(1) === 0) {
794-
tvs.foreach {
795-
case (_, tv: TV) => tv.tsc.remove(tsc)
796-
case (_, ProvType(tv: TV)) => tv.tsc.remove(tsc)
792+
tvs.map(_._2.unwrapProxies).foreach {
793+
case tv: TV => tv.tsc.remove(tsc)
797794
case _ => ()
798795
}
799796
tsc.constraints.head.zip(tvs).foreach {

shared/src/main/scala/mlscript/TyperHelpers.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ abstract class TyperHelpers { Typer: Typer =>
917917
}
918918
def children(includeBounds: Bool): List[SimpleType] = this match {
919919
case tv @ AssignedVariable(ty) => if (includeBounds) ty :: Nil else Nil
920-
case tv: TypeVariable => if (includeBounds) tv.lowerBounds ::: tv.upperBounds else Nil
920+
case tv: TypeVariable => if (includeBounds) tv.lowerBounds ::: tv.upperBounds ++ tv.tsc.keys.flatMap(_.tvs.map(_._2)) else Nil
921921
case FunctionType(l, r) => l :: r :: Nil
922922
case Overload(as) => as
923923
case ComposedType(_, l, r) => l :: r :: Nil
@@ -1014,9 +1014,7 @@ abstract class TyperHelpers { Typer: Typer =>
10141014
val couldBeDistribbed = bod.varsBetween(polymLevel, MaxLevel)
10151015
println(s"could be distribbed: $couldBeDistribbed")
10161016
if (couldBeDistribbed.isEmpty) return N
1017-
val cannotBeDistribbed = par.varsBetween(polymLevel, MaxLevel).flatMap { v =>
1018-
v :: v.tsc.keys.flatMap(_.tvs.flatMap(_._2.getVars)).toList
1019-
}
1017+
val cannotBeDistribbed = par.varsBetween(polymLevel, MaxLevel)
10201018
println(s"cannot be distribbed: $cannotBeDistribbed")
10211019
val canBeDistribbed = couldBeDistribbed -- cannotBeDistribbed
10221020
if (canBeDistribbed.isEmpty) return N // TODO

shared/src/test/diff/fcp/Overloads.mls

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,22 @@ IISS 0
9494
def f = fun x -> (if true then IISS else BBNN) x
9595
//│ f: 'a -> 'b
9696
//│ where
97-
//│ [+'a, -'b] in {[bool, bool], [number, number]}
9897
//│ [+'a, -'b] in {[int, int], [string, string]}
98+
//│ [+'a, -'b] in {[bool, bool], [number, number]}
9999

100100
f(0)
101101
//│ res: number
102102

103-
104103
// FIXME
105104
f(0) + 1
106105
//│ ╔══[ERROR] Type mismatch in operator application:
107-
//│ ║ l.105: f(0) + 1
106+
//│ ║ l.104: f(0) + 1
108107
//│ ║ ^^^^^^
109108
//│ ╟── type `number` is not an instance of type `int`
110109
//│ ║ l.13: def BBNN: bool -> bool & number -> number
111110
//│ ║ ^^^^^^
112111
//│ ╟── but it flows into application with expected type `int`
113-
//│ ║ l.105: f(0) + 1
112+
//│ ║ l.104: f(0) + 1
114113
//│ ╙── ^^^^
115114
//│ res: error | int
116115

@@ -120,10 +119,10 @@ f : int -> number
120119
:e
121120
f : number -> int
122121
//│ ╔══[ERROR] Type mismatch in type ascription:
123-
//│ ║ l.121: f : number -> int
122+
//│ ║ l.120: f : number -> int
124123
//│ ║ ^
125124
//│ ╟── type `number` does not match type `?a`
126-
//│ ║ l.121: f : number -> int
125+
//│ ║ l.120: f : number -> int
127126
//│ ╙── ^^^^^^
128127
//│ res: number -> int
129128

@@ -140,16 +139,16 @@ if true then IISS else BBNN
140139
:e
141140
(if true then IISS else BBNN) : (0 | 1 | true) -> number
142141
//│ ╔══[ERROR] Type mismatch in type ascription:
143-
//│ ║ l.141: (if true then IISS else BBNN) : (0 | 1 | true) -> number
142+
//│ ║ l.140: (if true then IISS else BBNN) : (0 | 1 | true) -> number
144143
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
145144
//│ ╟── type `int -> int & string -> string` is not a function
146145
//│ ║ l.12: def IISS: int -> int & string -> string
147146
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
148147
//│ ╟── but it flows into reference with expected type `(0 | 1 | true) -> number`
149-
//│ ║ l.141: (if true then IISS else BBNN) : (0 | 1 | true) -> number
148+
//│ ║ l.140: (if true then IISS else BBNN) : (0 | 1 | true) -> number
150149
//│ ║ ^^^^
151150
//│ ╟── Note: constraint arises from function type:
152-
//│ ║ l.141: (if true then IISS else BBNN) : (0 | 1 | true) -> number
151+
//│ ║ l.140: (if true then IISS else BBNN) : (0 | 1 | true) -> number
153152
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^
154153
//│ res: (0 | 1 | true) -> number
155154

@@ -168,13 +167,13 @@ not test
168167
//│ <: test:
169168
//│ ~(int -> int)
170169
//│ ╔══[ERROR] Type mismatch in application:
171-
//│ ║ l.166: not test
170+
//│ ║ l.165: not test
172171
//│ ║ ^^^^^^^^
173172
//│ ╟── type `~(int -> int)` is not an instance of type `bool`
174-
//│ ║ l.160: def test: ~(int -> int)
173+
//│ ║ l.159: def test: ~(int -> int)
175174
//│ ║ ^^^^^^^^^^^^^
176175
//│ ╟── but it flows into reference with expected type `bool`
177-
//│ ║ l.166: not test
176+
//│ ║ l.165: not test
178177
//│ ╙── ^^^^
179178
//│ res: bool | error
180179

0 commit comments

Comments
 (0)