Skip to content

Commit 8437765

Browse files
committed
Support fixed-length tuple patterns
# Conflicts: # hkmc2/shared/src/test/mlscript/ucs/patterns/SimpleTuple.mls
1 parent 5faaf2d commit 8437765

File tree

10 files changed

+121
-75
lines changed

10 files changed

+121
-75
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ case class End(msg: Str = "") extends BlockTail with ProductWithTail
100100
enum Case:
101101
case Lit(lit: Literal)
102102
case Cls(cls: ClassSymbol)
103+
case Tup(len: Int)
103104

104105
sealed abstract class Result
105106

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@ class Lowering(using TL, Raise, Elaborator.State):
245245
val (cse, blk) = mkArgs(args)
246246
(cse, Assign(arg, Select(sr, param.sym.id/*FIXME incorrect Ident?*/), blk))
247247
mkArgs(clsParams.zip(args))
248+
case Pattern.Tuple(args) =>
249+
val cse = Case.Tup(args.length) -> go(tail, topLevel = false)
250+
val blk = args.iterator.zipWithIndex.foldRight[Block](cse._2):
251+
case ((f, i), acc) => Assign(f, Select(sr, Tree.Ident(s"$i")), acc)
252+
(cse._1, blk)
248253
Match(sr, cse :: Nil,
249254
S(go(restSplit, topLevel = true)),
250255
End()

hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,16 @@ class JSBuilder extends CodeBuilder:
193193
doc" else { #{ ${ returningTerm(el) } #} # }"
194194
case N => doc""
195195
t :: e :: returningTerm(rest)
196+
case Match(scrut, Case.Tup(len) -> trm :: Nil, els, rest) =>
197+
val test = doc"Array.isArray(${ result(scrut) }) && ${ result(scrut) }.length === ${len}"
198+
val t = doc" # if (${ test }) { #{ ${
199+
returningTerm(trm)
200+
} #} # }"
201+
val e = els match
202+
case S(el) =>
203+
doc" else { #{ ${ returningTerm(el) } #} # }"
204+
case N => doc""
205+
t :: e :: returningTerm(rest)
196206

197207
case Begin(sub, thn) =>
198208
doc"${returningTerm(sub)} # ${returningTerm(thn).stripBreaks}"

hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Message.MessageContext
77
import utils.TraceLogger
88
import hkmc2.syntax.Literal
99
import Keyword.{as, and, `else`, is, let, `then`}
10+
import collection.mutable.HashMap
1011

1112
object Desugarer:
1213
extension (op: Keyword.Infix)
@@ -20,6 +21,10 @@ object Desugarer:
2021
case lhs and rhs => S((lhs, L(rhs)))
2122
case lhs `then` rhs => S((lhs, R(rhs)))
2223
case _ => N
24+
25+
class ScrutineeData:
26+
val classes: HashMap[ClassSymbol, List[BlockLocalSymbol]] = HashMap.empty
27+
val tuples: HashMap[Int, List[BlockLocalSymbol]] = HashMap.empty
2328
end Desugarer
2429

2530
class Desugarer(tl: TraceLogger, elaborator: Elaborator)(using raise: Raise, state: Elaborator.State):
@@ -57,16 +62,18 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)(using raise: Raise, sta
5762
case Split.Let(name, term, tail) => Split.Let(name, term, tail ++ fallback)
5863
case Split.Else(_) /* impossible */ | Split.End => fallback)
5964

60-
import collection.mutable.HashMap
61-
62-
private val subScrutineeMap = HashMap.empty[BlockLocalSymbol, HashMap[ClassSymbol, List[BlockLocalSymbol]]]
65+
private val subScrutineeMap = HashMap.empty[BlockLocalSymbol, ScrutineeData]
6366

6467
extension (symbol: BlockLocalSymbol)
6568
def getSubScrutinees(cls: ClassSymbol): List[BlockLocalSymbol] =
66-
subScrutineeMap.getOrElseUpdate(symbol, HashMap.empty).getOrElseUpdate(cls, {
69+
subScrutineeMap.getOrElseUpdate(symbol, new ScrutineeData).classes.getOrElseUpdate(cls, {
6770
val arity = cls.defn.flatMap(_.paramsOpt.map(_.length)).getOrElse(0)
6871
(0 until arity).map(i => TempSymbol(nextUid, N, s"param$i")).toList
6972
})
73+
def getSubScrutinees(arity: Int): List[BlockLocalSymbol] =
74+
subScrutineeMap.getOrElseUpdate(symbol, new ScrutineeData).tuples.getOrElseUpdate(arity, {
75+
(0 until arity).map(i => TempSymbol(nextUid, N, s"elem$i")).toList
76+
})
7077

7178
def default: Split => Sequel = split => _ => split
7279

@@ -388,6 +395,16 @@ class Desugarer(tl: TraceLogger, elaborator: Elaborator)(using raise: Raise, sta
388395
// Raise an error and discard `sequel`. Use `fallback` instead.
389396
raise(ErrorReport(msg"Unknown symbol `${ctor.name}`." -> ctor.toLoc :: Nil))
390397
fallback
398+
case Tree.Tup(args) => fallback => ctx => trace(
399+
pre = s"expandMatch <<< ${args.mkString(", ")}",
400+
post = (r: Split) => s"expandMatch >>> ${r.showDbg}"
401+
):
402+
val params = scrutSymbol.getSubScrutinees(args.length)
403+
Branch(
404+
ref,
405+
Pattern.Tuple(params),
406+
subMatches(params zip args, sequel)(Split.End)(ctx)
407+
) ~: fallback
391408
// A single constructor pattern.
392409
case pat @ App(ctor: Ident, Tup(args)) => fallback => ctx => trace(
393410
pre = s"expandMatch <<< ${ctor.name}(${args.iterator.map(_.showDbg).mkString(", ")})",

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -591,14 +591,14 @@ extends Importer:
591591
case id @ Ident(name) =>
592592
val sym = boundVars.getOrElseUpdate(name, VarSymbol(id, nextUid))
593593
Pattern.Var(sym)
594-
case Tup(fields) =>
595-
val pats = fields.map(
596-
f => pattern(f) match
597-
case (pat, vars) =>
598-
boundVars ++= vars
599-
pat
600-
)
601-
Pattern.Tuple(pats)
594+
// case Tup(fields) =>
595+
// val pats = fields.map(
596+
// f => pattern(f) match
597+
// case (pat, vars) =>
598+
// boundVars ++= vars
599+
// pat
600+
// )
601+
// Pattern.Tuple(pats)
602602
case _ =>
603603
???
604604
(go(t), boundVars.toList)

hkmc2/shared/src/main/scala/hkmc2/semantics/Pattern.scala

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,36 +2,40 @@ package hkmc2
22
package semantics
33

44
import mlscript.utils.*, shorthands.*
5-
import syntax.*
5+
import syntax.*, Tree.Ident
66

77

8-
enum Pattern extends Located:
8+
enum Pattern extends AutoLocated:
99
case Lit(literal: Literal)
10-
case Var(nme: BlockLocalSymbol)
11-
case Class(nme: ClassSymbol, parameters: Opt[List[BlockLocalSymbol]], var refined: Bool)(val ident: Tree.Ident)
12-
case Tuple(fields: List[Pattern])
13-
case Record(entries: List[(VarSymbol -> Pattern)])
10+
case Var(sym: BlockLocalSymbol)
11+
case Class(sym: ClassSymbol, params: Opt[List[BlockLocalSymbol]], var refined: Bool)(val ident: Ident)
12+
case Tuple(fields: List[BlockLocalSymbol])
13+
case Record(entries: List[(Ident -> BlockLocalSymbol)])
1414

15-
def toLoc: Opt[Loc] = this match
16-
case Lit(literal) => literal.toLoc
17-
case pat @ Class(_, _, _) => pat.ident.toLoc
15+
protected def children: List[Located] =
16+
this match
17+
case Lit(literal) => literal :: Nil
18+
case Var(sym) => sym :: Nil
19+
case Class(sym, params, _) => sym :: params.getOrElse(Nil)
20+
case Tuple(fields) => fields
21+
case Record(entries) => entries.flatMap:
22+
(key, pat) => key :: pat :: Nil
1823

1924
def subTerms: Ls[Term] = this match
20-
case Lit(literal) => Nil
21-
case Var(nme) => Nil
22-
case Class(_, parameters, _) => Nil
23-
case Tuple(fields) => fields.flatMap(_.subTerms)
24-
case Record(entries) => entries.flatMap(_._2.subTerms)
25+
case Lit(_) => Nil
26+
case Var(_) => Nil
27+
case Class(_, _, _) => Nil
28+
case Tuple(_) => Nil
29+
case Record(_) => Nil
2530

2631
def showDbg: Str = this match
2732
case Lit(literal) => literal.idStr
28-
case Var(nme) => nme.toString
29-
case Class(sym, ps, rfd) => (if rfd then "refined " else "") + (ps match {
30-
case N => sym.nme
31-
case S(parameters) => parameters.mkString(s"${sym.nme}(", ", ", ")")
32-
})
33-
case Tuple(fields) => fields.mkString("(", ", ", ")")
33+
case Var(sym) => sym.nme
34+
case Class(sym, ps, rfd) => (if rfd then "refined " else "") +
35+
sym.nme + ps.fold("")(_.mkString("(", ", ", ")"))
36+
case Tuple(fields) => fields.mkString("[", ", ", "]")
3437
case Record(Nil) => "{}"
35-
case Record(entries) => entries.iterator.map { case (nme, als) => s"$nme: $als" }.mkString("{ ", ", ", " }")
38+
case Record(entries) =>
39+
entries.iterator.map(_.name + ": " + _).mkString("{ ", ", ", " }")
3640

3741

hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Normalization.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class Normalization(tl: TraceLogger)(using raise: Raise):
4747
def =:=(rhs: Pattern): Bool = (lhs, rhs) match
4848
case (Pattern.Class(s1, _, _), Pattern.Class(s2, _, _)) => s1 === s2
4949
case (Pattern.Lit(l1), Pattern.Lit(l2)) => l1 === l2
50+
case (Pattern.Tuple(fs1), Pattern.Tuple(fs2)) => fs1.length === fs2.length
5051
case (_, _) => false
5152
/** Checks if `self` can be subsumed under `rhs`. */
5253
def <:<(rhs: Pattern): Bool =
@@ -89,7 +90,7 @@ class Normalization(tl: TraceLogger)(using raise: Raise):
8990
case Pattern.Var(vs) =>
9091
log(s"ALIAS: $scrutinee is $vs")
9192
Split.Let(vs, scrutinee, rec(consequent ++ alternative))
92-
case pattern @ (Pattern.Lit(_) | Pattern.Class(_, _, _)) =>
93+
case pattern @ (Pattern.Lit(_) | Pattern.Class(_, _, _) | Pattern.Tuple(_)) =>
9394
log(s"MATCH: $scrutinee is $pattern")
9495
val whenTrue = normalize(specialize(consequent ++ alternative, +, scrutinee, pattern))
9596
val whenFalse = rec(specialize(alternative, -, scrutinee, pattern).clearFallback)
@@ -145,7 +146,7 @@ class Normalization(tl: TraceLogger)(using raise: Raise):
145146
log(s"Case 1.1.1: $pattern =:= $thatPattern")
146147
thatPattern reportInconsistentRefinedWith pattern
147148
aliasBindings(pattern, thatPattern)(rec(continuation) ++ rec(tail))
148-
else if (thatPattern <:< pattern) then
149+
else if thatPattern <:< pattern then
149150
log(s"Case 1.1.2: $pattern <:< $thatPattern")
150151
pattern.markAsRefined; split
151152
else if split.isFallback then
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
:js
2+
:sjs
3+
4+
let x = 0
5+
//│ JS:
6+
//│ this.x = 0; undefined
7+
//│ x = 0
8+
9+
:fixme
10+
let x' = 0
11+
//│ JS:
12+
//│ this.x' = 0; undefined
13+
//│ ═══[COMPILATION ERROR] [Uncaught SyntaxError] Unexpected string
14+
//│ ═══[COMPILATION ERROR] [Uncaught SyntaxError] Unexpected string

hkmc2/shared/src/test/mlscript/ucs/general/DualOptions.mls

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,25 +103,12 @@ add_5(None, Some(9))
103103
add_5(Some(5), Some(9))
104104

105105

106-
:fixme
107106
fun add_6(x, y) =
108107
if [x, y] is
109108
[Some(xv), Some(yv)] then xv + yv
110109
[Some(xv), None] then xv
111110
[None, Some(yv)] then yv
112111
[None, None] then 0
113-
//│ ╔══[ERROR] Unrecognized pattern.
114-
//│ ║ l.112: [None, None] then 0
115-
//│ ╙── ^^^^^^^^^^^^^^^^
116-
//│ ╔══[ERROR] Unrecognized pattern.
117-
//│ ║ l.111: [None, Some(yv)] then yv
118-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^
119-
//│ ╔══[ERROR] Unrecognized pattern.
120-
//│ ║ l.110: [Some(xv), None] then xv
121-
//│ ╙── ^^^^^^^^^^^^^^^^
122-
//│ ╔══[ERROR] Unrecognized pattern.
123-
//│ ║ l.109: [Some(xv), Some(yv)] then xv + yv
124-
//│ ╙── ^^^^^^^^^^^^^^^^^^^^
125112

126113
add_6(None, None)
127114
add_6(Some(5), None)

hkmc2/shared/src/test/mlscript/ucs/patterns/SimpleTuple.mls

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,72 +5,79 @@ fun sum(x, y) = x + y
55
sum(1, 2)
66

77
:fixme
8-
fun sum'([x, y]) = x + y
9-
sum'([1, 2])
8+
fun sum([x, y]) = x + y
9+
sum([1, 2])
1010
//│ /!!!\ Uncaught error: scala.MatchError: Tup(List(Ident(x), Ident(y))) (of class hkmc2.syntax.Tree$Tup)
1111

12-
:todo
12+
:ucs desugared
1313
fun sum''(pair) =
1414
if pair is [x, y] then x + y
1515
sum''([1, 2])
16-
//│ ╔══[ERROR] Unrecognized pattern.
17-
//│ ║ l.14: if pair is [x, y] then x + y
18-
//│ ╙── ^^^^^^
19-
20-
// We need native support for tuple patterns in MLscript syntax.
21-
// Otherwise the following cases work.
16+
//│ Desugared:
17+
//│ > if pair@31 is [$elem0@32, $elem1@33] and
18+
//│ > let x@34 = $elem0@32#0
19+
//│ > let y@35 = $elem1@33#0
20+
//│ > else .+#1(x@34#0, y@35#0)
2221

23-
:todo
22+
:ucs desugared
2423
fun test(thing) =
2524
if thing is [] then 0
2625
test("")
2726
test(12)
28-
//│ ╔══[ERROR] Unrecognized pattern.
29-
//│ ║ l.25: if thing is [] then 0
30-
//│ ╙── ^^
27+
//│ Desugared:
28+
//│ > if thing@40 is [] then 0
3129

32-
:todo
33-
// Since pattern destruction is desugared to let bindings, matching with other
34-
// classes after the tuple pattern will not work.
30+
:ucs desugared
3531
class Point(x: Int, y: Int)
3632
fun discarded_cases(thing) =
3733
if thing is
3834
[x, y] then x + y
3935
Point(x, y) then x + y
40-
//│ ╔══[ERROR] Unrecognized pattern.
41-
//│ ║ l.38: [x, y] then x + y
42-
//│ ╙── ^^^^^^
36+
//│ Desugared:
37+
//│ > if
38+
//│ > thing@45 is [$elem0@51, $elem1@52] and
39+
//│ > let x@53 = $elem0@51#0
40+
//│ > let y@54 = $elem1@52#0
41+
//│ > else .+#3(x@53#0, y@54#0)
42+
//│ > thing@45 is Point($param0@46, $param1@47) and
43+
//│ > let x@48 = $param0@46#0
44+
//│ > let y@49 = $param1@47#0
45+
//│ > else .+#2(x@48#0, y@49#0)
4346

4447
:e
4548
:todo
4649
discarded_cases(Point(0, 0))
4750

4851
// A workaround is to move the tuple pattern to the last case.
49-
:todo
52+
:ucs desugared
5053
fun working_cases(thing) =
5154
if thing is
5255
Point(x, y) then x + y
5356
[x, y] then x + y
54-
//│ ╔══[ERROR] Unrecognized pattern.
55-
//│ ║ l.53: [x, y] then x + y
56-
//│ ╙── ^^^^^^
57+
//│ Desugared:
58+
//│ > if
59+
//│ > thing@61 is Point($param0@67, $param1@68) and
60+
//│ > let x@69 = $param0@67#0
61+
//│ > let y@70 = $param1@68#0
62+
//│ > else .+#5(x@69#0, y@70#0)
63+
//│ > thing@61 is [$elem0@62, $elem1@63] and
64+
//│ > let x@64 = $elem0@62#0
65+
//│ > let y@65 = $elem1@63#0
66+
//│ > else .+#4(x@64#0, y@65#0)
5767

5868
working_cases(Point(0, 0))
5969

6070
// However, the `Object` type forbids tuples to be used.
6171
:todo
6272
working_cases([0, 0])
6373

64-
:todo
74+
6575
fun not_working(x) =
6676
if x is
6777
[a, b, c] then
6878
a + b + c
6979
else
7080
0
71-
//│ ╔══[ERROR] Unrecognized pattern.
72-
//│ ║ l.67: [a, b, c] then
73-
//│ ╙── ^^^^^^^^^
7481

7582
not_working([1, 2, 3])
7683

0 commit comments

Comments
 (0)