Skip to content

Commit 32226d4

Browse files
authored
Preliminary implementation of refining patterns (hkust-taco#262)
1 parent 5a5509b commit 32226d4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+2044
-199
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
209209
val clsTmp = summon[Scope].allocateName(new semantics.TempSymbol(N, sym.nme+"$"+"class"))
210210
clsDefn.owner match
211211
case S(owner) =>
212-
assert(clsDefn.paramsOpt.isEmpty)
212+
assert((clsDefn.kind is syntax.Pat) || clsDefn.paramsOpt.isEmpty)
213213
// doc"${mkThis(owner)}.${sym.nme} = new ${clsJS}"
214214
doc"const $clsTmp = ${clsJS}; # ${mkThis(owner)}.${sym.nme} = new ${clsTmp
215215
}; # ${mkThis(owner)}.${sym.nme}.class = $clsTmp;"
@@ -250,7 +250,6 @@ class JSBuilder(using Elaborator.State, Elaborator.Ctx) extends CodeBuilder:
250250
case Match(scrut, hd :: tl, els, rest) =>
251251
val sd = result(scrut)
252252
def cond(cse: Case) = cse match
253-
case Case.Lit(syntax.Tree.BoolLit(true)) => sd
254253
case Case.Lit(lit) => doc"$sd === ${lit.idStr}"
255254
case Case.Cls(cls, pth) => cls match
256255
// case _: semantics.ModuleSymbol => doc"=== ${result(pth)}"

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,14 @@ extends Importer:
162162
N
163163
case _ => N
164164

165+
/** To perform a reverse lookup for a term that references a symbol in the current context. */
166+
def reference(target: ClassSymbol | ModuleSymbol): Ctxl[Opt[Term]] =
167+
def go(ctx: Ctx): Opt[Term] =
168+
ctx.env.values.collectFirst:
169+
case elem if elem.symbol.flatMap(_.asClsLike).contains(target) => elem.ref(target.id)
170+
.orElse(ctx.parent.flatMap(go))
171+
go(ctx).map(Term.SynthSel(_, Ident("class"))(S(target)))
172+
165173
def cls(tree: Tree, inAppPrefix: Bool): Ctxl[Term] = trace[Term](s"Elab class ${tree.showDbg}", r => s"~> $r"):
166174
val trm = term(tree, inAppPrefix)
167175
trm.symbol match
@@ -289,8 +297,12 @@ extends Importer:
289297
case InfixApp(lhs, Keyword.`:`, rhs) =>
290298
Term.Asc(term(lhs), term(rhs))
291299
case tree @ InfixApp(lhs, Keyword.`is` | Keyword.`and`, rhs) =>
292-
val des = new Desugarer(this)(tree)
293-
val nor = new ucs.Normalization(tl)(des)
300+
val des = new ucs.Desugarer(this)(tree)
301+
scoped("ucs:desugared"):
302+
log(s"Desugared:\n${Split.display(des)}")
303+
val nor = new ucs.Normalization(this)(des)
304+
scoped("ucs:normalized"):
305+
log(s"Normalized:\n${Split.display(nor)}")
294306
Term.IfLike(Keyword.`if`, des)(nor)
295307
case app @ PartialApp(lhs, args) =>
296308
var params: Ls[Param] = Nil
@@ -428,21 +440,21 @@ extends Importer:
428440
// case _ =>
429441
// raise(ErrorReport(msg"Illegal new expression." -> tree.toLoc :: Nil))
430442
case tree @ Tree.IfLike(kw, _, split) =>
431-
val desugared = new Desugarer(this)(tree)
443+
val desugared = new ucs.Desugarer(this)(tree)
432444
scoped("ucs:desugared"):
433445
log(s"Desugared:\n${Split.display(desugared)}")
434-
val normalized = new ucs.Normalization(tl)(desugared)
446+
val normalized = new ucs.Normalization(this)(desugared)
435447
scoped("ucs:normalized"):
436448
log(s"Normalized:\n${Split.display(normalized)}")
437449
Term.IfLike(kw, desugared)(normalized)
438450
case Tree.Quoted(body) => Term.Quoted(term(body))
439451
case Tree.Unquoted(body) => Term.Unquoted(term(body))
440452
case tree @ Tree.Case(_, branches) =>
441453
val scrut = VarSymbol(Ident("caseScrut"))
442-
val des = new Desugarer(this)(tree, scrut)
454+
val des = new ucs.Desugarer(this)(tree, scrut)
443455
scoped("ucs:desugared"):
444456
log(s"Desugared:\n${Split.display(des)}")
445-
val nor = new ucs.Normalization(tl)(des)
457+
val nor = new ucs.Normalization(this)(des)
446458
scoped("ucs:normalized"):
447459
log(s"Normalized:\n${Split.display(nor)}")
448460
Term.Lam(PlainParamList(
@@ -851,9 +863,29 @@ extends Importer:
851863
val owner = ctx.outer
852864
newCtx.nest(S(patSym)).givenIn:
853865
assert(body.isEmpty)
866+
td.extension match
867+
case N => raise(ErrorReport(msg"Pattern definitions must have a body." -> td.toLoc :: Nil))
868+
case S(tree) =>
869+
val (patternParams, extractionParams) = ps match // Filter out pattern parameters.
870+
case S(ParamList(_, params, _)) => params.partition:
871+
case param @ Param(FldFlags(false, false, false, false, true), _, _) => true
872+
case param @ Param(FldFlags(_, _, _, _, false), _, _) => false
873+
case N => (Nil, Nil)
874+
// TODO: Implement extraction parameters.
875+
if extractionParams.nonEmpty then
876+
raise(ErrorReport(msg"Pattern extraction parameters are not yet supported." ->
877+
Loc(extractionParams.iterator.map(_.sym)) :: Nil))
878+
log(s"pattern parameters: ${patternParams.mkString("{ ", ", ", " }")}")
879+
patSym.patternParams = patternParams
880+
val split = ucs.DeBrujinSplit.elaborate(patternParams, tree, this)
881+
scoped("ucs:rp:elaborated"):
882+
log(s"elaborated ${patSym.nme}:\n${split.display}")
883+
patSym.split = split
854884
log(s"pattern body is ${td.extension}")
855885
val translate = new ucs.Translator(this)
856-
val bod = translate(ps.map(_.params).getOrElse(Nil), td.extension.getOrElse(die))
886+
val bod = translate(patSym.patternParams,
887+
Nil, // ps.map(_.params).getOrElse(Nil), // TODO: remove pattern parameters
888+
td.extension.getOrElse(die))
857889
val pd = PatternDef(owner, patSym, tps, ps, ObjBody(Term.Blk(bod, Term.Lit(UnitLit(true)))), annotations)
858890
patSym.defn = S(pd)
859891
pd
@@ -924,6 +956,8 @@ extends Importer:
924956
raise(ErrorReport(msg"Module parameters must have concrete types." -> t.toLoc :: Nil))
925957
case _ => ()
926958
ps
959+
case TypeDef(Pat, inner, N, N) =>
960+
param(inner).map(_.mapSecond(p => p.copy(flags = p.flags.copy(pat = true))))
927961
case _ =>
928962
t.asParam.map: (isSpd, p, t) =>
929963
isSpd -> Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term(_)))

hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,30 @@ package semantics
33

44
import mlscript.utils.*, shorthands.*
55
import syntax.*, Tree.Ident
6+
import ucs.DeBrujinSplit
67

78
/** Flat patterns for pattern matching */
89
enum Pattern extends AutoLocated:
910
case Lit(literal: Literal)
1011
case Var(sym: BlockLocalSymbol)
1112
case ClassLike(sym: ClassSymbol | ModuleSymbol, trm: Term, parameters: Opt[List[BlockLocalSymbol]], var refined: Bool)(val tree: Tree)
13+
case Synonym(symbol: PatternSymbol, patternArguments: Ls[(split: DeBrujinSplit, tree: Tree)])
1214
case Tuple(size: Int, inf: Bool)
1315
case Record(entries: List[(Ident -> BlockLocalSymbol)])
1416

1517
def subTerms: Ls[Term] = this match
1618
case Lit(_) => Nil
1719
case Var(_) => Nil
1820
case ClassLike(_, t, _, _) => t :: Nil
21+
case Synonym(_, _) => Nil
1922
case Tuple(_, _) => Nil
2023
case Record(_) => Nil
2124

2225
def children: Ls[Located] = this match
2326
case Lit(literal) => literal :: Nil
2427
case Var(nme) => Nil
2528
case ClassLike(_, t, parameters, _) => t :: parameters.toList.flatten
29+
case Synonym(_, arguments) => arguments.map(_.tree)
2630
case Tuple(fields, _) => Nil
2731
case Record(entries) => entries.flatMap { case (nme, als) => nme :: als :: Nil }
2832

@@ -31,9 +35,9 @@ enum Pattern extends AutoLocated:
3135
case Var(sym) => sym.nme
3236
case ClassLike(sym, t, ps, rfd) => (if rfd then "refined " else "") +
3337
sym.nme + ps.fold("")(_.mkString("(", ", ", ")"))
38+
case Synonym(symbol, arguments) =>
39+
symbol.nme + arguments.iterator.map(_.tree.showDbg).mkString("(", ", ", ")")
3440
case Tuple(size, inf) => "[]" + (if inf then ">=" else "=") + size
3541
case Record(Nil) => "{}"
3642
case Record(entries) =>
3743
entries.iterator.map(_.name + ": " + _).mkString("{ ", ", ", " }")
38-
39-

hkmc2/shared/src/main/scala/hkmc2/semantics/Split.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ enum Split extends AutoLocated with ProductWithTail:
6161
var isFallback: Bool = false
6262
end Split
6363

64+
extension (split: Split)
65+
def ~~:(fallback: Split): Split =
66+
if fallback == Split.End || split.isFull then
67+
split
68+
else (split match
69+
case Split.Cons(head, tail) => Split.Cons(head, tail ~~: fallback)
70+
case Split.Let(name, term, tail) => Split.Let(name, term, tail ~~: fallback)
71+
case Split.Else(_) /* impossible */ | Split.End => fallback)
72+
6473
object Split:
6574
def default(term: Term): Split = Split.Else(term)
6675

hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,17 @@ type FieldSymbol = TermSymbol | MemberSymbol[?]
177177
sealed trait InnerSymbol extends Symbol
178178

179179
class ClassSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State)
180-
extends MemberSymbol[ClassDef] with CtorSymbol with InnerSymbol:
180+
extends MemberSymbol[ClassDef] with CtorSymbol with InnerSymbol with NamedSymbol:
181+
def name: Str = nme
181182
def nme = id.name
182183
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of classe here
183184
override def toString: Str = s"class:$nme${State.dbgUid(uid)}"
184185
/** Compute the arity. */
185186
def arity: Int = tree.paramLists.headOption.fold(0)(_.fields.length)
186187

187188
class ModuleSymbol(val tree: Tree.TypeDef, val id: Tree.Ident)(using State)
188-
extends MemberSymbol[ModuleDef] with CtorSymbol with InnerSymbol:
189+
extends MemberSymbol[ModuleDef] with CtorSymbol with InnerSymbol with NamedSymbol:
190+
def name: Str = nme
189191
def nme = id.name
190192
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of module here
191193
override def toString: Str = s"module:${id.name}${State.dbgUid(uid)}"
@@ -195,11 +197,20 @@ class TypeAliasSymbol(val id: Tree.Ident)(using State) extends MemberSymbol[Type
195197
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of type alias here
196198
override def toString: Str = s"module:${id.name}${State.dbgUid(uid)}"
197199

198-
class PatternSymbol(val id: Tree.Ident)(using State)
200+
class PatternSymbol(val id: Tree.Ident, val params: Opt[Tree.Tup], val body: Tree)(using State)
199201
extends MemberSymbol[PatternDef] with CtorSymbol with InnerSymbol:
200202
def nme = id.name
201203
def toLoc: Option[Loc] = id.toLoc // TODO track source tree of pattern here
202204
override def toString: Str = s"pattern:${id.name}"
205+
/** The desugared nameless split. */
206+
private var _split: Opt[ucs.DeBrujinSplit] = N
207+
def split_=(split: ucs.DeBrujinSplit): Unit = _split = S(split)
208+
def split: ucs.DeBrujinSplit = _split.getOrElse:
209+
lastWords(s"found unelaborated pattern: $nme")
210+
/** The list of pattern parameters, for example,
211+
* `T` in `pattern Nullable(pattern T) = null | T`.
212+
*/
213+
var patternParams: Ls[Param] = Nil
203214

204215
class TopLevelSymbol(blockNme: Str)(using State)
205216
extends MemberSymbol[ModuleDef] with InnerSymbol:

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,18 +342,19 @@ case class TypeDef(
342342

343343

344344
// TODO Store optional source locations for the flags instead of booleans
345-
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool):
345+
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool, pat: Bool):
346346
def showDbg: Str =
347347
val flags = Buffer.empty[String]
348348
if mut then flags += "mut"
349349
if spec then flags += "spec"
350350
if genGetter then flags += "gen"
351351
if mod then flags += "module"
352+
if pat then flags += "pattern"
352353
flags.mkString(" ")
353354
override def toString: String = "" + showDbg + ""
354355

355356
object FldFlags:
356-
val empty: FldFlags = FldFlags(false, false, false, false)
357+
val empty: FldFlags = FldFlags(false, false, false, false, false)
357358
object benign:
358359
// * Some flags like `mut` and `module` are "benign" in the sense that they don't affect code-gen
359360
def unapply(flags: FldFlags): Bool =

0 commit comments

Comments
 (0)