Skip to content

Commit afb08c5

Browse files
committed
wf check and typing function intersection
1 parent 9e13a34 commit afb08c5

File tree

5 files changed

+249
-144
lines changed

5 files changed

+249
-144
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 50 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -169,24 +169,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
169169
case Neg(rhs) =>
170170
mono(rhs, !pol).!
171171
case CompType(lhs, rhs, pol) =>
172-
val (l, r) = (typeMonoType(lhs), typeMonoType(rhs))
173-
if !pol then
174-
val lfa = l.toDnf.conjs.flatMap(_.i.v).collect:
175-
case (f :: fs) =>
176-
val fd = Type.discriminant(f.args).flatten
177-
fs.foldLeft(fd: Type, fd.fields.keys.toSet): (x, y) =>
178-
val d = Type.discriminant(y.args).flatten
179-
(x._1 | d, x._2 & d.fields.keys.toSet)
180-
val rfa = r.toDnf.conjs.flatMap(_.i.v).collect:
181-
case (f :: fs) =>
182-
val fd = Type.discriminant(f.args).flatten
183-
fs.foldLeft(fd: Type, fd.fields.keys.toSet): (x, y) =>
184-
val d = Type.discriminant(y.args).flatten
185-
(x._1 | d, x._2 & d.fields.keys.toSet)
186-
val d = lfa.iterator.flatMap(x => rfa.iterator.map(y => (x, y))).exists:
187-
case ((x, u), (y, w)) => (u & w).isEmpty || Type.disjoint(x, y) =/= S(Set.empty)
188-
if d then error(msg"Ill-formed functions intersection" -> ty.toLoc :: Nil)
189-
Type.mkComposedType(l, r, pol)
172+
Type.mkComposedType(typeMonoType(lhs), typeMonoType(rhs), pol)
190173
case _ =>
191174
ty.symbol.flatMap(_.asTpe) match
192175
case S(cls: (ClassSymbol | TypeAliasSymbol)) => typeAndSubstType(Term.TyApp(ty, Nil)(N), pol)
@@ -201,17 +184,39 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
201184
tv -> qv
202185
bds.foreach:
203186
case (tv, QuantVar(_, ub, lb)) =>
204-
ub.foreach(ub => tv.state.upperBounds ::= typeMonoType(ub))
205-
lb.foreach(lb => tv.state.lowerBounds ::= typeMonoType(lb))
187+
ub.foreach(ub => tv.state.upperBounds ::= monoOrErr(typeType(ub), ub))
188+
lb.foreach(lb => tv.state.lowerBounds ::= monoOrErr(typeType(lb), lb))
206189
val lbty = tv.state.lowerBounds.foldLeft[Type](Bot)(_ | _)
207190
val ubty = tv.state.upperBounds.foldLeft[Type](Top)(_ & _)
208191
solver.constrain(lbty, ubty)
209192
PolyType(bds.map(_._1), S(outer), body)
210193

211-
private def typeMonoType(ty: Term)(using ctx: BbCtx, cctx: CCtx): Type = monoOrErr(typeType(ty), ty)
212-
213-
private def typeType(ty: Term)(using ctx: BbCtx, cctx: CCtx): GeneralType =
214-
typeAndSubstType(ty, pol = true)(using Map.empty)
194+
private def typeMonoType(ty: Term)(using ctx: BbCtx, cctx: CCtx): Type = monoOrErr(typeAndSubstType(ty, true)(using Map.empty), ty)
195+
196+
private def wffuns(fs: Ls[FunType]) =
197+
val wf = fs.forall(f => (f.ret :: f.eff :: f.args).forall(wftype))
198+
wf && fs.combinations(2).forall: u =>
199+
Type.disjoint(Type.discriminant(u.head.args), Type.discriminant(u.tail.head.args)).exists(_.isEmpty)
200+
private def wfrcds(rs: Ls[RcdType]) =
201+
val wf = rs.forall(_.fields.forall(u => wftype(u._2)))
202+
wf && rs.combinations(2).forall: u =>
203+
Type.disjoint(u.head, u.tail.head).exists(_.isEmpty)
204+
private def wfcls(cs: Ls[ClassLikeType]) =
205+
cs.forall(_.targs.forall(u => wftype(u.posPart) && wftype(u.negPart)))
206+
private def wftype(ty: GeneralType): Bool = ty match
207+
case t: PolyType => wftype(t.body)
208+
case t: PolyFunType => (t.ret :: t.eff :: t.args).forall(wftype)
209+
case t: Type =>
210+
val n = t.!.toDnf.conjs
211+
val nf = n.iterator.map(_.i.v).forall:
212+
case S(f: Ls[FunType]) => false
213+
case _ => true
214+
nf && wffuns(n.flatMap(_.u.fun)) && n.forall(c => wfrcds(c.u.rcd) && wfcls(c.u.cls))
215+
216+
private def typeType(ty: Term, map: Map[Uid[Symbol], TypeArg] = Map.empty)(using ctx: BbCtx, cctx: CCtx): GeneralType =
217+
val t = typeAndSubstType(ty, pol = true)(using map)
218+
if !wftype(t) then error(msg"Ill-formed type" -> ty.toLoc :: Nil)
219+
t
215220

216221
private def instantiate(ty: PolyType)(using ctx: BbCtx): GeneralType =
217222
ty.instantiate(infVarState.nextUid, freshEnv(new TempSymbol(N, "env")), ctx.lvl)
@@ -486,13 +491,29 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
486491
ascribe(term, typeType(ty))
487492
ascribe(term, rhs)
488493
case _ =>
489-
val (lhsTy, eff) = typeCheck(lhs)
490494
rhs match
491495
case pf: PolyFunType if pf.isPoly =>
492496
(error(msg"Cannot type non-function term ${lhs.toString} as ${rhs.show}" -> lhs.toLoc :: Nil), Bot)
493497
case _ =>
494-
constrain(tryMkMono(lhsTy, lhs), monoOrErr(rhs, lhs))
495-
(rhs, eff)
498+
val r = monoOrErr(rhs, lhs)
499+
val n = r.!.toDnf.conjs
500+
if n.flatMap(_.u.fun).length > 1 then
501+
val eff = n.foldLeft(Bot: Type): (e, c) =>
502+
(c.i.v, c.u, c.vars) match
503+
case (i, u, v) =>
504+
val n = i.fold(Bot: Type):
505+
case x: (ClassLikeType | RcdType) => x.!
506+
case x: Ls[FunType] => x.reduce[Type](_ & _).!
507+
val k = Ls(u.fun, u.cls, u.cls).flatten.foldLeft(Bot: Type)(_ | _)
508+
val vs = v.foldLeft(Bot: Type): (x, y) =>
509+
x | (if y._2 then y._1.! else y._1)
510+
val (_, eff) = ascribe(lhs, n | k | vs)
511+
e | eff
512+
(rhs, eff)
513+
else
514+
val (lhsTy, eff) = typeCheck(lhs)
515+
constrain(tryMkMono(lhsTy, lhs), r)
516+
(rhs, eff)
496517

497518
// TODO: t -> loc when toLoc is implemented
498519
private def app(lhs: (GeneralType, Type), rhs: Ls[Elem], t: Term)
@@ -620,7 +641,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
620641
case Param(_, sym, sign, _) =>
621642
if sym.nme === field.name then sign else N
622643
}.filter(_.isDefined)) match
623-
case S(res) :: Nil => (typeAndSubstType(res, pol = true)(using map.toMap), eff)
644+
case S(res) :: Nil => (typeType(res, map.toMap), eff)
624645
case _ => (error(msg"${field.name} is not a valid member in class ${clsSym.nme}" -> t.toLoc :: Nil), Bot)
625646
case N =>
626647
(error(msg"Not a valid class: ${cls.describe}" -> cls.toLoc :: Nil), Bot)
@@ -650,7 +671,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL)(using Config):
650671
require(clsDfn.paramsOpt.forall(_.restParam.isEmpty))
651672
args.iterator.zip(clsDfn.params.params).foreach {
652673
case (arg, Param(sign = S(sign))) =>
653-
val (ty, eff) = ascribe(arg, typeAndSubstType(sign, pol = true)(using map.toMap))
674+
val (ty, eff) = ascribe(arg, typeType(sign, map.toMap))
654675
effBuff += eff
655676
case _ => ???
656677
}

hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,11 +273,6 @@ case class RcdType(fields: Ls[Str -> Type]) extends BasicType with CachedNorm[Rc
273273
RcdType(fields.mapValues(_.subst))
274274
def & (that: RcdType): RcdType =
275275
RcdType((fields ++ that.fields).groupMapReduce(_._1)(_._2)(_ & _).toList)
276-
def flatten: RcdType =
277-
RcdType(fields.flatMap: u =>
278-
(u._1, u._2.toBasic.simp.toBasic) match
279-
case (a, r: RcdType) => r.flatten.fields.map(u => (s"$a.${u._1}", u._2))
280-
case u => Ls(u))
281276

282277
case class ComposedType(lhs: Type, rhs: Type, pol: Bool) extends BasicType: // * Positive -> union
283278
override def subst(using map: Map[Uid[InfVar], InfVar]): ThisType =

hkmc2/shared/src/test/mlscript/logicsub/DisjSub.mls

Lines changed: 1 addition & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ new Pair(1, 2)
105105
// WF check on intersection types
106106
:e
107107
fun idIIBB: (Pair[Int, Int] -> Int) & (Pair[Bool, Bool] -> Bool)
108-
//│ ╔══[ERROR] Ill-formed functions intersection
108+
//│ ╔══[ERROR] Ill-formed type
109109
//│ ║ l.107: fun idIIBB: (Pair[Int, Int] -> Int) & (Pair[Bool, Bool] -> Bool)
110110
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
111111
//│ Type: ⊤
@@ -271,94 +271,3 @@ x => k(x, x)
271271
//│ Type: ('x) ->{'eff} 'app
272272
//│ Where:
273273
//│ 'x#'x ∨ ( 'x#Int ∨ Int <: 'app ∧ ⊥ <: 'eff) ∧ ( 'x#Bool ∨ Bool <: 'app ∧ ⊥ <: 'eff){0: 'x, 1: 'x} <: {0: Int, 1: Int} | {0: Bool, 1: Bool}
274-
275-
:e
276-
fun ill: ((Int | Bool, Int | Bool) -> Str) & ((Bool | Str, Bool | Str) -> Str)
277-
//│ ╔══[ERROR] Ill-formed functions intersection
278-
//│ ║ l.276: fun ill: ((Int | Bool, Int | Bool) -> Str) & ((Bool | Str, Bool | Str) -> Str)
279-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
280-
//│ Type: ⊤
281-
282-
:e
283-
fun id: [A] -> (A -> A) & (Int -> Int)
284-
//│ ╔══[ERROR] Ill-formed functions intersection
285-
//│ ║ l.283: fun id: [A] -> (A -> A) & (Int -> Int)
286-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^
287-
//│ Type: ⊤
288-
289-
fun fst: [A] -> (Pair[A, Int] -> Pair[A, Int]) & (Ls[A] -> A)
290-
//│ Type: ⊤
291-
292-
fst(new Pair(true, 1))
293-
//│ Type: Pair[out Bool, out Int]
294-
295-
fst(cons(1, nil()))
296-
//│ Type: Int
297-
298-
:e
299-
fun idInt: (Int -> Int) & (Int -> Int)
300-
//│ ╔══[ERROR] Ill-formed functions intersection
301-
//│ ║ l.299: fun idInt: (Int -> Int) & (Int -> Int)
302-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^
303-
//│ Type: ⊤
304-
305-
fun wf: ((Int | Bool -> Int) | (Bool -> Bool)) & (Str -> Str)
306-
//│ Type: ⊤
307-
308-
wf("")
309-
//│ Type: Str
310-
311-
wf(true)
312-
//│ Type: Bool | Int
313-
314-
:e
315-
wf(1)
316-
//│ ╔══[ERROR] Type error in application
317-
//│ ║ l.315: wf(1)
318-
//│ ║ ^^^^^
319-
//│ ╟── because: cannot constrain (((Int | Bool) -> Int) | (Bool -> Bool)) & (Str -> Str) <: Int ->{'eff} 'app
320-
//│ ╟── because: cannot constrain {0: Int} <: {0: Bool} | {0: Str}
321-
//│ ╙── because: cannot constrain ⊤ <: ⊥
322-
//│ Type: Int
323-
324-
class
325-
Z()
326-
S()
327-
//│ Type: ⊤
328-
329-
fun wf: ((Z, Z) -> Z) & ((Z, S) -> S) & ((S, Z | S) -> S)
330-
//│ Type: ⊤
331-
332-
fun zs: Z | S
333-
//│ Type: ⊤
334-
335-
wf(zs, zs)
336-
//│ Type: S | Z
337-
338-
:e
339-
fun ill:
340-
(([a: Int, b:Int]) -> Int) &
341-
(([b: Bool, c: Bool]) -> Bool) &
342-
(([a: Str, c: Str]) -> Str)
343-
//│ ╔══[ERROR] Ill-formed functions intersection
344-
//│ ║ l.340: (([a: Int, b:Int]) -> Int) &
345-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^
346-
//│ ║ l.341: (([b: Bool, c: Bool]) -> Bool) &
347-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
348-
//│ ║ l.342: (([a: Str, c: Str]) -> Str)
349-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
350-
//│ Type: ⊤
351-
352-
:e
353-
fun nested:
354-
(([tag: [a: Int, b:Int]]) -> Int) &
355-
(([tag: [b: Bool, c: Bool]]) -> Bool) &
356-
(([tag: [a: Str, c: Str]]) -> Str)
357-
//│ ╔══[ERROR] Ill-formed functions intersection
358-
//│ ║ l.354: (([tag: [a: Int, b:Int]]) -> Int) &
359-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
360-
//│ ║ l.355: (([tag: [b: Bool, c: Bool]]) -> Bool) &
361-
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
362-
//│ ║ l.356: (([tag: [a: Str, c: Str]]) -> Str)
363-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
364-
//│ Type: ⊤

hkmc2/shared/src/test/mlscript/logicsub/elixir.mls

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,11 @@ fun subtract(a, b) = if
5555
//│ ╙── because: cannot constrain (Int | Tru) | Fls <: Int
5656
//│ Type: ⊤
5757

58-
// TODO impl
5958
fun negate: (Int -> Int) & (Tru | Fls -> Tru | Fls)
59+
fun negate(x) = if x is
60+
Int then 0 - x
61+
Tru then not(x)
62+
Fls then not(x)
6063
//│ Type: ⊤
6164

6265
fun subtract: (Int, Int) -> Int
@@ -91,15 +94,15 @@ fun _in(x) = if x is
9194
fun _out(x) = x.Ls#prim(() => new Nil, (x, y) => new Cons(x, y))
9295
//│ Type: ⊤
9396

94-
// fun map: [A, B] -> (Ls[A], A -> B) -> Ls[B]
97+
fun map: [A, B] -> (Ls[A], A -> B) -> Ls[B]
9598
fun map(xs, f) =
9699
let x = _out(xs)
97100
if x is
98101
Cons then cons(f(car(x)), map(cdr(x), f))
99102
Nil then nil()
100103
//│ Type: ⊤
101104

102-
// fun reduce: [A, B] -> (Ls[A], B, (A, B) -> B) -> B
105+
fun reduce: [A, B] -> (Ls[A], B, (A, B) -> B) -> B
103106
fun reduce(xs, acc, f) =
104107
let x = _out(xs)
105108
if x is
@@ -111,11 +114,9 @@ let xs = cons(1, cons(4, nil()))
111114
//│ Type: ⊤
112115

113116
xs map(negate)
114-
//│ Type: Ls['A] | Ls['A1]
117+
//│ Type: Ls['B]
115118
//│ Where:
116-
//│ Int <: 'A1
117-
//│ 'A1 <: 'A
118-
//│ 'A <: 'A1
119+
//│ Int <: 'B
119120

120121
fun map(xs, f) = if xs is
121122
Cons then cons(f(car(xs)), map(_out(cdr(xs)), f))
@@ -141,21 +142,11 @@ xs _out() map(negate)
141142
//│ 'A1 <: 'A
142143
//│ 'A <: 'A1
143144

144-
145-
146-
// TODO
147-
// fun negate: (Int -> Int) & (Bool -> Bool) & (~(Int | Bool) -> ~(Int | Bool))
145+
fun negate: (Int -> Int) & (Bool -> Bool) & (~(Int | Bool) -> ~(Int | Bool))
148146
fun negate(x) = if x is
149147
Int then 0 - x
150148
Tru then not(x)
151149
Fls then not(x)
152150
else
153151
x
154152
//│ Type: ⊤
155-
156-
negate as Int -> Int
157-
negate as Bool -> Bool
158-
negate as ~(Int | Bool) -> ~(Int | Bool)
159-
//│ Type: (¬Int & ¬Bool) ->{⊥} ¬Int & ¬Bool
160-
161-
// type tree(a) = (a and not list()) or [tree(a)]

0 commit comments

Comments
 (0)