Skip to content

Commit c7ad33c

Browse files
Fix constraint cache and unexpected generalizations in BbML (hkust-taco#305)
* Fix cache * Improve * WIP: Cache properly and improve nf show * Fix unexpected generalizations * Add more test cases * Fix missing level check * Rerun test * Add comments
1 parent 12479ae commit c7ad33c

20 files changed

+358
-229
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,17 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
8787
else
8888
val bd = if v.lvl >= rest.lvl then rest else extrude(rest)(using v.lvl, true, mutable.HashMap.empty)
8989
if pol then
90-
val nc = Type.mkNegType(bd)
90+
val nc = Type.mkNegType(bd).toDnf // always cache the normal form to avoid unexpected cache misses
9191
log(s"New bound: ${v.showDbg} <: ${nc.showDbg}")
9292
cctx.nest(v -> nc) givenIn:
9393
v.state.upperBounds ::= nc
9494
v.state.lowerBounds.foreach(lb => constrainImpl(lb, nc))
9595
else
96-
log(s"New bound: ${v.showDbg} :> ${bd.showDbg}")
97-
cctx.nest(bd -> v) givenIn:
98-
v.state.lowerBounds ::= bd
99-
v.state.upperBounds.foreach(ub => constrainImpl(bd, ub))
96+
val c = bd.toDnf // always cache the normal form to avoid unexpected cache misses
97+
log(s"New bound: ${v.showDbg} :> ${c.showDbg}")
98+
cctx.nest(c -> v) givenIn:
99+
v.state.lowerBounds ::= c
100+
v.state.upperBounds.foreach(ub => constrainImpl(c, ub))
100101
case Conj(i, u, Nil) => (conj.i, conj.u) match
101102
case (_, Union(N, Nil)) =>
102103
// raise(ErrorReport(msg"Cannot solve ${conj.i.toString()} ∧ ¬⊥" -> N :: Nil))
@@ -137,9 +138,10 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
137138
case _: ClassLikeType | _: FunType | _: InfVar | Top | Bot => ty
138139

139140
private def constrainImpl(lhs: Type, rhs: Type)(using BbCtx, CCtx, TL): Unit =
140-
if cctx.cache((lhs, rhs)) then log(s"Cached!")
141+
val p = lhs.toDnf -> rhs.toDnf
142+
if cctx.cache(p) then log(s"Cached!")
141143
else trace(s"CONSTRAINT ${lhs.showDbg} <: ${rhs.showDbg}"):
142-
cctx.nest(lhs -> rhs) givenIn:
144+
cctx.nest(p) givenIn:
143145
val ty = dnf(inlineSkolemBounds(lhs & rhs.!, true)(using Set.empty))
144146
constrainDNF(ty)
145147
def constrain(lhs: Type, rhs: Type)(using BbCtx, CCtx, TL): Unit =

hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,22 @@ extends NormalForm with CachedBasicType:
5454
})
5555
def toDnf(using TL): Disj = Disj(this :: Nil)
5656
override def show(using Scope): Str =
57-
((i :: Nil).filterNot(_.isTop).map(_.show) :::
57+
val s = ((i :: Nil).filterNot(_.isTop).map(_.show) :::
5858
(u :: Nil).filterNot(_.isBot).map("¬{"+_.show+"}") :::
5959
vars.map:
6060
case (tv, true) => tv.show
6161
case (tv, false) => "¬" + tv.show
6262
).mkString("")
63+
if s.isEmpty then "" else s
6364

6465
override def showDbg: Str =
65-
((i :: Nil).filterNot(_.isTop).map(_.showDbg) :::
66+
val s = ((i :: Nil).filterNot(_.isTop).map(_.showDbg) :::
6667
(u :: Nil).filterNot(_.isBot).map("~{"+_.showDbg+"}") :::
6768
vars.map:
6869
case (tv, true) => tv.showDbg
6970
case (tv, false) => "~" + tv.showDbg
7071
).mkString(" && ")
72+
if s.isEmpty then "" else s
7173
object Conj:
7274
// * Conj objects cannot be created with `new` except in this file.
7375
// * This is because we want to sort the vars in the apply function.

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -267,23 +267,23 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
267267
case _ =>
268268
(error(msg"Cannot quote ${code.toString}" -> code.toLoc :: Nil), Bot, Bot)
269269

270-
private def typeFunDef(sym: Symbol, lam: Term, sig: Opt[Term], pctx: BbCtx)(using ctx: BbCtx, cctx: CCtx, scope: Scope) = lam match
270+
private def typeFunDef(sym: Symbol, lam: Term, sig: Opt[Term])(using ctx: BbCtx, cctx: CCtx, scope: Scope) = lam match
271271
case Term.Lam(params, body) => sig match
272272
case S(sig) =>
273273
val sigTy = typeType(sig)(using ctx)
274-
pctx += sym -> sigTy
274+
ctx += sym -> sigTy
275275
ascribe(lam, sigTy)
276276
()
277277
case N =>
278278
val outer = freshOuter(new TempSymbol(S(lam), "outer"))(using ctx)
279279
given BbCtx = ctx.nestWithOuter(outer)
280280
val funTyV = freshVar(sym)
281-
pctx += sym -> funTyV // for recursive functions
281+
ctx += sym -> funTyV // for recursive functions
282282
val (res, _) = typeCheck(lam)
283283
val funTy = tryMkMono(res, lam)
284284
given CCtx = CCtx.init(lam, N)
285285
constrain(funTy, funTyV)(using ctx)
286-
pctx += sym -> PolyType.generalize(funTy, S(outer), 1)
286+
ctx += sym -> PolyType.generalize(funTy, S(outer), ctx.lvl + 1)
287287
case _ => error(msg"Function definition shape not yet supported for ${sym.nme}" -> lam.toLoc :: Nil)
288288

289289
private def typeSplit
@@ -443,10 +443,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
443443
ctx += sym -> rhsTy
444444
goStats(stats)
445445
case (td @ TermDefinition(k = Fun, params = ps :: Nil, sign = sig, body = S(body))) :: stats =>
446-
typeFunDef(td.sym, Term.Lam(ps, body), sig, ctx)
446+
typeFunDef(td.sym, Term.Lam(ps, body), sig)
447447
goStats(stats)
448448
case (td @ TermDefinition(k = Fun, params = Nil, sign = sig, body = S(body))) :: stats =>
449-
typeFunDef(td.sym, body, sig, ctx) // * may be a case expressions
449+
typeFunDef(td.sym, body, sig) // * may be a case expressions
450450
goStats(stats)
451451
case (td1 @ TermDefinition(k = Fun, sign = S(sig), body = None)) :: (td2 @ TermDefinition(k = Fun, body = S(body))) :: stats
452452
if td1.sym === td2.sym => goStats(td2 :: stats) // * avoid type check signatures twice
@@ -568,7 +568,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
568568
constrain(tryMkMono(refTy, ref), BbCtx.refTy(ctnt, sk))
569569
(ctnt, sk | refEff)
570570
case Term.Quoted(body) =>
571-
val nestCtx = ctx.nextLevel
571+
val nestCtx = ctx.nest
572572
given BbCtx = nestCtx
573573
val (ty, ctxTy, eff) = typeCode(body)
574574
(BbCtx.codeTy(ty, ctxTy), eff)

hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,6 @@ trait CachedBasicType extends Type:
8080
abstract class TypeExt extends Type:
8181
override def hashCode: Int =
8282
toBasic.hashCode
83-
override def equals(that: Any): Bool =
84-
toBasic === that
8583

8684
sealed abstract class Type extends GeneralType with TypeArg:
8785

@@ -356,7 +354,7 @@ object PolyType:
356354
visited.toSet
357355

358356
def generalize(ty: GeneralType, outer: Opt[InfVar], lvl: Int): PolyType =
359-
PolyType(collectTVs(ty).filter(v => outer.map(_.uid != v.uid).getOrElse(true)).toList.sorted, outer, ty)
357+
PolyType(collectTVs(ty).filter(v => v.lvl == lvl && outer.map(_.uid != v.uid).getOrElse(true)).toList.sorted, outer, ty)
360358

361359
// * Functions that accept/return a polymorphic type.
362360
// * Note that effects are always monomorphic

hkmc2/shared/src/test/mlscript/bbml/bbBasics.mls

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ let ip = new Printer(foofoo) in ip.Printer#f("42")
164164
//│ ║ ^^^^
165165
//│ ╟── because: cannot constrain Str <: 'T
166166
//│ ╟── because: cannot constrain Str <: 'T
167-
//│ ╟── because: cannot constrain Str <: ¬(¬'T)
168167
//│ ╟── because: cannot constrain Str <: 'T
169-
//│ ╙── because: cannot constrain Str <: ¬(¬{Int})
168+
//│ ╟── because: cannot constrain Str <: 'T
169+
//│ ╙── because: cannot constrain Str <: Int
170170
//│ Type: Str
171171

172172
data class TFun[T](f: T -> T)
@@ -191,9 +191,9 @@ let tf = new TFun(inc) in tf.TFun#f("1")
191191
//│ ║ ^^^
192192
//│ ╟── because: cannot constrain Str <: 'T
193193
//│ ╟── because: cannot constrain Str <: 'T
194-
//│ ╟── because: cannot constrain Str <: ¬(¬'T)
195194
//│ ╟── because: cannot constrain Str <: 'T
196-
//│ ╙── because: cannot constrain Str <: ¬(¬{Int})
195+
//│ ╟── because: cannot constrain Str <: 'T
196+
//│ ╙── because: cannot constrain Str <: Int
197197
//│ Type: Str ∨ Int
198198

199199
data class Pair[A, B](fst: A, snd: B)

hkmc2/shared/src/test/mlscript/bbml/bbBorrowing.mls

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ letreg of r =>
9191
//│ ║ l.75: k()
9292
//│ ║ ^^^^^
9393
//│ ╟── because: cannot constrain 'E <: ⊥
94-
//│ ╟── because: cannot constrain 'E <: ¬()
95-
//│ ╟── because: cannot constrain ¬⊥ ∧ ¬'Rg <: ¬()
96-
//│ ╟── because: cannot constrain <: 'Rg
97-
//│ ╙── because: cannot constrain <: ¬()
94+
//│ ╟── because: cannot constrain 'E <:
95+
//│ ╟── because: cannot constrain ¬'Rg <:
96+
//│ ╟── because: cannot constrain <: 'Rg
97+
//│ ╙── because: cannot constrain <:
9898
//│ Type: ⊥
9999

100100
:e
@@ -127,10 +127,10 @@ letreg of r =>
127127
//│ ║ l.109: r
128128
//│ ║ ^^^
129129
//│ ╟── because: cannot constrain 'E <: ⊥
130-
//│ ╟── because: cannot constrain 'E <: ¬()
131-
//│ ╟── because: cannot constrain ¬⊥ ∧ 'Rg <: ¬()
132-
//│ ╟── because: cannot constrain 'Rg <: ¬()
133-
//│ ╙── because: cannot constrain <: ¬()
130+
//│ ╟── because: cannot constrain 'E <:
131+
//│ ╟── because: cannot constrain 'Rg <:
132+
//│ ╟── because: cannot constrain 'Rg <:
133+
//│ ╙── because: cannot constrain <:
134134
//│ Type: Reg[?, 'E]
135135
//│ Where:
136136
//│ ⊤ <: 'E
@@ -171,10 +171,10 @@ letreg of r =>
171171
//│ ║ l.157: r
172172
//│ ║ ^^^
173173
//│ ╟── because: cannot constrain 'E <: ⊥
174-
//│ ╟── because: cannot constrain 'E <: ¬()
175-
//│ ╟── because: cannot constrain ¬⊥ ∧ 'Rg <: ¬()
176-
//│ ╟── because: cannot constrain 'Rg <: ¬()
177-
//│ ╙── because: cannot constrain <: ¬()
174+
//│ ╟── because: cannot constrain 'E <:
175+
//│ ╟── because: cannot constrain 'Rg <:
176+
//│ ╟── because: cannot constrain 'Rg <:
177+
//│ ╙── because: cannot constrain <:
178178
//│ Type: Reg[?, 'E]
179179
//│ Where:
180180
//│ ⊤ <: 'E

hkmc2/shared/src/test/mlscript/bbml/bbBorrowing2.mls

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ letreg of r =>
4848
//│ ║ l.36: write(r)
4949
//│ ║ ^^^^^^^^^^
5050
//│ ╟── because: cannot constrain 'E <: ⊥
51-
//│ ╟── because: cannot constrain 'E <: ¬()
52-
//│ ╟── because: cannot constrain ¬⊥ ∧ 'Rg <: ¬()
53-
//│ ╟── because: cannot constrain 'Rg <: ¬()
54-
//│ ╙── because: cannot constrain <: ¬()
51+
//│ ╟── because: cannot constrain 'E <:
52+
//│ ╟── because: cannot constrain 'Rg <:
53+
//│ ╟── because: cannot constrain 'Rg <:
54+
//│ ╙── because: cannot constrain <:
5555
//│ Type: Int
5656

5757

@@ -87,10 +87,10 @@ letreg of r =>
8787
//│ ║ l.75: write(r)
8888
//│ ║ ^^^^^^^^^^
8989
//│ ╟── because: cannot constrain 'E <: ⊥
90-
//│ ╟── because: cannot constrain 'E <: ¬()
91-
//│ ╟── because: cannot constrain ¬⊥ ∧ 'Rg <: ¬()
92-
//│ ╟── because: cannot constrain 'Rg <: ¬()
93-
//│ ╙── because: cannot constrain <: ¬()
90+
//│ ╟── because: cannot constrain 'E <:
91+
//│ ╟── because: cannot constrain 'Rg <:
92+
//│ ╟── because: cannot constrain 'Rg <:
93+
//│ ╙── because: cannot constrain <:
9494
//│ Type: Int
9595

9696

hkmc2/shared/src/test/mlscript/bbml/bbCheck.mls

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ foo(42, 1)
180180
//│ ║ l.178: foo(42, 1)
181181
//│ ║ ^
182182
//│ ╟── because: cannot constrain Int <: ¬'A
183-
//│ ╟── because: cannot constrain 'A <: ¬(Int)
184-
//│ ╙── because: cannot constrain Int <: ¬(Int)
183+
//│ ╟── because: cannot constrain 'A <: ¬{Int}
184+
//│ ╙── because: cannot constrain Int <: ¬{Int}
185185
//│ Type: Int
186186

187187
foo(42, false)

hkmc2/shared/src/test/mlscript/bbml/bbCodeGen.mls

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,40 @@ region x in let y = x.ref 42 in y := 0
268268
//│ let x5, y1, tmp5; x5 = new this.Region(); tmp5 = new this.Ref(x5, 42); y1 = tmp5; y1.value = 0; 0
269269
//│ = 0
270270
//│ Type: Int
271+
272+
273+
274+
data class LsAlg[A, E](nil:() -> E, cons: (A, E) -> E)
275+
//│ Type: ⊤
276+
277+
fun nil(x) = x.LsAlg#nil()
278+
fun cons(x, y, z) = x.LsAlg#cons(y, z)
279+
//│ Type: ⊤
280+
281+
data class Nil()
282+
data class Cons[A, B](val x: A,val y: B)
283+
//│ Type: ⊤
284+
285+
fun mk() = new LsAlg(() => new Nil, (x, y) => new Cons(x, y))
286+
//│ Type: ⊤
287+
288+
// fun xs: [E] -> LsAlg[in Int, E] -> E
289+
fun xs(x) = x cons(1, x nil())
290+
//│ Type: ⊤
291+
292+
fun ys: [E] -> LsAlg[in Nothing, E] -> E
293+
fun ys(x) = x nil()
294+
//│ Type: ⊤
295+
296+
fun zs: [E] -> LsAlg[in Int | E, E] -> E
297+
fun zs(x) = x cons(xs(x), x cons(ys(x),x nil()))
298+
//│ Type: ⊤
299+
300+
mk() zs()
301+
//│ = Cons(Cons(1, Nil()), Cons(Nil(), Nil()))
302+
//│ Type: Nil ∨ Cons['A, 'B]
303+
//│ Where:
304+
//│ ('E ∨ Cons['A, 'B]) ∨ Nil <: 'B
305+
//│ Nil <: 'E
306+
//│ Cons['A, 'B] <: 'E
307+
//│ 'E ∨ Int <: 'A

hkmc2/shared/src/test/mlscript/bbml/bbCyclicExtrude.mls

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ f => (let g = x => f(x(x)) in g) as [A] -> A -> A
2323
//│ ╔══[ERROR] Type error in block with expected type (A) ->{⊥} A
2424
//│ ║ l.22: f => (let g = x => f(x(x)) in g) as [A] -> A -> A
2525
//│ ║ ^^^^^^^^^^^^^^^^^
26-
//│ ╟── because: cannot constrain 'x ->{'eff ∨ 'eff1} 'app <: A -> A
26+
//│ ╟── because: cannot constrain ('x) ->{'eff ∨ 'eff1} ('app) <: (A) ->{⊥} (A)
2727
//│ ╟── because: cannot constrain A <: 'x
2828
//│ ╟── because: cannot constrain A <: 'x
29-
//│ ╙── because: cannot constrain A <: ¬(¬{'x ->{'eff1} 'app1})
29+
//│ ╙── because: cannot constrain A <: ('x) ->{'eff1} ('app1)
3030
//│ Type: (⊥ -> ⊥) ->{⊥} ['A] -> ('A) ->{⊥} 'A
3131

3232
f => (x => f(x(x)) as [A] -> A -> A)

0 commit comments

Comments
 (0)