Skip to content

Commit efb66c1

Browse files
Patches for BbML (#240)
* Fix extrusion cache * Fix boolean if type check * Fix quotes * Fix run function type * Fix context envs * Slightly improve error messages * Fix missing operators * Cache skolem extrusion * Rename prelude file
1 parent 719113d commit efb66c1

29 files changed

+566
-494
lines changed

hkmc2/jvm/src/test/scala/hkmc2/BbmlDiffMaker.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ import hkmc2.bbml.*
88

99
abstract class BbmlDiffMaker extends JSBackendDiffMaker:
1010

11-
val bbPredefFile = file / os.up / os.RelPath("bbPredef.mls")
11+
val bbPreludeFile = file / os.up / os.RelPath("bbPrelude.mls")
1212

1313
val bbmlOpt = new NullaryCommand("bbml"):
1414
override def onSet(): Unit =
1515
super.onSet()
1616
if isGlobal then typeCheck.disable.isGlobal = true
1717
typeCheck.disable.setCurrentValue(())
18-
if file =/= bbPredefFile then
19-
importFile(bbPredefFile, verbose = false)
18+
if file =/= bbPreludeFile then
19+
importFile(bbPreludeFile, verbose = false)
2020

2121

2222
lazy val bbCtx =

hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ class ConstraintSolver(infVarState: InfVarUid.State, tl: TraceLogger):
3737

3838
import hkmc2.bbml.NormalForm.*
3939

40-
private def freshXVar(lvl: Int): InfVar = InfVar(lvl, infVarState.nextUid, new VarState(), false)
40+
private def freshXVar(lvl: Int, hint: Option[Str]): InfVar = InfVar(lvl, infVarState.nextUid, new VarState(), false)(hint)
4141

42-
def extrude(ty: Type)(using lvl: Int, pol: Bool, cache: ExtrudeCache): Type =
42+
def extrude(ty: Type)(using lvl: Int, pol: Bool, cache: ExtrudeCache, bbctx: BbCtx, cctx: CCtx, tl: TL): Type =
4343
trace[Type](s"Extruding[${printPol(pol)}] $ty", r => s"~> $r"):
4444
if ty.lvl <= lvl then ty else ty.toBasic/*TODO improve extrude directly*/ match
4545
case ClassLikeType(sym, targs) =>
@@ -49,13 +49,18 @@ class ConstraintSolver(infVarState: InfVarUid.State, tl: TraceLogger):
4949
case t: Type => Wildcard(extrude(t)(using lvl, !pol), extrude(t))
5050
})
5151
case v @ InfVar(_, uid, state, true) => // * skolem
52-
if pol then
53-
state.upperBounds.foldLeft[Type](Top)(_ & _)
54-
else
55-
state.lowerBounds.foldLeft[Type](Bot)(_ | _)
52+
cache.getOrElse(uid -> pol, {
53+
val nv = freshXVar(lvl, v.hint)
54+
cache += uid -> pol -> nv
55+
if pol then
56+
constrainImpl(state.upperBounds.foldLeft[Type](Top)(_ & _), nv)
57+
else
58+
constrainImpl(nv, state.lowerBounds.foldLeft[Type](Bot)(_ | _))
59+
nv
60+
})
5661
case v @ InfVar(_, uid, _, false) =>
5762
cache.getOrElse(uid -> pol, {
58-
val nv = freshXVar(lvl)
63+
val nv = freshXVar(lvl, v.hint)
5964
cache += uid -> pol -> nv
6065
if pol then
6166
v.state.upperBounds ::= nv

hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ extends NormalForm with CachedBasicType:
3434
}.foldLeft[Opt[Ls[(InfVar, Bool)]]](S(Nil))((res, p) => (res, p) match { // * None -> bot
3535
case (N, _) => N
3636
case (S(Nil), p) => S(p :: Nil)
37-
case (S((InfVar(v, uid1, s, k), p1) :: tail), (InfVar(_, uid2, _, _), p2)) if uid1 === uid2 =>
38-
if p1 === p2 then S((InfVar(v, uid1, s, k), p1) :: tail) else N
37+
case (S((lhs @ InfVar(v, uid1, s, k), p1) :: tail), (InfVar(_, uid2, _, _), p2)) if uid1 === uid2 =>
38+
if p1 === p2 then S((InfVar(v, uid1, s, k)(lhs.hint), p1) :: tail) else N
3939
case (S(head :: tail), p) => S(p :: head :: tail)
4040
})
4141
vars match

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,10 @@ final case class BbCtx(
2323
ctx: Ctx,
2424
parent: Option[BbCtx],
2525
lvl: Int,
26-
clsDefs: HashMap[Str, ClassDef],
27-
env: HashMap[Uid[Symbol], GeneralType],
28-
quoteSkolemEnv: HashMap[Uid[Symbol], InfVar], // * SkolemTag for variables in quasiquotes
26+
env: HashMap[Uid[Symbol], GeneralType]
2927
):
3028
def +=(p: Symbol -> GeneralType): Unit = env += p._1.uid -> p._2
3129
def get(sym: Symbol): Option[GeneralType] = env.get(sym.uid) orElse parent.dlof(_.get(sym))(None)
32-
def *=(cls: ClassDef): Unit = clsDefs += cls.sym.id.name -> cls
3330
def getCls(name: Str): Option[TypeSymbol] =
3431
for
3532
elem <- ctx.get(name)
@@ -38,10 +35,8 @@ final case class BbCtx(
3835
yield cls
3936
def &=(p: (Symbol, Type, InfVar)): Unit =
4037
env += p._1.uid -> BbCtx.varTy(p._2, p._3)(using this)
41-
quoteSkolemEnv += p._1.uid -> p._3
42-
def getSk(sym: Symbol): Option[Type] = quoteSkolemEnv.get(sym.uid) orElse parent.dlof(_.getSk(sym))(None)
43-
def nest: BbCtx = copy(parent = Some(this))
44-
def nextLevel: BbCtx = copy(lvl = lvl + 1)
38+
def nest: BbCtx = copy(parent = Some(this), env = HashMap.empty)
39+
def nextLevel: BbCtx = copy(parent = Some(this), lvl = lvl + 1, env = HashMap.empty)
4540

4641
given (using ctx: BbCtx): Raise = ctx.raise
4742

@@ -62,7 +57,7 @@ object BbCtx:
6257
def refTy(ct: Type, sk: Type)(using ctx: BbCtx): Type =
6358
ClassLikeType(ctx.getCls("Ref").get, Wildcard(ct, ct) :: Wildcard.out(sk) :: Nil)
6459
def init(raise: Raise)(using Elaborator.State, Elaborator.Ctx): BbCtx =
65-
new BbCtx(raise, summon, None, 1, HashMap.empty, HashMap.empty, HashMap.empty)
60+
new BbCtx(raise, summon, None, 1, HashMap.empty)
6661
end BbCtx
6762

6863

@@ -72,13 +67,13 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
7267
private val infVarState = new InfVarUid.State()
7368
private val solver = new ConstraintSolver(infVarState, tl)
7469

75-
private def freshSkolem(using ctx: BbCtx): InfVar =
76-
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), true)
77-
private def freshVar(using ctx: BbCtx): InfVar =
78-
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), false)
79-
private def freshWildcard(using ctx: BbCtx) =
80-
val in = freshVar
81-
val out = freshVar
70+
private def freshSkolem(hint: Option[Str])(using ctx: BbCtx): InfVar =
71+
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), true)(hint)
72+
private def freshVar(hint: Option[Str])(using ctx: BbCtx): InfVar =
73+
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), false)(hint)
74+
private def freshWildcard(hint: Option[Str])(using ctx: BbCtx) =
75+
val in = freshVar(hint)
76+
val out = freshVar(hint)
8277
// in.state.upperBounds ::= out // * Not needed for soundness; complicates inferred types
8378
Wildcard(in, out)
8479

@@ -157,7 +152,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
157152
private def genPolyType(tvs: Ls[QuantVar], body: => GeneralType)(using ctx: BbCtx, cctx: CCtx) =
158153
val bds = tvs.map:
159154
case qv @ QuantVar(sym, ub, lb) =>
160-
val tv = freshVar
155+
val tv = freshVar(S(sym.name))
161156
ctx += sym -> tv // TODO: a type var symbol may be better...
162157
tv -> qv
163158
bds.foreach:
@@ -176,7 +171,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
176171

177172
private def instantiate(ty: PolyType)(using ctx: BbCtx): GeneralType = ty.instantiate(infVarState.nextUid, ctx.lvl)(tl)
178173

179-
private def extrude(ty: GeneralType)(using ctx: BbCtx, pol: Bool): GeneralType = ty match
174+
private def extrude(ty: GeneralType)(using ctx: BbCtx, pol: Bool, cctx: CCtx): GeneralType = ty match
180175
case ty: Type => solver.extrude(ty)(using ctx.lvl, pol, HashMap.empty)
181176
case PolyType(tvs, body) => PolyType(tvs, extrude(body))
182177
case PolyFunType(args, ret, eff) =>
@@ -185,7 +180,6 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
185180
private def constrain(lhs: Type, rhs: Type)(using ctx: BbCtx, cctx: CCtx): Unit =
186181
solver.constrain(lhs, rhs)
187182

188-
// TODO: content type
189183
private def typeCode(code: Term)(using ctx: BbCtx): (Type, Type, Type) =
190184
given CCtx = CCtx.init(code, N)
191185
code match
@@ -201,12 +195,12 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
201195
given BbCtx = nestCtx
202196
val bds = params.map:
203197
case Param(_, sym, _) =>
204-
val tv = freshVar
205-
val sk = freshSkolem
198+
val tv = freshVar(S(sym.name))
199+
val sk = freshSkolem(S(sym.name))
206200
nestCtx &= (sym, tv, sk)
207201
(tv, sk)
208202
val (bodyTy, ctxTy, eff) = typeCode(body)
209-
val res = freshVar(using ctx)
203+
val res = freshVar(N)(using ctx)
210204
constrain(ctxTy, bds.foldLeft[Type](res)((res, bd) => res | bd._2))
211205
(FunType(bds.map(_._1), bodyTy, Bot), res, eff)
212206
case Term.App(lhs, Term.Tup(rhs)) =>
@@ -215,26 +209,29 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
215209
case (res, p: Fld) =>
216210
val (ty, ctx, eff) = typeCode(p.term)
217211
(ty :: res._1, res._2 | ctx, res._3 | eff)
218-
val resTy = freshVar
212+
val resTy = freshVar(N)
219213
constrain(lhsTy, FunType(rhsTy.reverse, resTy, Bot)) // TODO: right
220214
(resTy, lhsCtx | rhsCtx, lhsEff | rhsEff)
215+
case sel @ Term.Sel(Term.Ref(_: TopLevelSymbol), _) if sel.symbol.isDefined =>
216+
val (opTy, eff) = typeCheck(Ref(sel.symbol.get)(sel.nme, 666)) // FIXME 666
217+
(tryMkMono(opTy, sel), Bot, eff)
221218
case Term.Unquoted(body) =>
222219
val (ty, eff) = typeCheck(body)
223-
val tv = freshVar
224-
val cr = freshVar
220+
val tv = freshVar(N)
221+
val cr = freshVar(N)
225222
constrain(tryMkMono(ty, body), BbCtx.codeTy(tv, cr))
226223
(tv, cr, eff)
227224
case Term.Blk(LetDecl(sym) :: DefineVar(sym2, rhs) :: Nil, body) if sym2 is sym => // TODO: more than one!!
228225
val (rhsTy, rhsCtx, rhsEff) = typeCode(rhs)(using ctx)
229226
val nestCtx = ctx.nextLevel
230227
given BbCtx = nestCtx
231-
val sk = freshSkolem
228+
val sk = freshSkolem(S(sym.nme))
232229
nestCtx &= (sym, rhsTy, sk)
233230
val (bodyTy, bodyCtx, bodyEff) = typeCode(body)
234-
val res = freshVar(using ctx)
231+
val res = freshVar(N)(using ctx)
235232
constrain(bodyCtx, sk | res)
236233
(bodyTy, rhsCtx | res, rhsEff | bodyEff)
237-
case Term.IfLike(Keyword.`if`, Split.Cons(Branch(cond, Pattern.Lit(BoolLit(true)), Split.Else(cons)), Split.Else(alts))) =>
234+
case Term.IfLike(Keyword.`if`, Split.Let(_, cond, Split.Cons(Branch(_, Pattern.Lit(BoolLit(true)), Split.Else(cons)), Split.Else(alts)))) =>
238235
val (condTy, condCtx, condEff) = typeCode(cond)
239236
val (consTy, consCtx, consEff) = typeCode(cons)
240237
val (altsTy, altsCtx, altsEff) = typeCode(alts)
@@ -252,7 +249,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
252249
()
253250
case N =>
254251
given BbCtx = ctx.nextLevel
255-
val funTyV = freshVar
252+
val funTyV = freshVar(S(sym.nme))
256253
pctx += sym -> funTyV // for recursive functions
257254
val (res, _) = typeCheck(lam)
258255
val funTy = tryMkMono(res, lam)
@@ -266,10 +263,10 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
266263
: (GeneralType, Type) =
267264
split match
268265
case Split.Cons(Branch(scrutinee, Pattern.ClassLike(sym, _, _, _), cons), alts) =>
269-
// * Pattern matching
266+
// * Pattern matching for classes
270267
val (clsTy, tv, emptyTy) = ctx.getCls(sym.nme).flatMap(_.defn) match
271268
case S(cls) =>
272-
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard)), freshVar, ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
269+
(ClassLikeType(sym, cls.tparams.map(_ => freshWildcard(N))), (freshVar(N)), ClassLikeType(sym, cls.tparams.map(_ => Wildcard.empty)))
273270
case _ =>
274271
error(msg"Cannot match ${scrutinee.toString} as ${sym.toString}" -> split.toLoc :: Nil)
275272
(Bot, Bot, Bot)
@@ -286,6 +283,23 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
286283
val (altsTy, altsEff) = typeSplit(alts, sign)(using nestCtx2)
287284
val allEff = scrutineeEff | (consEff | altsEff)
288285
(sign.getOrElse(tryMkMono(consTy, cons) | tryMkMono(altsTy, alts)), allEff)
286+
// * Pattern matching for literals
287+
case Split.Cons(Branch(scrutinee, Pattern.Lit(lit), cons), alts) =>
288+
val (scrutineeTy, scrutineeEff) = typeCheck(scrutinee)
289+
val litTy = lit match
290+
case _: Tree.BoolLit => BbCtx.boolTy
291+
case _: Tree.IntLit => BbCtx.intTy
292+
case _: Tree.DecLit => BbCtx.numTy
293+
case _: Tree.StrLit => BbCtx.strTy
294+
case _: Tree.UnitLit => Top
295+
296+
constrain(tryMkMono(scrutineeTy, scrutinee), litTy)
297+
val nestCtx1 = ctx.nest
298+
val nestCtx2 = ctx.nest
299+
val (consTy, consEff) = typeSplit(cons, sign)(using nestCtx1)
300+
val (altsTy, altsEff) = typeSplit(alts, sign)(using nestCtx2)
301+
val allEff = scrutineeEff | (consEff | altsEff)
302+
(sign.getOrElse(tryMkMono(consTy, cons) | tryMkMono(altsTy, alts)), allEff)
289303
case Split.Let(name, term, tail) =>
290304
val nestCtx = ctx.nest
291305
given BbCtx = nestCtx
@@ -361,8 +375,8 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
361375
val (ty, eff) = typeCheck(f.term)
362376
Left(ty) :: Right(eff) :: Nil
363377
.partitionMap(x => x)
364-
val effVar = freshVar
365-
val retVar = freshVar
378+
val effVar = freshVar(N)
379+
val retVar = freshVar(N)
366380
constrain(tryMkMono(funTy, t), FunType(argTy.map((tryMkMono(_, t))), retVar, effVar))
367381
(retVar, argEff.foldLeft[Type](effVar | lhsEff)((res, e) => res | e))
368382

@@ -394,8 +408,6 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
394408
(error(msg"Variable not found: ${sym.nme}"
395409
-> t.toLoc :: Nil), Bot)
396410
case Blk(stats, res) =>
397-
val nestCtx = ctx.nest
398-
given BbCtx = nestCtx
399411
val effBuff = ListBuffer.empty[Type]
400412
def goStats(stats: Ls[Statement]): Unit = stats match
401413
case Nil => ()
@@ -406,7 +418,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
406418
require(sym2 is sym)
407419
val (rhsTy, eff) = typeCheck(rhs)
408420
effBuff += eff
409-
nestCtx += sym -> rhsTy
421+
ctx += sym -> rhsTy
410422
goStats(stats)
411423
case TermDefinition(_, Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _, _) :: stats =>
412424
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
@@ -418,7 +430,6 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
418430
ctx += sym -> typeType(sig)
419431
goStats(stats)
420432
case (clsDef: ClassDef) :: stats =>
421-
ctx *= clsDef
422433
goStats(stats)
423434
goStats(stats)
424435
val (ty, eff) = typeCheck(res)
@@ -434,7 +445,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
434445
given BbCtx = nestCtx
435446
val tvs = params.map:
436447
case Param(_, sym, sign) =>
437-
val ty = sign.map(s => typeType(s)(using nestCtx)).getOrElse(freshVar)
448+
val ty = sign.map(s => typeType(s)(using nestCtx)).getOrElse(freshVar(S(sym.nme)))
438449
nestCtx += sym -> ty
439450
ty
440451
val (bodyTy, eff) = typeCheck(body)
@@ -446,7 +457,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
446457
val map = HashMap[Uid[Symbol], TypeArg]()
447458
val targs = clsDfn.tparams.map {
448459
case TyParam(_, _, targ) =>
449-
val ty = freshWildcard
460+
val ty = freshWildcard(N)
450461
map += targ.uid -> ty
451462
ty
452463
}
@@ -471,12 +482,12 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
471482
val map = HashMap[Uid[Symbol], TypeArg]()
472483
val targs = clsDfn.tparams.map {
473484
case TyParam(_, S(_), targ) =>
474-
val ty = freshVar
485+
val ty = freshVar(N)
475486
map += targ.uid -> ty
476487
ty
477488
case TyParam(_, N, targ) =>
478489
// val ty = freshWildcard // FIXME probably not correct
479-
val ty = freshVar
490+
val ty = freshVar(N)
480491
map += targ.uid -> ty
481492
ty
482493
}
@@ -498,28 +509,28 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
498509
case Term.Region(sym, body) =>
499510
val nestCtx = ctx.nextLevel
500511
given BbCtx = nestCtx
501-
val sk = freshSkolem
512+
val sk = freshSkolem(S(sym.nme))
502513
nestCtx += sym -> BbCtx.regionTy(sk)
503514
val (res, eff) = typeCheck(body)
504-
val tv = freshVar(using ctx)
515+
val tv = freshVar(N)(using ctx)
505516
constrain(eff, tv | sk)
506517
(extrude(res)(using ctx, true), tv | allocType)
507518
case Term.RegRef(reg, value) =>
508519
val (regTy, regEff) = typeCheck(reg)
509520
val (valTy, valEff) = typeCheck(value)
510-
val sk = freshVar
521+
val sk = freshVar(N)
511522
constrain(tryMkMono(regTy, reg), BbCtx.regionTy(sk))
512523
(BbCtx.refTy(tryMkMono(valTy, value), sk), sk | (regEff | valEff))
513524
case Term.Assgn(lhs, rhs) =>
514525
val (lhsTy, lhsEff) = typeCheck(lhs)
515526
val (rhsTy, rhsEff) = typeCheck(rhs)
516-
val sk = freshVar
527+
val sk = freshVar(N)
517528
constrain(tryMkMono(lhsTy, lhs), BbCtx.refTy(tryMkMono(rhsTy, rhs), sk))
518529
(tryMkMono(rhsTy, rhs), sk | (lhsEff | rhsEff))
519530
case Term.Deref(ref) =>
520531
val (refTy, refEff) = typeCheck(ref)
521-
val sk = freshVar
522-
val ctnt = freshVar
532+
val sk = freshVar(N)
533+
val ctnt = freshVar(N)
523534
constrain(tryMkMono(refTy, ref), BbCtx.refTy(ctnt, sk))
524535
(ctnt, sk | refEff)
525536
case Term.Quoted(body) =>

0 commit comments

Comments
 (0)