Skip to content

Support tuple patterns #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
dcbe630
No longer share built-in symbols among tests
chengluyu Nov 13, 2024
f33235c
Rename `LitPat` to `Lit`
chengluyu Nov 6, 2024
ae34afc
Remove useless patterns and methods
chengluyu Nov 6, 2024
691ba3d
Support fixed-length tuple patterns
chengluyu Nov 7, 2024
d767708
Add `:js` to some tests so we can check the output
chengluyu Nov 8, 2024
7e0c86c
Add `:expect` command and improve a test with it
chengluyu Nov 8, 2024
735c2fb
Desugar variable-length tuple patterns
chengluyu Nov 11, 2024
627ff04
Show the weakness of current tuple normalization
chengluyu Nov 12, 2024
7317483
Memorization of tuple subscrutinees
chengluyu Nov 12, 2024
526dbde
Add built-in tuple primitives to the prelude
chengluyu Nov 12, 2024
eeeaf9d
Recognize overlapping tuple patterns
chengluyu Nov 12, 2024
37fe75d
Use meaningful function names in tuple pattern tests
chengluyu Nov 13, 2024
cfe1f58
Add a case to `showAsTree` for `Loc`
chengluyu Nov 13, 2024
bb9e916
Parse and desugar tuple patterns with `...`
chengluyu Nov 13, 2024
05d4658
Fix wrong casing in paths
chengluyu Nov 13, 2024
ab8ff6b
Support pattern matching on imported symbols
chengluyu Nov 14, 2024
e5a8311
Temporarily disable a TODO in VarianceTraverser
chengluyu Nov 14, 2024
4b3fa86
Deduplicate predefined libraries
chengluyu Nov 14, 2024
3c32ef4
Move the logic of getting ctor params to trees
chengluyu Nov 14, 2024
f4aeddd
Select `tupleSlice` and `tupleGet` in the right way
chengluyu Nov 14, 2024
f10cb6d
Remove duplicated test files for primes in names
chengluyu Nov 14, 2024
0aef97c
Hide long-winded parse trees in a test file
chengluyu Nov 14, 2024
5d8f452
Generate `Array` with qualified `globalThis` prefix
chengluyu Nov 14, 2024
e0ead12
Compute `ClassSymbol`'s arity directly from the tree
chengluyu Nov 14, 2024
387684c
Reduce the abuse of UCS debug commands
chengluyu Nov 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions hkmc2/jvm/src/test/scala/hkmc2/JSBackendDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
val sjs = NullaryCommand("sjs")
val showRepl = NullaryCommand("showRepl")
val silent = NullaryCommand("silent")
val expect = Command("expect"): ln =>
ln.trim

private val baseScp: codegen.js.Scope =
codegen.js.Scope.empty
Expand Down Expand Up @@ -61,6 +63,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
output(s"JS:")
output(jsStr)
def mkQuery(prefix: Str, jsStr: Str) =
import hkmc2.Message.MessageContext
val queryStr = jsStr.replaceAll("\n", " ")
val (reply, stderr) = host.query(queryStr, !expectRuntimeOrCodeGenErrors && fixme.isUnset && todo.isUnset)
reply match
Expand All @@ -73,11 +76,15 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
output(s"> ${line}")
content match
case "undefined" =>
case _ => output(s"$prefix= ${content}")
case _ =>
expect.get match
case S(expected) if content != expected => raise:
ErrorReport(msg"Expected: ${expected}, got: ${content}" -> N :: Nil,
source = Diagnostic.Source.Runtime)
case _ => output(s"$prefix= ${content}")
case ReplHost.Empty =>
case ReplHost.Unexecuted(message) => ???
case ReplHost.Error(isSyntaxError, message) =>
import hkmc2.Message.MessageContext
if (isSyntaxError) then
// If there is a syntax error in the generated code,
// it should be a code generation error.
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val res = freshVar(using ctx)
constrain(bodyCtx, sk | res)
(bodyTy, rhsCtx | res, rhsEff | bodyEff)
case Term.IfLike(Keyword.`if`, Split.Cons(Branch(cond, Pattern.LitPat(BoolLit(true)), Split.Else(cons)), Split.Else(alts))) =>
case Term.IfLike(Keyword.`if`, Split.Cons(Branch(cond, Pattern.Lit(BoolLit(true)), Split.Else(cons)), Split.Else(alts))) =>
val (condTy, condCtx, condEff) = typeCode(cond)
val (consTy, consCtx, consEff) = typeCode(cons)
val (altsTy, altsCtx, altsEff) = typeCode(alts)
Expand Down
1 change: 1 addition & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ case class End(msg: Str = "") extends BlockTail with ProductWithTail
enum Case:
case Lit(lit: Literal)
case Cls(cls: ClassSymbol | ModuleSymbol, path: Path)
case Tup(len: Int, inf: Bool)

sealed abstract class Result

Expand Down
12 changes: 7 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,15 @@ class Lowering(using TL, Raise, Elaborator.State):
End()
)
pat match
case Pattern.LitPat(lit) => mkMatch(Case.Lit(lit) -> go(tail, topLevel = false))
case Pattern.Lit(lit) => mkMatch(Case.Lit(lit) -> go(tail, topLevel = false))
case Pattern.ClassLike(cls, trm, args0, _refined) =>
subTerm(trm): st =>
val args = args0.getOrElse(Nil)
val clsDefn = cls.defn.getOrElse(die)
val clsParams = clsDefn.paramsOpt.getOrElse(Nil)
val clsParams = cls match
case cls: ClassSymbol => cls.tree.params
case _: ModuleSymbol => Nil
assert(args0.isEmpty || clsParams.length === args.length)
def mkArgs(args: Ls[Param -> BlockLocalSymbol])(using Subst): Case -> Block = args match
def mkArgs(args: Ls[(LocalSymbol & NamedSymbol) -> BlockLocalSymbol])(using Subst): Case -> Block = args match
// def mkArgs(args: Ls[Param -> BlockLocalSymbol])(using Subst): Block = args match
case Nil =>
// mkMatch(Case.Cls(cls, st) -> go(tail, topLevel = false))
Expand All @@ -232,13 +233,14 @@ class Lowering(using TL, Raise, Elaborator.State):
// Assign(arg, Select(sr, Tree.Ident("head")), mkArgs(args))

val (cse, blk) = mkArgs(args)
(cse, Assign(arg, Select(sr, param.sym.id/*FIXME incorrect Ident?*/), blk))
(cse, Assign(arg, Select(sr, param.id/*FIXME incorrect Ident?*/), blk))
// mkMatch(cse -> Assign(arg, Select(sr, param.sym.id/*FIXME incorrect Ident?*/), blk))
// Assign(arg, Select(sr, param.sym.id/*FIXME incorrect Ident?*/), mkArgs(args))

// val (cse, blk) =
// mkMatch(cse -> blk)
mkMatch(mkArgs(clsParams.zip(args)))
case Pattern.Tuple(len, inf) => mkMatch(Case.Tup(len, inf) -> go(tail, topLevel = false))
// Match(sr, cse :: Nil,
// S(go(restSplit, topLevel = true)),
// End()
Expand Down
14 changes: 12 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ class JSBuilder extends CodeBuilder:
case N => doc""
t :: e :: returningTerm(rest)
case Match(scrut, Case.Cls(cls, pth) -> trm :: Nil, els, rest) =>
val test = cls.defn.getOrElse(die).kind match
// case syntax.Mod => doc"=== ${result(pth)}"
val test = cls match
// case _: semantics.ModuleSymbol => doc"=== ${result(pth)}"
case _ => doc"instanceof ${result(pth)}"
val t = doc" # if (${ result(scrut) } $test) { #{ ${
returningTerm(trm)
Expand All @@ -255,6 +255,16 @@ class JSBuilder extends CodeBuilder:
doc" else { #{ ${ returningTerm(el) } #} # }"
case N => doc""
t :: e :: returningTerm(rest)
case Match(scrut, Case.Tup(len, inf) -> trm :: Nil, els, rest) =>
val test = doc"globalThis.Array.isArray(${ result(scrut) }) && ${ result(scrut) }.length ${if inf then ">=" else "==="} ${len}"
val t = doc" # if (${ test }) { #{ ${
returningTerm(trm)
} #} # }"
val e = els match
case S(el) =>
doc" else { #{ ${ returningTerm(el) } #} # }"
case N => doc""
t :: e :: returningTerm(rest)

case Begin(sub, thn) =>
doc"${returningTerm(sub)} # ${returningTerm(thn).stripBreaks}"
Expand Down
94 changes: 83 additions & 11 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import Message.MessageContext
import utils.TraceLogger
import hkmc2.syntax.Literal
import Keyword.{as, and, `else`, is, let, `then`}
import collection.mutable.HashMap
import Elaborator.{ctx, Ctxl}
import hkmc2.semantics.Elaborator.Ctx.globalThisSymbol

object Desugarer:
extension (op: Keyword.Infix)
Expand All @@ -20,6 +23,11 @@ object Desugarer:
case lhs and rhs => S((lhs, L(rhs)))
case lhs `then` rhs => S((lhs, R(rhs)))
case _ => N

class ScrutineeData:
val classes: HashMap[ClassSymbol, List[BlockLocalSymbol]] = HashMap.empty
val tupleLead: HashMap[Int, BlockLocalSymbol] = HashMap.empty
val tupleLast: HashMap[Int, BlockLocalSymbol] = HashMap.empty
end Desugarer

class Desugarer(tl: TraceLogger, elaborator: Elaborator)
Expand Down Expand Up @@ -58,16 +66,20 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
case Split.Let(name, term, tail) => Split.Let(name, term, tail ++ fallback)
case Split.Else(_) /* impossible */ | Split.End => fallback)

import collection.mutable.HashMap

private val subScrutineeMap = HashMap.empty[BlockLocalSymbol, HashMap[ClassSymbol, List[BlockLocalSymbol]]]
private val subScrutineeMap = HashMap.empty[BlockLocalSymbol, ScrutineeData]

extension (symbol: BlockLocalSymbol)
def getSubScrutinees(cls: ClassSymbol): List[BlockLocalSymbol] =
subScrutineeMap.getOrElseUpdate(symbol, HashMap.empty).getOrElseUpdate(cls, {
val arity = cls.defn.flatMap(_.paramsOpt.map(_.length)).getOrElse(0)
(0 until arity).map(i => TempSymbol(nextUid, N, s"param$i")).toList
subScrutineeMap.getOrElseUpdate(symbol, new ScrutineeData).classes.getOrElseUpdate(cls, {
(0 until cls.arity).map(i => TempSymbol(nextUid, N, s"param$i")).toList
})
def getTupleLeadSubScrutinee(index: Int): BlockLocalSymbol =
val data = subScrutineeMap.getOrElseUpdate(symbol, new ScrutineeData)
data.tupleLead.getOrElseUpdate(index, TempSymbol(nextUid, N, s"first$index"))
def getTupleLastSubScrutinee(index: Int): BlockLocalSymbol =
val data = subScrutineeMap.getOrElseUpdate(symbol, new ScrutineeData)
data.tupleLast.getOrElseUpdate(index, TempSymbol(nextUid, N, s"last$index"))


def default: Split => Sequel = split => _ => split

Expand Down Expand Up @@ -349,6 +361,12 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
raise(ErrorReport(msg"Unrecognized pattern split." -> tree.toLoc :: Nil))
_ => _ => Split.default(Term.Error)

private lazy val tupleSlice =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("tupleSlice")))

private lazy val tupleGet =
term(Sel(Sel(Ident("globalThis"), Ident("Predef")), Ident("tupleGet")))

/** Elaborate a single match (a scrutinee and a pattern) and forms a split
* with an innermost split as the sequel of the match.
* @param scrutSymbol the symbol representing the scrutinee
Expand Down Expand Up @@ -383,6 +401,54 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
// Raise an error and discard `sequel`. Use `fallback` instead.
raise(ErrorReport(msg"Cannot use this ${ctor.describe} as a pattern" -> ctor.toLoc :: Nil))
fallback
case Tree.Tup(args) => fallback => ctx => trace(
pre = s"expandMatch <<< ${args.mkString(", ")}",
post = (r: Split) => s"expandMatch >>> ${r.showDbg}"
):
// Break tuple into three parts:
// 1. A fixed number of leading patterns.
// 2. A variable number of middle patterns indicated by `..`.
// 3. A fixed number of trailing patterns.
val (lead, rest) = args.foldLeft[(Ls[Tree], Opt[(Opt[Tree], Ls[Tree])])]((Nil, N)):
case ((lead, N), Spread(_, _, patOpt)) => (lead, S((patOpt, Nil)))
case ((lead, N), pat) => (lead :+ pat, N)
case ((lead, S((rest, last))), pat) => (lead, S((rest, last :+ pat)))
// Some helper functions. TODO: deduplicate
def int(i: Int) = Term.Lit(IntLit(BigInt(i)))
def fld(t: Term) = Fld(FldFlags.empty, t, N)
def tup(xs: Fld*) = Term.Tup(xs.toList)(Tup(Nil))
def app(lhs: Term, rhs: Term, sym: FlowSymbol) = Term.App(lhs, rhs)(Tree.App(Tree.Empty(), Tree.Empty()), sym)
def getLast(i: Int) = TempSymbol(nextUid, N, s"last$i")
// `wrap`: add let bindings for tuple elements
// `matches`: pairs of patterns and symbols to be elaborated
val (wrapRest, restMatches) = rest match
case S((rest, last)) =>
val (wrapLast, reversedLastMatches) = last.reverseIterator.zipWithIndex
.foldLeft[(Split => Split, Ls[(BlockLocalSymbol, Tree)])]((identity, Nil)):
case ((wrapInner, matches), (pat, lastIndex)) =>
val sym = scrutSymbol.getTupleLastSubScrutinee(lastIndex)
val wrap = (split: Split) =>
Split.Let(sym, app(tupleGet, tup(fld(ref), fld(int(-1 - lastIndex))), sym), wrapInner(split))
(wrap, (sym, pat) :: matches)
val lastMatches = reversedLastMatches.reverse
rest match
case N => (wrapLast, lastMatches)
case S(pat) =>
val sym = TempSymbol(nextUid, N, "rest")
val wrap = (split: Split) =>
Split.Let(sym, app(tupleSlice, tup(fld(ref), fld(int(lead.length)), fld(int(last.length))), sym), wrapLast(split))
(wrap, (sym, pat) :: lastMatches)
case N => (identity: Split => Split, Nil)
val (wrap, matches) = lead.zipWithIndex.foldRight((wrapRest, restMatches)):
case ((pat, i), (wrapInner, matches)) =>
val sym = scrutSymbol.getTupleLeadSubScrutinee(i)
val wrap = (split: Split) => Split.Let(sym, Term.Sel(ref, Ident(s"$i"))(N), wrapInner(split))
(wrap, (sym, pat) :: matches)
Branch(
ref,
Pattern.Tuple(lead.length + rest.fold(0)(_._2.length), rest.isDefined),
wrap(subMatches(matches, sequel)(Split.End)(ctx))
) ~: fallback
// A single constructor pattern.
case pat @ App(ctor @ (_: Ident | _: Sel), Tup(args)) => fallback => ctx => trace(
pre = s"expandMatch <<< ${ctor}(${args.iterator.map(_.showDbg).mkString(", ")})",
Expand All @@ -391,11 +457,14 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
val clsTrm = elaborator.cls(ctor)
clsTrm.symbol.flatMap(_.asClsLike) match
case S(cls: ClassSymbol) =>
val arity = cls.defn.flatMap(_.paramsOpt.map(_.length)).getOrElse(0)
if args.length =/= arity then
val n = arity.toString
val arity = cls.arity
if arity =/= args.length then
val m = args.length.toString
raise(ErrorReport(msg"mismatched arity: expect $n, found $m" -> pat.toLoc :: Nil))
ErrorReport:
if arity == 0 then
msg"the constructor does not take any arguments but found $m" -> pat.toLoc :: Nil
else
msg"mismatched arity: expect ${arity.toString}, found $m" -> pat.toLoc :: Nil
val params = scrutSymbol.getSubScrutinees(cls)
Branch(
ref,
Expand All @@ -411,11 +480,14 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)
pre = s"expandMatch: literal <<< $literal",
post = (r: Split) => s"expandMatch: literal >>> ${r.showDbg}"
):
Branch(ref, Pattern.LitPat(literal), sequel(ctx)) ~: fallback
Branch(ref, Pattern.Lit(literal), sequel(ctx)) ~: fallback
// A single pattern in conjunction with more conditions
case pattern and consequent => fallback => ctx =>
val innerSplit = termSplit(consequent, identity)(Split.End)
expandMatch(scrutSymbol, pattern, innerSplit)(fallback)(ctx)
case Jux(Ident(".."), Ident(_)) => fallback => _ =>
raise(ErrorReport(msg"Illgeal rest pattern." -> pattern.toLoc :: Nil))
fallback
case _ => fallback => _ =>
// Raise an error and discard `sequel`. Use `fallback` instead.
raise(ErrorReport(msg"Unrecognized pattern." -> pattern.toLoc :: Nil))
Expand Down
65 changes: 32 additions & 33 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,20 @@ import Keyword.{`let`, `set`}

object Elaborator:

val builtinOpsMap: Map[Str, BuiltinSymbol] =
val binOps: Ls[Str] = Ls(
",",
"+", "-", "*", "/", "%",
"==", "!=", "<", "<=", ">", ">=",
"===",
"&&", "||")
val isUnary: Str => Bool = Set("-", "+", "!", "~").contains
val baseBuiltins = binOps.map: op =>
op -> BuiltinSymbol(op, binary = true, unary = isUnary(op), nullary = false)
.toMap
baseBuiltins
+ (";" -> baseBuiltins(","))
+ ("+." -> baseBuiltins("+"))
+ ("-." -> baseBuiltins("-"))
+ ("*." -> baseBuiltins("*"))
val reservedNames = builtinOpsMap.keySet + "NaN" + "Infinity"
private val binaryOps = Ls(
",",
"+", "-", "*", "/", "%",
"==", "!=", "<", "<=", ">", ">=",
"===",
"&&", "||")
private val unaryOps = Set("-", "+", "!", "~")
private val aliasOps = Map(
";" -> ",",
"+." -> "+",
"-." -> "-",
"*." -> "*")

val reservedNames = binaryOps.toSet ++ aliasOps.keySet + "NaN" + "Infinity"

case class Ctx(outer: Opt[InnerSymbol], parent: Opt[Ctx], env: Map[Str, Ctx.Elem]):
def +(local: Str -> Symbol): Ctx = copy(outer, env = env + local.mapSecond(Ctx.RefElem(_)))
Expand Down Expand Up @@ -103,6 +100,13 @@ extends Importer:
private val allocSkolemSym = VarSymbol(Ident("Alloc"), allocSkolemUID)
private val allocSkolemDef = TyParam(FldFlags.empty, N, allocSkolemSym)
allocSkolemSym.decl = S(allocSkolemDef)

private val builtinOpsMap =
val baseBuiltins = binaryOps.map: op =>
op -> BuiltinSymbol(op, binary = true, unary = unaryOps(op), nullary = false)
.toMap
baseBuiltins ++ aliasOps.map:
case (alias, base) => alias -> baseBuiltins(base)

def mkLetBinding(sym: LocalSymbol, rhs: Term): Ls[Statement] =
LetDecl(sym) :: DefineVar(sym, rhs) :: Nil
Expand Down Expand Up @@ -625,13 +629,8 @@ extends Importer:
if ctx.outer.isDefined then TermSymbol(k, ctx.outer, id)
else VarSymbol(id, nextUid)

def param(t: Tree): Ctxl[Ls[Param]] = t match
case id: Ident =>
Param(FldFlags.empty, fieldOrVarSym(ParamBind, id), N) :: Nil
case InfixApp(lhs: Ident, Keyword.`:`, rhs) =>
Param(FldFlags.empty, fieldOrVarSym(ParamBind, lhs), S(term(rhs))) :: Nil
case App(Ident(","), list) => params(list)._1
case TermDef(ImmutVal, inner, _) => param(inner)
def param(t: Tree): Ctxl[Ls[Param]] = t.param.map: (p, t) =>
Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term))

def params(t: Tree): Ctxl[(Ls[Param], Ctx)] = t match
case Tup(ps) =>
Expand All @@ -654,14 +653,14 @@ extends Importer:
case id @ Ident(name) =>
val sym = boundVars.getOrElseUpdate(name, VarSymbol(id, nextUid))
Pattern.Var(sym)
case Tup(fields) =>
val pats = fields.map(
f => pattern(f) match
case (pat, vars) =>
boundVars ++= vars
pat
)
Pattern.Tuple(pats)
// case Tup(fields) =>
// val pats = fields.map(
// f => pattern(f) match
// case (pat, vars) =>
// boundVars ++= vars
// pat
// )
// Pattern.Tuple(pats)
case _ =>
???
(go(t), boundVars.toList)
Expand Down Expand Up @@ -713,7 +712,7 @@ extends Importer:
if !tp.isContravariant then traverseType(pol)(targ)
if !tp.isCovariant then traverseType(pol.!)(targ)
case N =>
TODO(sym->sym.uid)
// TODO(sym->sym.uid)
case S(sym: TypeAliasSymbol) =>
// TODO dedup with above...
sym.defn match
Expand Down
Loading
Loading