Skip to content

Commit b97a8a6

Browse files
committed
where keyword for pattern guards
1 parent b2eda2f commit b97a8a6

File tree

5 files changed

+46
-9
lines changed

5 files changed

+46
-9
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Desugarer.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ import mlscript.utils.*, shorthands.*
77
import Message.MessageContext
88
import utils.TraceLogger
99
import syntax.Literal
10-
import Keyword.{as, and, `do`, `else`, is, let, `then`}
10+
import Keyword.{as, and, `do`, `else`, is, let, `then`, where}
1111
import collection.mutable.{HashMap, SortedSet}
1212
import Elaborator.{ctx, Ctxl}
1313
import scala.annotation.targetName
1414
import hkmc2.semantics.ClassDef.Parameterized
15+
import hkmc2.codegen.Case.Lit
1516

1617
object Desugarer:
1718
extension (op: Keyword.Infix)
@@ -608,6 +609,11 @@ class Desugarer(val elaborator: Elaborator)
608609
case pattern and consequent => fallback => ctx =>
609610
val innerSplit = termSplit(consequent, identity)(Split.End)
610611
expandMatch(scrutSymbol, pattern, innerSplit)(fallback)(ctx)
612+
case pattern where condition => fallback => ctx =>
613+
val sym = TempSymbol(N, "conditionTemp")
614+
val newSequel = expandMatch(sym, Tree.BoolLit(true), sequel)(fallback)
615+
val newNewSequel = (ctx: Ctx) => Split.Let(sym, term(condition)(using ctx), newSequel(ctx))
616+
expandMatch(scrutSymbol, pattern, newNewSequel)(fallback)(ctx)
611617
case Jux(Ident(".."), Ident(_)) => fallback => _ =>
612618
raise(ErrorReport(msg"Illegal rest pattern." -> pattern.toLoc :: Nil))
613619
fallback

hkmc2/shared/src/main/scala/hkmc2/syntax/Keyword.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ object Keyword:
103103
val `object` = Keyword("object", N, curPrec)
104104
val `open` = Keyword("open", N, curPrec)
105105
val `type` = Keyword("type", N, N)
106-
val `where` = Keyword("where", N, N)
106+
val `where` = Keyword("where", curPrec, curPrec)
107107
val `forall` = Keyword("forall", N, N)
108108
val `exists` = Keyword("exists", N, N)
109109
val `null` = Keyword("null", N, N)
@@ -139,7 +139,7 @@ object Keyword:
139139
`abstract`, mut, virtual, `override`, declare, public, `private`)
140140

141141
type Infix = `and`.type | `or`.type | `then`.type | `else`.type | `is`.type | `:`.type | `->`.type |
142-
`=>`.type | `extends`.type | `restricts`.type | `as`.type | `do`.type
142+
`=>`.type | `extends`.type | `restricts`.type | `as`.type | `do`.type | `where`.type
143143

144144
type Ellipsis = `...`.type | `..`.type
145145

hkmc2/shared/src/main/scala/hkmc2/syntax/ParseRule.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ class ParseRules(using State):
396396
genInfixRule(`extends`, (rhs, _: Unit) => lhs => InfixApp(lhs, `extends`, rhs)),
397397
genInfixRule(`restricts`, (rhs, _: Unit) => lhs => InfixApp(lhs, `restricts`, rhs)),
398398
genInfixRule(`do`, (rhs, _: Unit) => lhs => InfixApp(lhs, `do`, rhs)),
399+
genInfixRule(`where`, (rhs, _: Unit) => lhs => InfixApp(lhs, `where`, rhs)),
399400
)
400401

401402
end ParseRules
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
:js
2+
3+
data class Pair[A, B](fst: A, snd: B)
4+
5+
6+
fun orderedPair(p) = p is
7+
Pair(a, b) where a <= b
8+
9+
orderedPair(Pair(4, 5))
10+
//│ = true
11+
12+
orderedPair(4)
13+
//│ = false
14+
15+
orderedPair(Pair(1, 1))
16+
//│ = true
17+
18+
orderedPair(Pair(2, 1))
19+
//│ = false
20+
21+
fun foo(p) = p is
22+
Pair(a, b) where
23+
let c = a
24+
a == 0
25+
26+
foo(Pair(5, 5))
27+
//│ = false
28+
29+
foo(Pair(0, 0))
30+
//│ = true
31+
32+
foo(Pair(0, 4))
33+
//│ = true
34+
35+
foo(Pair(4, 0))
36+
//│ = false

hkmc2/shared/src/test/mlscript/ucs/syntax/PlainConditionals.mls

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,7 @@ fun foo(x) = x is
5353

5454
fun foo(x) = x is Pair(a, b) | Int
5555

56-
:todo
5756
fun foo(x) = x is (Pair(a, b) where a > b) | Int
58-
//│ ╔══[PARSE ERROR] Unexpected 'where' keyword here
59-
//│ ║ l.57: fun foo(x) = x is (Pair(a, b) where a > b) | Int
60-
//│ ╙── ^^^^^
61-
62-
6357

6458
data class A[T](arg: T)
6559

0 commit comments

Comments
 (0)