diff --git a/shared/src/main/scala/mlscript/ConstraintSolver.scala b/shared/src/main/scala/mlscript/ConstraintSolver.scala index e509a48579..37b3a5441f 100644 --- a/shared/src/main/scala/mlscript/ConstraintSolver.scala +++ b/shared/src/main/scala/mlscript/ConstraintSolver.scala @@ -620,7 +620,9 @@ class ConstraintSolver extends NormalForms { self: Typer => rec(b.inner.ub, ar.inner.ub, false) case (LhsRefined(S(b: ArrayBase), ts, r, _), _) => reportError() case (LhsRefined(S(ov: Overload), ts, r, trs), _) => - annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints + val t = TupleSetConstraints.mk(ov) + annoying(Nil, LhsRefined(S(t), ts, r, trs), Nil, done_rs) + // annoying(Nil, LhsRefined(S(ov.approximatePos), ts, r, trs), Nil, done_rs) // TODO remove approx. with ambiguous constraints case (LhsRefined(S(Without(b, ns)), ts, r, _), RhsBases(pts, N | S(L(_)), _)) => rec(b, done_rs.toType(), true) case (_, RhsBases(pts, S(L(Without(base, ns))), _)) => @@ -818,6 +820,11 @@ class ConstraintSolver extends NormalForms { self: Typer => val newBound = (cctx._1 ::: cctx._2.reverse).foldRight(rhs)((c, ty) => if (c.prov is noProv) ty else mkProxy(ty, c.prov)) lhs.upperBounds ::= newBound // update the bound + lhs.lbtsc.foreach { + case (tsc, i) => + tsc.filterUB(i, rhs) + if (tsc.constraints.isEmpty) reportError() + } lhs.lowerBounds.foreach(rec(_, rhs, true)) // propagate from the bound case (lhs, rhs: TypeVariable) if lhs.level <= rhs.level => @@ -825,6 +832,11 @@ class ConstraintSolver extends NormalForms { self: Typer => val newBound = (cctx._1 ::: cctx._2.reverse).foldLeft(lhs)((ty, c) => if (c.prov is noProv) ty else mkProxy(ty, c.prov)) rhs.lowerBounds ::= newBound // update the bound + rhs.ubtsc.foreach { + case (tsc, i) => + tsc.filterLB(i, lhs) + if (tsc.constraints.isEmpty) reportError() + } rhs.upperBounds.foreach(rec(lhs, _, true)) // propagate from the bound diff --git a/shared/src/main/scala/mlscript/TypeSimplifier.scala b/shared/src/main/scala/mlscript/TypeSimplifier.scala index f2a4b297bf..0e688fe8e3 100644 --- a/shared/src/main/scala/mlscript/TypeSimplifier.scala +++ b/shared/src/main/scala/mlscript/TypeSimplifier.scala @@ -76,7 +76,6 @@ trait TypeSimplifier { self: Typer => .reduceOption(_ &- _).filterNot(_.isTop).toList else Nil } - nv case ComposedType(true, l, r) => diff --git a/shared/src/main/scala/mlscript/TyperDatatypes.scala b/shared/src/main/scala/mlscript/TyperDatatypes.scala index 5fc80392a6..e94cd179c8 100644 --- a/shared/src/main/scala/mlscript/TyperDatatypes.scala +++ b/shared/src/main/scala/mlscript/TyperDatatypes.scala @@ -534,6 +534,9 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer => require(value.forall(_.level <= level)) _assignedTo = value } + + var lbtsc: Opt[(TupleSetConstraints, Int)] = N + var ubtsc: Opt[(TupleSetConstraints, Int)] = N // * Bounds should always be disregarded when `equatedTo` is defined, as they are then irrelevant: def lowerBounds: List[SimpleType] = { require(assignedTo.isEmpty, this); _lowerBounds } @@ -646,5 +649,112 @@ abstract class TyperDatatypes extends TyperHelpers { Typer: Typer => lazy val underlying: SimpleType = tt.neg() val prov = noProv } - + + class TupleSetConstraints(val constraints: MutSet[Ls[ST]], var tvs: Ls[TV])(val prov: TypeProvenance) { + def filterUB(index: Int, ub: ST)(implicit raise: Raise, ctx: Ctx): Unit = { + def go(ub: ST): Unit = ub match { + case ub: TV => + ub.upperBounds.foreach(go) + ub.lbtsc = S(this, index) + case _ => + constraints.filterInPlace { constrs => + val ty = constrs(index) + val dnf = DNF.mk(MaxLevel, Nil, ty & ub.neg(), true) + dnf.isBot || dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty)) + } + } + go(ub) + println(s"TSC filterUB: $tvs in $constraints") + if (constraints.sizeCompare(1) === 0) { + constraints.head.zip(tvs).foreach { + case (ty, tv) => + tv.lbtsc = N + tv.ubtsc = N + constrain(tv, ty)(raise, prov, ctx) + constrain(ty, tv)(raise, prov, ctx) + } + } + } + def filterLB(index: Int, lb: ST)(implicit raise: Raise, ctx: Ctx): Unit = { + constraints.filterInPlace { constrs => + val ty = constrs(index) + val dnf = DNF.mk(MaxLevel, Nil, lb & ty.neg(), true) + dnf.isBot || dnf.cs.forall(c => !(c.vars.isEmpty && c.nvars.isEmpty)) + } + println(s"TSC filterLB: $tvs in $constraints") + if (constraints.sizeCompare(1) === 0) { + constraints.head.zip(tvs).foreach { + case (ty, tv) => + tv.lbtsc = N + tv.ubtsc = N + constrain(tv, ty)(raise, prov, ctx) + constrain(ty, tv)(raise, prov, ctx) + } + } + } + } + object TupleSetConstraints { + def lcgField(first: FieldType, rest: Ls[FieldType]) + (implicit prov: TypeProvenance, lvl: Level) + : (FieldType, Ls[TV], Ls[Ls[ST]]) = { + val (ub, tvs, constrs) = lcg(first.ub, rest.map(_.ub)) + if (first.lb.isEmpty && rest.forall(_.lb.isEmpty)) { + (FieldType(N, ub)(prov), tvs, constrs) + } else { + val (lb, ltvs, lconstrs) = lcg(first.lb.getOrElse(BotType), rest.map(_.lb.getOrElse(BotType))) + (FieldType(S(lb), ub)(prov), tvs ++ ltvs, constrs ++ lconstrs) + } + } + def lcg(first: ST, rest: Ls[ST]) + (implicit prov: TypeProvenance, lvl: Level) + : (ST, Ls[TV], Ls[Ls[ST]]) = first match { + case a: FunctionType if rest.forall(_.isInstanceOf[FunctionType]) => + val (lhss, rhss) = rest.collect { + case FunctionType(lhs, rhs) => lhs -> rhs + }.unzip + val (lhs, ltvs, lconstrs) = lcg(a.lhs, lhss) + val (rhs, rtvs, rconstrs) = lcg(a.rhs, rhss) + (FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs) + case a: ArrayType if rest.forall(_.isInstanceOf[ArrayType]) => + val inners = rest.collect { case b: ArrayType => b.inner } + val (t, tvs, constrs) = lcgField(a.inner, inners) + (ArrayType(t)(prov), tvs, constrs) + case a: TupleType if rest.forall { case b: TupleType => a.fields.sizeCompare(b.fields.size) === 0; case _ => false } => + val fields = rest.collect { case TupleType(fields) => fields.map(_._2) } + val (fts, tvss, constrss) = a.fields.map(_._2).zip(fields.transpose).map { case (a, bs) => lcgField(a, bs) }.unzip3 + (TupleType(fts.map(N -> _))(prov), tvss.flatten, constrss.flatten) + case a: TR if rest.forall { case b: TR => a.defn === b.defn && a.targs.sizeCompare(b.targs.size) === 0; case _ => false } => + val targs = rest.collect { case b: TR => b.targs } + val (ts, tvss, constrss) = a.targs.zip(targs.transpose).map { case (a, bs) => lcg(a, bs) }.unzip3 + (TypeRef(a.defn, ts)(prov), tvss.flatten, constrss.flatten) + case a: TV if rest.forall { case b: TV => a.compare(b) === 0; case _ => false } => (a, Nil, Nil) + case a if rest.forall(_ === a) => (a, Nil, Nil) + case _ => + val tv = freshVar(prov, N) + (tv, List(tv), List(first :: rest)) + } + def lcgFunction(first: FunctionType, rest: Ls[FunctionType])(implicit prov: TypeProvenance, lvl: Level) + : (FunctionType, Ls[TV], Ls[Ls[ST]]) = { + val (lhss, rhss) = rest.map { + case FunctionType(lhs, rhs) => lhs -> rhs + }.unzip + val (lhs, ltvs, lconstrs) = lcg(first.lhs, lhss) + val (rhs, rtvs, rconstrs) = lcg(first.rhs, rhss) + (FunctionType(lhs, rhs)(prov), ltvs ++ rtvs, lconstrs ++ rconstrs) + } + def mk(ov: Overload)(implicit lvl: Level): FunctionType = { + def unwrap(t: ST): ST = t.map(unwrap) + if (ov.alts.tail.isEmpty) ov.alts.head else { + val f = ov.mapAlts(unwrap)(unwrap) + val (t, tvs, constrs) = lcgFunction(f.alts.head, f.alts.tail)(ov.prov, lvl) + val tsc = new TupleSetConstraints(MutSet.empty ++ constrs.transpose, tvs)(ov.prov) + tvs.zipWithIndex.foreach { case (tv, i) => + tv.lbtsc = S((tsc, i)) + tv.ubtsc = S((tsc, i)) + } + println(s"TSC mk: ${tsc.tvs} in ${tsc.constraints}") + t + } + } + } } diff --git a/shared/src/main/scala/mlscript/TyperHelpers.scala b/shared/src/main/scala/mlscript/TyperHelpers.scala index 727b0001c7..2e20b650c8 100644 --- a/shared/src/main/scala/mlscript/TyperHelpers.scala +++ b/shared/src/main/scala/mlscript/TyperHelpers.scala @@ -983,13 +983,24 @@ abstract class TyperHelpers { Typer: Typer => def getVars: SortedSet[TypeVariable] = getVarsImpl(includeBounds = true) def showBounds: String = - getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty).map { + getVars.iterator.filter(tv => tv.assignedTo.nonEmpty || (tv.upperBounds ++ tv.lowerBounds).nonEmpty || (tv.lbtsc.isDefined && tv.ubtsc.isEmpty)).map { case tv @ AssignedVariable(ty) => "\n\t\t" + tv.toString + " := " + ty case tv => ("\n\t\t" + tv.toString + (if (tv.lowerBounds.isEmpty) "" else " :> " + tv.lowerBounds.mkString(" | ")) - + (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & "))) - }.mkString - + + (if (tv.upperBounds.isEmpty) "" else " <: " + tv.upperBounds.mkString(" & ")) + + tv.lbtsc.fold(""){ case (tsc, i) => " :> " + tsc.tvs(i) } ) + }.mkString + { + val visited: MutSet[TV] = MutSet.empty + getVars.iterator.filter(tv => tv.ubtsc.isDefined).map { + case tv if visited.contains(tv) => "" + case tv => + visited ++= tv.lbtsc.fold(Nil: Ls[TV])(_._1.tvs) + tv.lbtsc.fold("") { case (tsc, _) => ("\n\t\t[ " + + tsc.tvs.mkString(", ") + + " ] in { " + tsc.constraints.mkString(", ") + " }") + } + }.mkString + } } diff --git a/shared/src/test/diff/nu/ArrayProg.mls b/shared/src/test/diff/nu/ArrayProg.mls index 857ee4e3ac..c5243ca638 100644 --- a/shared/src/test/diff/nu/ArrayProg.mls +++ b/shared/src/test/diff/nu/ArrayProg.mls @@ -155,7 +155,7 @@ module A { //│ } A.g(0) -//│ Int | Str +//│ Int //│ res //│ = 0 diff --git a/shared/src/test/diff/nu/HeungTung.mls b/shared/src/test/diff/nu/HeungTung.mls index 51f4efad8b..c810a3122b 100644 --- a/shared/src/test/diff/nu/HeungTung.mls +++ b/shared/src/test/diff/nu/HeungTung.mls @@ -66,8 +66,16 @@ fun g = h //│ fun g: (Bool | Int) -> (Int | false | true) // * In one step +:e // TODO: argument of union type fun g: (Int | Bool) -> (Int | Bool) fun g = f +//│ ╔══[ERROR] Type mismatch in definition: +//│ ║ l.71: fun g = f +//│ ║ ^^^^^ +//│ ╟── expression of type `Int | false | true` does not match type `?a` +//│ ╟── Note: constraint arises from function type: +//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool) +//│ ╙── ^^^^^^^^^^^^^^ //│ fun g: Int -> Int & Bool -> Bool //│ fun g: (Bool | Int) -> (Int | false | true) @@ -88,9 +96,11 @@ fun j = i fun j: (Int & Bool) -> (Int & Bool) fun j = f //│ ╔══[ERROR] Type mismatch in definition: -//│ ║ l.89: fun j = f +//│ ║ l.97: fun j = f //│ ║ ^^^^^ -//│ ╙── expression of type `Int` does not match type `nothing` +//│ ╟── type `?a` does not match type `nothing` +//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool) +//│ ╙── ^^^^^^^^^^^^^^ //│ fun j: Int -> Int & Bool -> Bool //│ fun j: nothing -> nothing @@ -106,7 +116,7 @@ fun g = f // * With match-type-based constraint solving, we could return Int here f(0) -//│ Int | false | true +//│ Int //│ res //│ = 0 @@ -114,15 +124,22 @@ f(0) x => f(x) -//│ (Bool | Int) -> (Int | false | true) +//│ anything -> nothing //│ res //│ = [Function: res] // : forall 'a: 'a -> case 'a of { Int => Int; Bool => Bool } where 'a <: Int | Bool - +:e f(if true then 0 else false) -//│ Int | false | true +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.134: f(if true then 0 else false) +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── expression of type `0 | false` does not match type `?a` +//│ ╟── Note: constraint arises from function type: +//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool) +//│ ╙── ^^^^^^^^^^^^^^ +//│ error //│ res //│ = 0 @@ -132,12 +149,21 @@ f(if true then 0 else false) :w f(refined if true then 0 else false) // this one can be precise again! //│ ╔══[WARNING] Paren-less applications should use the 'of' keyword -//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again! //│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ //│ ╔══[ERROR] identifier not found: refined -//│ ║ l.133: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again! //│ ╙── ^^^^^^^ -//│ Int | false | true +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── application of type `error` does not match type `?a` +//│ ║ l.150: f(refined if true then 0 else false) // this one can be precise again! +//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.50: fun f: (Int -> Int) & (Bool -> Bool) +//│ ╙── ^^^^^^^^^^^^^^ +//│ error //│ Code generation encountered an error: //│ unresolved symbol refined @@ -193,7 +219,7 @@ type T = List[Int] :e // TODO application types type Res = M(T) //│ ╔══[ERROR] Wrong number of type arguments – expected 0, found 1 -//│ ║ l.194: type Res = M(T) +//│ ║ l.220: type Res = M(T) //│ ╙── ^^^^ //│ type Res = M @@ -216,7 +242,7 @@ fun f: Int -> Int fun f: Bool -> Bool fun f = id //│ ╔══[ERROR] A type signature for 'f' was already given -//│ ║ l.216: fun f: Bool -> Bool +//│ ║ l.242: fun f: Bool -> Bool //│ ╙── ^^^^^^^^^^^^^^^^^^^ //│ fun f: forall 'a. 'a -> 'a //│ fun f: Int -> Int @@ -224,13 +250,13 @@ fun f = id :e // TODO support f: (Int -> Int) & (Bool -> Bool) //│ ╔══[ERROR] Type mismatch in type ascription: -//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool) +//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool) //│ ║ ^ //│ ╟── type `Bool` is not an instance of `Int` -//│ ║ l.225: f: (Int -> Int) & (Bool -> Bool) +//│ ║ l.251: f: (Int -> Int) & (Bool -> Bool) //│ ║ ^^^^ //│ ╟── Note: constraint arises from type reference: -//│ ║ l.215: fun f: Int -> Int +//│ ║ l.241: fun f: Int -> Int //│ ╙── ^^^ //│ Int -> Int & Bool -> Bool //│ res @@ -297,14 +323,14 @@ fun test(x) = refined if x is A then 0 B then 1 //│ ╔══[WARNING] Paren-less applications should use the 'of' keyword -//│ ║ l.296: fun test(x) = refined if x is +//│ ║ l.322: fun test(x) = refined if x is //│ ║ ^^^^^^^^^^^^^^^ -//│ ║ l.297: A then 0 +//│ ║ l.323: A then 0 //│ ║ ^^^^^^^^^^ -//│ ║ l.298: B then 1 +//│ ║ l.324: B then 1 //│ ╙── ^^^^^^^^^^ //│ ╔══[ERROR] identifier not found: refined -//│ ║ l.296: fun test(x) = refined if x is +//│ ║ l.322: fun test(x) = refined if x is //│ ╙── ^^^^^^^ //│ fun test: (A | B) -> error //│ Code generation encountered an error: diff --git a/shared/src/test/diff/nu/WeirdUnions.mls b/shared/src/test/diff/nu/WeirdUnions.mls index 03ab69c023..e92c6ca504 100644 --- a/shared/src/test/diff/nu/WeirdUnions.mls +++ b/shared/src/test/diff/nu/WeirdUnions.mls @@ -47,14 +47,24 @@ fun f: (Str => Str) & ((Str, Int) => Int) //│ fun f: Str -> Str & (Str, Int) -> Int // * ...resulting in approximation at call sites (we don't handle overloading) +:e // TODO f("abc", "abc") -//│ Int | Str +//│ ╔══[ERROR] Type mismatch in application: +//│ ║ l.51: f("abc", "abc") +//│ ║ ^^^^^^^^^^^^^^^ +//│ ╟── argument list of type `["abc", "abc"]` does not match type `?a` +//│ ║ l.51: f("abc", "abc") +//│ ║ ^^^^^^^^^^^^^^ +//│ ╟── Note: constraint arises from function type: +//│ ║ l.46: fun f: (Str => Str) & ((Str, Int) => Int) +//│ ╙── ^^^^^^^^^^^^^^^^^^^ +//│ error //│ res //│ = //│ f is not implemented f("abcabc") -//│ Int | Str +//│ Str //│ res //│ = //│ f is not implemented @@ -71,19 +81,19 @@ let r = if true then id else (x, y) => [y, x] r(error) r(error, error) //│ ╔══[ERROR] Type mismatch in application: -//│ ║ l.71: r(error) +//│ ║ l.81: r(error) //│ ║ ^^^^^^^^ //│ ╟── argument of type `[nothing]` does not match type `[?a, ?b]` -//│ ║ l.71: r(error) +//│ ║ l.81: r(error) //│ ║ ^^^^^^^ //│ ╟── Note: constraint arises from tuple literal: -//│ ║ l.65: let r = if true then id else (x, y) => [y, x] +//│ ║ l.75: let r = if true then id else (x, y) => [y, x] //│ ╙── ^^^^ //│ ╔══[ERROR] Type mismatch in application: -//│ ║ l.72: r(error, error) +//│ ║ l.82: r(error, error) //│ ║ ^^^^^^^^^^^^^^^ //│ ╟── argument list of type `[nothing, nothing]` does not match type `[?a]` -//│ ║ l.72: r(error, error) +//│ ║ l.82: r(error, error) //│ ╙── ^^^^^^^^^^^^^^ //│ error | [nothing, nothing] //│ res @@ -106,7 +116,7 @@ r of [0, 1] // Also currently parses the same: let r = if true then id else [x, y] => [y, x] -//│ let r: forall 'a 'b 'c. (['b, 'c] & 'a) -> (['c, 'b] | 'a) +//│ let r: forall 'a 'b 'c. (['a, 'b] & 'c) -> (['b, 'a] | 'c) //│ r //│ = [Function: id]