Skip to content

Commit dbc6bcb

Browse files
chengluyuLPTK
andauthored
Support the use of do as a connective in UCS (hkust-taco#246)
Co-authored-by: Lionel Parreaux <lionel.parreaux@gmail.com>
1 parent 76b129a commit dbc6bcb

File tree

20 files changed

+285
-87
lines changed

20 files changed

+285
-87
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ import mlscript.utils.*, shorthands.*
66
import Message.MessageContext
77
import utils.TraceLogger
88
import hkmc2.syntax.Literal
9-
import Keyword.{as, and, `else`, is, let, `then`}
10-
import collection.mutable.HashMap
9+
import Keyword.{as, and, `do`, `else`, is, let, `then`}
10+
import collection.mutable.{HashMap, SortedSet}
1111
import Elaborator.{ctx, Ctxl}
1212
import ucs.DesugaringBase
1313

@@ -16,30 +16,52 @@ object Desugarer:
1616
infix def unapply(tree: Tree): Opt[(Tree, Tree)] = tree match
1717
case InfixApp(lhs, `op`, rhs) => S((lhs, rhs))
1818
case _ => N
19-
20-
/** An extractor that accepts either `A and B` or `A then B`. */
21-
object `~>`:
22-
infix def unapply(tree: Tree): Opt[(Tree, Tree \/ Tree)] = tree match
23-
case lhs and rhs => S((lhs, L(rhs)))
24-
case lhs `then` rhs => S((lhs, R(rhs)))
25-
case _ => N
26-
19+
2720
class ScrutineeData:
2821
val classes: HashMap[ClassSymbol, List[BlockLocalSymbol]] = HashMap.empty
2922
val tupleLead: HashMap[Int, BlockLocalSymbol] = HashMap.empty
3023
val tupleLast: HashMap[Int, BlockLocalSymbol] = HashMap.empty
3124
end Desugarer
3225

33-
class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
26+
class Desugarer(val elaborator: Elaborator)
3427
(using raise: Raise, state: Elaborator.State, c: Elaborator.Ctx) extends DesugaringBase:
3528
import Desugarer.*
3629
import Elaborator.Ctx
37-
import elaborator.term
38-
import tl.*
30+
import elaborator.term, elaborator.tl.*
31+
32+
given Ordering[Loc] = Ordering.by: loc =>
33+
(loc.spanStart, loc.spanEnd)
34+
35+
/** Keep track of the locations where `do` and `then` are used as connectives. */
36+
private val kwLocSets = (SortedSet.empty[Loc], SortedSet.empty[Loc])
37+
38+
private def reportInconsistentConnectives(kw: Keyword, kwLoc: Opt[Loc]): Unit =
39+
log(kwLocSets)
40+
(kwLocSets._1.headOption, kwLocSets._2.headOption) match
41+
case (Some(doLoc), Some(thenLoc)) =>
42+
raise(ErrorReport(
43+
msg"Mixed use of `do` and `then` in the `${kw.name}` expression." -> kwLoc
44+
:: msg"Keyword `then` is used here." -> S(thenLoc)
45+
:: msg"Keyword `do` is used here." -> S(doLoc) :: Nil
46+
))
47+
case _ => ()
48+
49+
private def topmostDefault: Split =
50+
if kwLocSets._1.nonEmpty then Split.Else(Term.Lit(UnitLit(true))) else Split.End
51+
52+
/** An extractor that accepts either `A and B`, `A then B`, and `A do B`. It
53+
* also keeps track of the usage of `then` and `do`.
54+
*/
55+
object `~>`:
56+
infix def unapply(tree: Tree): Opt[(Tree, Tree \/ Tree)] = tree match
57+
case lhs and rhs => S((lhs, L(rhs)))
58+
case lhs `then` rhs => kwLocSets._2 ++= tree.toLoc; S((lhs, R(rhs)))
59+
case lhs `do` rhs => kwLocSets._1 ++= tree.toLoc; S((lhs, R(rhs)))
60+
case _ => N
3961

4062
// We're working on composing continuations in the UCS translation.
4163
// The type of continuation is `Split => Ctx => Split`.
42-
// The first parameter represents the fallback split, which does not have
64+
// The first parameter represents the "backup" split, which does not have
4365
// access to the bindings in the current match. The second parameter
4466
// represents the context with bindings in the current match.
4567

@@ -82,10 +104,6 @@ class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
82104

83105
def default: Split => Sequel = split => _ => split
84106

85-
/** Desugar UCS shorthands. */
86-
def shorthands(tree: Tree): Sequel = termSplitShorthands(tree, identity):
87-
Split.default(Term.Lit(Tree.BoolLit(false)))
88-
89107
private def termSplitShorthands(tree: Tree, finish: Term => Term): Split => Sequel = tree match
90108
case Block(branches) => branches match
91109
case Nil => lastWords("encountered empty block")
@@ -166,6 +184,12 @@ class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
166184
val sym = VarSymbol(ident)
167185
val fallbackCtx = ctx + (ident.name -> sym)
168186
Split.Let(sym, term(termTree)(using ctx), elabFallback(fallback)(fallbackCtx)).withLocOf(t)
187+
case Modified(Keyword.`do`, doLoc, computation) => fallback => ctx => trace(
188+
pre = s"termSplit: do $computation",
189+
post = (res: Split) => s"termSplit: else >>> $res"
190+
):
191+
val sym = TempSymbol(N, "doTemp")
192+
Split.Let(sym, term(computation)(using ctx), elabFallback(fallback)(ctx)).withLocOf(t)
169193
case Modified(Keyword.`else`, elsLoc, default) => fallback => ctx => trace(
170194
pre = s"termSplit: else $default",
171195
post = (res: Split) => s"termSplit: else >>> $res"
@@ -241,6 +265,12 @@ class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
241265
val sym = VarSymbol(ident)
242266
val fallbackCtx = ctx + (ident.name -> sym)
243267
Split.Let(sym, term(termTree)(using ctx), elabFallback(fallbackCtx))
268+
case (Tree.Empty(), Modified(Keyword.`do`, doLoc, computation)) => ctx => trace(
269+
pre = s"termSplit: do $computation",
270+
post = (res: Split) => s"termSplit: else >>> $res"
271+
):
272+
val sym = TempSymbol(N, "doTemp")
273+
Split.Let(sym, term(computation)(using ctx), elabFallback(ctx))
244274
case (Tree.Empty(), Modified(Keyword.`else`, elsLoc, default)) => ctx =>
245275
// TODO: report `rest` as unreachable
246276
Split.default(term(default)(using ctx))
@@ -322,6 +352,12 @@ class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
322352
val sym = VarSymbol(ident)
323353
val fallbackCtx = ctx + (ident.name -> sym)
324354
Split.Let(sym, term(termTree)(using ctx), elabFallback(backup)(fallbackCtx))
355+
case Modified(Keyword.`do`, doLoc, computation) => fallback => ctx => trace(
356+
pre = s"patternSplit (do) <<< $computation",
357+
post = (res: Split) => s"patternSplit: else >>> $res"
358+
):
359+
val sym = TempSymbol(N, "doTemp")
360+
Split.Let(sym, term(computation)(using ctx), elabFallback(fallback)(ctx))
325361
case Modified(Keyword.`else`, elsLoc, body) => backup => ctx => trace(
326362
pre = s"patternSplit (else) <<< $tree",
327363
post = (res: Split) => s"patternSplit (else) >>> ${res.showDbg}"
@@ -499,4 +535,20 @@ class Desugarer(tl: TraceLogger, val elaborator: Elaborator)
499535
):
500536
val innermostSplit = subMatches(rest, sequel)(fallback)
501537
expandMatch(scrutinee, pattern, innermostSplit)(fallback)
538+
539+
/** Desugar `case` expressions. */
540+
def apply(tree: Case, scrut: VarSymbol)(using Ctx): Split =
541+
val topmost = patternSplit(tree.branches, scrut)(Split.End)(ctx)
542+
reportInconsistentConnectives(Keyword.`case`, tree.kwLoc)
543+
topmost ++ topmostDefault
544+
545+
/** Desugar `if` and `while` expressions. */
546+
def apply(tree: IfLike)(using Ctx): Split =
547+
val topmost = termSplit(tree.split, identity)(Split.End)(ctx)
548+
reportInconsistentConnectives(tree.kw, tree.kwLoc)
549+
topmost ++ topmostDefault
550+
551+
/** Desugar `is` and `and` shorthands. */
552+
def apply(tree: InfixApp)(using Ctx): Split =
553+
termSplitShorthands(tree, identity)(Split.default(Term.Lit(Tree.BoolLit(false))))(ctx)
502554
end Desugarer

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,8 @@ extends Importer:
269269
Term.Lam(syms, term(rhs)(using nestCtx))
270270
case InfixApp(lhs, Keyword.`:`, rhs) =>
271271
Term.Asc(term(lhs), term(rhs))
272-
case InfixApp(lhs, Keyword.`is` | Keyword.`and`, rhs) =>
273-
val des = new Desugarer(tl, this).shorthands(tree)(ctx)
272+
case tree @ InfixApp(lhs, Keyword.`is` | Keyword.`and`, rhs) =>
273+
val des = new Desugarer(this)(tree)
274274
val nor = new ucs.Normalization(tl)(des)
275275
Term.IfLike(Keyword.`if`, des)(nor)
276276
case App(Ident("|"), Tree.Tup(lhs :: rhs :: Nil)) =>
@@ -348,8 +348,8 @@ extends Importer:
348348
Term.New(cls(c, inAppPrefix = false), Nil).withLocOf(tree)
349349
// case _ =>
350350
// raise(ErrorReport(msg"Illegal new expression." -> tree.toLoc :: Nil))
351-
case Tree.IfLike(kw, split) =>
352-
val desugared = new Desugarer(tl, this).termSplit(split, identity)(Split.End)(ctx)
351+
case tree @ Tree.IfLike(kw, _, split) =>
352+
val desugared = new Desugarer(this)(tree)
353353
scoped("ucs:desugared"):
354354
log(s"Desugared:\n${Split.display(desugared)}")
355355
val normalized = new ucs.Normalization(tl)(desugared)
@@ -358,10 +358,9 @@ extends Importer:
358358
Term.IfLike(kw, desugared)(normalized)
359359
case Tree.Quoted(body) => Term.Quoted(term(body))
360360
case Tree.Unquoted(body) => Term.Unquoted(term(body))
361-
case Tree.Case(branches) =>
361+
case tree @ Tree.Case(_, branches) =>
362362
val scrut = VarSymbol(Ident("caseScrut"))
363-
val desugarer = new Desugarer(tl, this)
364-
val des = desugarer.patternSplit(branches, scrut)(Split.End)(ctx)
363+
val des = new Desugarer(this)(tree, scrut)
365364
scoped("ucs:desugared"):
366365
log(s"Desugared:\n${Split.display(des)}")
367366
val nor = new ucs.Normalization(tl)(des)

hkmc2/shared/src/main/scala/hkmc2/syntax/Keyword.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,11 @@ object Keyword:
6363

6464
val `if` = Keyword("if", N, nextPrec)
6565
val `while` = Keyword("while", N, curPrec)
66-
val `then` = Keyword("then", nextPrec, curPrec)
66+
67+
val thenPrec = nextPrec
68+
val `then` = Keyword("then", thenPrec, thenPrec)
69+
val `do` = Keyword("do", thenPrec, thenPrec)
70+
6771
val `else` = Keyword("else", nextPrec, curPrec)
6872
val `case` = Keyword("case", N, N)
6973
val `fun` = Keyword("fun", N, N)
@@ -81,7 +85,6 @@ object Keyword:
8185
val `in` = Keyword("in", curPrec, curPrec)
8286
val `out` = Keyword("out", N, curPrec)
8387
val `set` = Keyword("set", N, curPrec)
84-
val `do` = Keyword("do", N, N)
8588
val `declare` = Keyword("declare", N, N)
8689
val `trait` = Keyword("trait", N, N)
8790
val `mixin` = Keyword("mixin", N, N)
@@ -125,7 +128,7 @@ object Keyword:
125128
`abstract`, mut, virtual, `override`, declare, public, `private`)
126129

127130
type Infix = `and`.type | `or`.type | `then`.type | `else`.type | `is`.type | `:`.type | `->`.type |
128-
`=>`.type | `extends`.type | `restricts`.type | `as`.type
131+
`=>`.type | `extends`.type | `restricts`.type | `as`.type | `do`.type
129132

130133
type Ellipsis = `...`.type | `..`.type
131134

hkmc2/shared/src/main/scala/hkmc2/syntax/ParseRule.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ class ParseRules(using State):
183183
val items = split match
184184
case Block(stmts) => stmts.appended(clause)
185185
case _ => split :: clause :: Nil
186-
IfLike(kw, Block(items))
187-
case (split, N) => IfLike(kw, split)
186+
IfLike(kw, N/* TODO */, Block(items))
187+
case (split, N) => IfLike(kw, N/* TODO */, split)
188188
,
189189
Blk(
190190
ParseRule(s"'${kw.name}' block")(End(()))
191-
) { case (body, _) => IfLike(kw, body) }
191+
) { case (body, _) => IfLike(kw, N/* TODO */, body) }
192192
)
193193

194194
def typeAliasLike(kw: Keyword, kind: TypeDefKind): Kw[TypeDef] =
@@ -258,7 +258,7 @@ class ParseRules(using State):
258258
,
259259
Kw(`case`):
260260
ParseRule("`case` keyword")(
261-
Blk(ParseRule("`case` branches")(End(())))((body, _: Unit) => Case(body))
261+
Blk(ParseRule("`case` branches")(End(())))((body, _: Unit) => Case(N/* TODO */, body))
262262
)
263263
,
264264
Kw(`region`):
@@ -358,6 +358,7 @@ class ParseRules(using State):
358358
genInfixRule(`:`, (rhs, _: Unit) => lhs => InfixApp(lhs, `:`, rhs)),
359359
genInfixRule(`extends`, (rhs, _: Unit) => lhs => InfixApp(lhs, `extends`, rhs)),
360360
genInfixRule(`restricts`, (rhs, _: Unit) => lhs => InfixApp(lhs, `restricts`, rhs)),
361+
genInfixRule(`do`, (rhs, _: Unit) => lhs => InfixApp(lhs, `do`, rhs)),
361362
)
362363

363364
end ParseRules

hkmc2/shared/src/main/scala/hkmc2/syntax/Parser.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ abstract class Parser(
547547
val ele = simpleExprImpl(prec)
548548
term match
549549
case InfixApp(lhs, Keyword.`then`, rhs) =>
550-
Quoted(IfLike(Keyword.`if`, Block(
550+
Quoted(IfLike(Keyword.`if`, S(l0), Block(
551551
InfixApp(Unquoted(lhs), Keyword.`then`, Unquoted(rhs)) :: Modified(Keyword.`else`, N, Unquoted(ele)) :: Nil
552552
)))
553553
case tk =>
@@ -900,6 +900,12 @@ abstract class Parser(
900900
case (NEWLINE, _) :: (KEYWORD(kw), _) :: _
901901
if kw.canStartInfixOnNewLine && kw.leftPrecOrMin > prec
902902
&& infixRules.kwAlts.contains(kw.name)
903+
&& (kw isnt Keyword.`do`) // This is to avoid the following case:
904+
// ```
905+
// 0 then "null"
906+
// do console.log("non-null")
907+
// ```
908+
// Otherwise, `do` will be parsed as an infix operator
903909
=>
904910
consume
905911
exprCont(acc, prec, allowNewlines = false)

hkmc2/shared/src/main/scala/hkmc2/syntax/Tree.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ enum Tree extends AutoLocated:
6666
case Sel(prefix: Tree, name: Ident)
6767
case InfixApp(lhs: Tree, kw: Keyword.Infix, rhs: Tree)
6868
case New(body: Tree)
69-
case IfLike(kw: Keyword.`if`.type | Keyword.`while`.type, split: Tree)
69+
case IfLike(kw: Keyword.`if`.type | Keyword.`while`.type, kwLoc: Opt[Loc], split: Tree)
7070
@deprecated("Use If instead", "hkmc2-ucs")
7171
case IfElse(cond: Tree, alt: Tree)
72-
case Case(branches: Tree)
72+
case Case(kwLoc: Opt[Loc], branches: Tree)
7373
case Region(name: Tree, body: Tree)
7474
case RegRef(reg: Tree, value: Tree)
7575
case Effectful(eff: Tree, body: Tree)
@@ -95,9 +95,9 @@ enum Tree extends AutoLocated:
9595
case InfixApp(lhs, _, rhs) => Ls(lhs, rhs)
9696
case TermDef(k, head, rhs) => head :: rhs.toList
9797
case New(body) => body :: Nil
98-
case IfLike(_, split) => split :: Nil
98+
case IfLike(_, _, split) => split :: Nil
9999
case IfElse(cond, alt) => cond :: alt :: Nil
100-
case Case(bs) => Ls(bs)
100+
case Case(_, bs) => Ls(bs)
101101
case Region(name, body) => name :: body :: Nil
102102
case RegRef(reg, value) => reg :: value :: Nil
103103
case Effectful(eff, body) => eff :: body :: Nil
@@ -133,9 +133,9 @@ enum Tree extends AutoLocated:
133133
case Sel(prefix, name) => "selection"
134134
case InfixApp(lhs, kw, rhs) => "infix operation"
135135
case New(body) => "new"
136-
case IfLike(Keyword.`if`, split) => "if expression"
137-
case IfLike(Keyword.`while`, split) => "while expression"
138-
case Case(branches) => "case"
136+
case IfLike(Keyword.`if`, _, split) => "if expression"
137+
case IfLike(Keyword.`while`, _, split) => "while expression"
138+
case Case(_, branches) => "case"
139139
case Region(name, body) => "region"
140140
case RegRef(reg, value) => "region reference"
141141
case Effectful(eff, body) => "effectful"

hkmc2/shared/src/test/mlscript-compile/Predef.mls

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,14 @@ class MatchResult(captures)
3030
class MatchFailure(errors)
3131

3232
fun checkArgs(functionName, expected, isUB, got) =
33-
if got < expected || isUB && got > expected then
33+
if got < expected || isUB && got > expected do
3434
let name = if functionName.length > 0 then " '" + functionName + "'" else ""
3535
throw globalThis.Error("Function" + name + " expected " + expected + " arguments but got " + got)
3636
// TODO
3737
// throw globalThis.Error("Function" + name + " expected "
3838
// + expected
3939
// + (if isUB then "" else " at least")
4040
// + " arguments but got " + got)
41-
else ()
4241

4342
module TraceLogger with
4443
mut val enabled = false

hkmc2/shared/src/test/mlscript-compile/apps/Accounting.mls

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@ class Line(val name: Str, val proj: Project, val starting_balance: Num, val isMa
3030
fun expense(amt) =
3131
set balance = balance -. amt
3232
fun mustBeEmpty() =
33-
if balance > 10_000 then
33+
if balance > 10_000 do
3434
warnings.push of "> **❗️** Unspent balance of " ~ name ~ ": `" ~ display(balance) ~ "`"
35-
else () // TODO allow omitting else branch
3635

3736
val lines = []
3837

hkmc2/shared/src/test/mlscript/basics/OpBlocks.mls

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ fun f(x) = if x
8686
//│ Ident of "x"
8787
//│ rhs = S of IfLike:
8888
//│ kw = keyword 'if'
89+
//│ kwLoc = N
8990
//│ split = App:
9091
//│ lhs = Ident of "x"
9192
//│ rhs = OpBlock of Ls of
@@ -120,10 +121,10 @@ fun f(x) = if x
120121
> 0 then "a"
121122
is 0 then "b"
122123
//│ ╔══[PARSE ERROR] Expect an operator instead of 'is' keyword
123-
//│ ║ l.121: is 0 then "b"
124+
//│ ║ l.122: is 0 then "b"
124125
//│ ╙── ^^
125126
//│ ╔══[PARSE ERROR] Unexpected 'is' keyword here
126-
//│ ║ l.121: is 0 then "b"
127+
//│ ║ l.122: is 0 then "b"
127128
//│ ╙── ^^
128129
//│ ═══[ERROR] Unrecognized operator branch.
129130

@@ -135,11 +136,11 @@ fun f(x) = if x
135136
foo(A) then a
136137
bar(B) then b
137138
//│ ╔══[ERROR] Unrecognized term split (juxtaposition).
138-
//│ ║ l.134: fun f(x) = if x
139+
//│ ║ l.135: fun f(x) = if x
139140
//│ ║ ^
140-
//│ ║ l.135: foo(A) then a
141+
//│ ║ l.136: foo(A) then a
141142
//│ ║ ^^^^^^^^^^^^^^^
142-
//│ ║ l.136: bar(B) then b
143+
//│ ║ l.137: bar(B) then b
143144
//│ ╙── ^^^^^^^^^^^^^^^
144145

145146

@@ -148,10 +149,10 @@ fun f(x) = if x
148149
is 0 then "a"
149150
is 1 then "b"
150151
//│ ╔══[PARSE ERROR] Expected start of statement in this position; found 'is' keyword instead
151-
//│ ║ l.149: is 1 then "b"
152+
//│ ║ l.150: is 1 then "b"
152153
//│ ╙── ^^
153154
//│ ╔══[PARSE ERROR] Expected end of input; found literal instead
154-
//│ ║ l.149: is 1 then "b"
155+
//│ ║ l.150: is 1 then "b"
155156
//│ ╙── ^
156157
//│ = [Function: f]
157158

0 commit comments

Comments
 (0)