Skip to content

Commit 8927be7

Browse files
authored
Merge pull request #194 from chengluyu/pretyper
New UCS desugarer and rudimentary `PreTyper`
2 parents a1de9b6 + f8cd23f commit 8927be7

File tree

168 files changed

+13318
-5850
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

168 files changed

+13318
-5850
lines changed

compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,14 @@ class ClassLifter(logDebugMsg: Boolean = false) {
227227
}
228228

229229
private def liftCaseBranch(brn: CaseBranches)(using ctx: LocalContext, cache: ClassCache, globFuncs: Map[Var, (Var, LocalContext)], outer: Option[ClassInfoCache]): (CaseBranches, LocalContext) = brn match{
230-
case Case(v: Var, body, rest) =>
230+
case k @ Case(v: Var, body, rest) =>
231231
val nTrm = liftTerm(body)(using ctx.addV(v))
232232
val nRest = liftCaseBranch(rest)
233-
(Case(v, nTrm._1, nRest._1), nTrm._2 ++ nRest._2)
234-
case Case(pat, body, rest) =>
233+
(Case(v, nTrm._1, nRest._1)(k.refined), nTrm._2 ++ nRest._2)
234+
case k @ Case(pat, body, rest) =>
235235
val nTrm = liftTerm(body)
236236
val nRest = liftCaseBranch(rest)
237-
(Case(pat, nTrm._1, nRest._1), nTrm._2 ++ nRest._2)
237+
(Case(pat, nTrm._1, nRest._1)(k.refined), nTrm._2 ++ nRest._2)
238238
case Wildcard(body) =>
239239
val nTrm = liftTerm(body)
240240
(Wildcard(nTrm._1), nTrm._2)

compiler/shared/test/diff/Defunctionalize/Lambdas.mls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
//│ Code(List(main$$1()))
4242
//│ }
4343
//│ class Lambda1$1$1()
44-
//│ fun apply$Lambda1$1$1: (anything, Object) -> Bool
44+
//│ fun apply$Lambda1$1$1: (anything, Bool) -> Bool
4545
//│ fun main$$1: () -> Bool
4646
//│ Bool
4747
//│ res

compiler/shared/test/diff/Defunctionalize/Modules.mls

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ x.y.f
2020
//│ fun main$$2 = () => let obj = let obj = x$2 in if obj is ‹(x$2) then y$x$2(obj,); else error› in if obj is ‹(Foo$1) then f$Foo$1(obj,); else error›
2121
//│ Code(List(main$$2()))
2222
//│ }
23-
//│ ╔══[WARNING] Found a redundant else branch
24-
//│ ╙──
23+
//│ ╔══[WARNING] the outer binding `x$2`
24+
//│ ╙── is shadowed by name pattern `x$2`
25+
//│ ╔══[WARNING] this case is unreachable
26+
//│ ╙── because it is subsumed by the branch
2527
//│ module x$2
2628
//│ class Foo$1()
2729
//│ let y$x$2: anything -> Foo$1

compiler/shared/test/diff/Defunctionalize/OldMonoList.mls

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
:NewDefs
22
:AllowRuntimeErrors
33

4+
// FIXME
45
:mono
56
class List(e: Int, tail: List | Nil) {
67
fun map: (Int -> Int) -> List
@@ -42,19 +43,7 @@ fun add2(x) = x+2
4243
//│ fun apply$Lambda1$2$3 = (this, x,) => +(x, 1,)
4344
//│ Code(List(main$$5()))
4445
//│ }
45-
//│ class Lambda1$3$4()
46-
//│ class Nil$2()
47-
//│ class List$1(e: Int, tail: List$1 | Nil$2)
48-
//│ class Lambda1$2$3()
49-
//│ fun map$List$1: (Object, Object) -> List$1
50-
//│ fun add2$1: Int -> Int
51-
//│ fun main$$5: () -> List$1
52-
//│ fun apply$Lambda1$3$4: (anything, Int) -> Int
53-
//│ fun map$Nil$2: forall 'a. ('a & (List$1 | Nil$2), anything) -> (Nil$2 | 'a)
54-
//│ fun apply$Lambda1$2$3: (anything, Int) -> Int
55-
//│ List$1
56-
//│ res
57-
//│ = List$1 {}
46+
//│ /!!!\ Uncaught error: java.lang.Exception: Internal Error: the `if` expression has already been desugared, please make sure that the objects are copied
5847

5948
:mono
6049
class List(e: Int, tail: List | Nil) {

compiler/shared/test/diff/Defunctionalize/SimpleFunc.mls

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
:NewDefs
22

33
:mono
4-
fun f(x: Int) = if x then 42 else 1337
4+
fun f(x: Bool) = if x then 42 else 1337
55
//│ Lifted:
66
//│ TypingUnit {
7-
//│ fun f$1 = (x: Int,) => if (x) then 42 else 1337
7+
//│ fun f$1 = (x: Bool,) => if (x) then 42 else 1337
88
//│ }
99
//│ Mono:
1010
//│ TypingUnit {
11-
//│ fun f$1 = (x: Int,) => if (x) then 42 else 1337
11+
//│ fun f$1 = (x: Bool,) => if (x) then 42 else 1337
1212
//│ }
13-
//│ fun f$1: (x: Int) -> (1337 | 42)
13+
//│ fun f$1: (x: Bool) -> (1337 | 42)
1414

1515
:mono
1616
fun foo() = 42

js/src/main/scala/Main.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ object Main {
159159
|""".stripMargin
160160

161161
val backend = new JSWebBackend()
162-
val (lines, resNames) = backend(pgrm, newDefs = true)
162+
val (lines, resNames) = backend(pgrm)
163163
val code = lines.mkString("\n")
164164

165165
// TODO: add a toggle button to show js code

shared/src/main/scala/mlscript/Diagnostic.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ object Diagnostic {
1919
sealed abstract class Source
2020
case object Lexing extends Source
2121
case object Parsing extends Source
22+
case object PreTyping extends Source
23+
case object Desugaring extends Source
2224
case object Typing extends Source
2325
case object Compilation extends Source
2426
case object Runtime extends Source

shared/src/main/scala/mlscript/JSBackend.scala

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,13 @@ abstract class JSBackend {
194194
private def desugarQuotedBranch(branch: CaseBranches)(
195195
implicit scope: Scope, isQuoted: Bool, freeVars: FreeVars
196196
): Either[Term, CaseBranches] = branch match {
197-
case Case(pat, body, rest) =>
197+
case cse @ Case(pat, body, rest) =>
198198
val dp = desugarQuote(pat)
199199
val db = desugarQuote(body)
200200
desugarQuotedBranch(rest) match {
201201
case L(t) => L(createASTCall("Case", dp :: db :: t :: Nil))
202202
case R(b) => dp match {
203-
case dp: SimpleTerm => R(Case(dp, db, b))
203+
case dp: SimpleTerm => R(Case(dp, db, b)(cse.refined))
204204
case _ => die
205205
}
206206
}
@@ -553,6 +553,14 @@ abstract class JSBackend {
553553
pat match {
554554
case Var("int") =>
555555
JSInvoke(JSField(JSIdent("Number"), "isInteger"), scrut :: Nil)
556+
case Var("Int") if !oldDefs =>
557+
JSInvoke(JSField(JSIdent("Number"), "isInteger"), scrut :: Nil)
558+
case Var("Num") if !oldDefs =>
559+
JSBinary("===", scrut.typeof(), JSLit(JSLit.makeStringLiteral("number")))
560+
case Var("Bool") if !oldDefs =>
561+
JSBinary("===", scrut.typeof(), JSLit(JSLit.makeStringLiteral("boolean")))
562+
case Var("Str") if !oldDefs =>
563+
JSBinary("===", scrut.typeof(), JSLit(JSLit.makeStringLiteral("string")))
556564
case Var("bool") =>
557565
JSBinary("===", scrut.member("constructor"), JSLit("Boolean"))
558566
case Var(s @ ("true" | "false")) =>
@@ -572,8 +580,7 @@ abstract class JSBackend {
572580
case _ => throw new CodeGenError(s"unknown match case: $name")
573581
}
574582
}
575-
case lit: Lit =>
576-
JSBinary("===", scrut, translateTerm(lit))
583+
case lit: Lit => JSBinary("===", scrut, translateTerm(lit))
577584
},
578585
_,
579586
_
@@ -1342,7 +1349,7 @@ abstract class JSBackend {
13421349
}
13431350

13441351
class JSWebBackend extends JSBackend {
1345-
def oldDefs = false
1352+
override def oldDefs: Bool = false
13461353

13471354
// Name of the array that contains execution results
13481355
val resultsName: Str = topLevelScope declareRuntimeSymbol "results"
@@ -1468,8 +1475,8 @@ class JSWebBackend extends JSBackend {
14681475
(JSImmEvalFn(N, Nil, R(polyfill.emit() ::: stmts ::: epilogue), Nil).toSourceCode.toLines, resultNames.toList)
14691476
}
14701477

1471-
def apply(pgrm: Pgrm, newDefs: Bool): (Ls[Str], Ls[Str]) =
1472-
if (newDefs) generateNewDef(pgrm) else generate(pgrm)
1478+
def apply(pgrm: Pgrm): (Ls[Str], Ls[Str]) =
1479+
if (!oldDefs) generateNewDef(pgrm) else generate(pgrm)
14731480
}
14741481

14751482
abstract class JSTestBackend extends JSBackend {

shared/src/main/scala/mlscript/MLParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class MLParser(origin: Origin, indent: Int = 0, recordLocations: Bool = true) {
140140
def matchArms[p: P](sep: Str): P[CaseBranches] = P(
141141
( ("_" ~ "->" ~ term).map(Wildcard)
142142
| ((lit | variable) ~ "->" ~ term ~ matchArms2(sep))
143-
.map { case (t, b, rest) => Case(t, b, rest) }
143+
.map { case (t, b, rest) => Case(t, b, rest)(refined = false) }
144144
).?.map {
145145
case None => NoCases
146146
case Some(b) => b

shared/src/main/scala/mlscript/NewParser.scala

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,10 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
794794
Bra(false, elt)
795795
case (Round, _) =>
796796
yeetSpaces match {
797+
case (KEYWORD(opStr @ "=>"), l1) :: (NEWLINE, l2) :: _ /* if opPrec(opStr)._1 > prec */ =>
798+
consume
799+
val rhs = Blk(typingUnit.entities)
800+
Lam(Tup(res), rhs)
797801
case (KEYWORD("=>"), l1) :: _ =>
798802
consume
799803
val e = expr(NewParser.opPrec("=>")._2)
@@ -1245,14 +1249,26 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
12451249
R(res)
12461250
}
12471251
case L(rhs) =>
1248-
L(IfOpsApp(acc, opIfBlock(opv -> rhs :: Nil)))
1252+
val (opsRhss, els) = opIfBlock(opv -> rhs :: Nil)
1253+
val opsApp = IfOpsApp(acc, opsRhss)
1254+
L(els.fold[IfBody](opsApp)(trm => IfBlock(L(opsApp) :: L(IfElse(trm)) :: Nil)))
12491255
}
12501256
}
1251-
final def opIfBlock(acc: Ls[Var -> IfBody])(implicit et: ExpectThen, fe: FoundErr): Ls[Var -> IfBody] = wrap(acc) { l =>
1257+
final def opIfBlock(acc: Ls[Var -> IfBody])(implicit et: ExpectThen, fe: FoundErr): (Ls[Var -> IfBody], Opt[Term]) = wrap(acc) { l =>
12521258
cur match {
12531259
case (NEWLINE, _) :: c => // TODO allow let bindings...
12541260
consume
12551261
c match {
1262+
case (IDENT("_", false), wcLoc) :: _ =>
1263+
exprOrIf(0) match {
1264+
case R(rhs) =>
1265+
err(msg"expect an operator branch" -> S(wcLoc) :: Nil)
1266+
(acc.reverse, N)
1267+
case L(IfThen(_, els)) => (acc.reverse, S(els))
1268+
case L(rhs) =>
1269+
err(msg"expect 'then' after the wildcard" -> rhs.toLoc :: Nil)
1270+
(acc.reverse, N)
1271+
}
12561272
case (IDENT(opStr2, true), opLoc2) :: _ =>
12571273
consume
12581274
val rhs = exprOrIf(0)
@@ -1262,12 +1278,23 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
12621278
case L(rhs) =>
12631279
opIfBlock(Var(opStr2).withLoc(S(opLoc2)) -> rhs :: acc)
12641280
}
1265-
case _ =>
1281+
case (KEYWORD("else"), elseLoc) :: tail =>
1282+
consume
1283+
exprOrIf(0) match {
1284+
case R(rhs) => (acc.reverse, S(rhs))
1285+
case L(rhs) =>
1286+
err(msg"expect a term" -> rhs.toLoc :: Nil)
1287+
(acc.reverse, N)
1288+
}
1289+
case (_, headLoc) :: _ =>
12661290
// printDbg(c)
1267-
???
1291+
err(msg"expect an operator" -> S(headLoc) :: Nil)
1292+
(acc.reverse, N)
1293+
case Nil =>
1294+
(acc.reverse, N)
12681295
}
12691296
case _ =>
1270-
acc.reverse
1297+
(acc.reverse, N)
12711298
}
12721299
}
12731300

shared/src/main/scala/mlscript/Typer.scala

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import mlscript.Message._
1515
* In order to turn the resulting CompactType into a mlscript.Type, we use `expandCompactType`.
1616
*/
1717
class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val newDefs: Bool)
18-
extends ucs.Desugarer with TypeSimplifier {
18+
extends TypeDefs with TypeSimplifier {
1919

2020
def funkyTuples: Bool = false
2121
def doFactorize: Bool = false
@@ -385,6 +385,9 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
385385
"sub" -> intBinOpTy,
386386
"mul" -> intBinOpTy,
387387
"div" -> intBinOpTy,
388+
"numAdd" -> numberBinOpTy,
389+
"numSub" -> numberBinOpTy,
390+
"numMul" -> numberBinOpTy,
388391
"sqrt" -> fun(singleTup(IntType), IntType)(noProv),
389392
"lt" -> numberBinPred,
390393
"le" -> numberBinPred,
@@ -418,6 +421,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
418421
"*." -> numberBinOpTy,
419422
"%" -> intBinOpTy,
420423
"/" -> numberBinOpTy,
424+
"**" -> numberBinOpTy,
421425
"<" -> numberBinPred,
422426
">" -> numberBinPred,
423427
"<=" -> numberBinPred,
@@ -802,7 +806,8 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
802806
}
803807

804808
// TODO also prevent rebinding of "not"
805-
val reservedVarNames: Set[Str] = Set("|", "&", "~", "neg", "and", "or", "is")
809+
val reservedVarNames: Set[Str] =
810+
Set("|", "&", "~", "neg", "and", "or", "is", "refined")
806811

807812
object ValidVar {
808813
def unapply(v: Var)(implicit raise: Raise): S[Str] = S {
@@ -1180,21 +1185,29 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
11801185
val argProv = tp(args.toLoc, "argument list")
11811186
con(new_ty, FunctionType(typeTerm(args).withProv(argProv), res)(noProv), res)
11821187
case App(App(Var("is"), _), _) => // * Old-style operators
1183-
val desug = If(IfThen(term, Var("true")), S(Var("false")))
1184-
term.desugaredTerm = S(desug)
1185-
typeTerm(desug)
1188+
typeTerm(term.desugaredTerm.getOrElse {
1189+
val desug = If(IfThen(term, Var("true")), S(Var("false")))
1190+
term.desugaredTerm = S(desug)
1191+
desug
1192+
})
11861193
case App(Var("is"), _) =>
1187-
val desug = If(IfThen(term, Var("true")), S(Var("false")))
1188-
term.desugaredTerm = S(desug)
1189-
typeTerm(desug)
1194+
typeTerm(term.desugaredTerm.getOrElse {
1195+
val desug = If(IfThen(term, Var("true")), S(Var("false")))
1196+
term.desugaredTerm = S(desug)
1197+
desug
1198+
})
11901199
case App(App(Var("and"), PlainTup(lhs)), PlainTup(rhs)) => // * Old-style operators
1191-
val desug = If(IfThen(lhs, rhs), S(Var("false")))
1192-
term.desugaredTerm = S(desug)
1193-
typeTerm(desug)
1200+
typeTerm(term.desugaredTerm.getOrElse {
1201+
val desug = If(IfThen(lhs, rhs), S(Var("false")))
1202+
term.desugaredTerm = S(desug)
1203+
desug
1204+
})
11941205
case App(Var("and"), PlainTup(lhs, rhs)) =>
1195-
val desug = If(IfThen(lhs, rhs), S(Var("false")))
1196-
term.desugaredTerm = S(desug)
1197-
typeTerm(desug)
1206+
typeTerm(term.desugaredTerm.getOrElse {
1207+
val desug = If(IfThen(lhs, rhs), S(Var("false")))
1208+
term.desugaredTerm = S(desug)
1209+
desug
1210+
})
11981211
case App(f: Term, a @ Tup(fields)) if (fields.exists(x => x._1.isDefined)) =>
11991212
def getLowerBoundFunctionType(t: SimpleType): List[FunctionType] = t.unwrapProvs match {
12001213
case PolymorphicType(_, AliasOf(fun_ty @ FunctionType(_, _))) =>
@@ -1395,8 +1408,9 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
13951408
con(s_ty, req, cs_ty)
13961409
}
13971410
case elf: If =>
1398-
try typeTerm(desugarIf(elf)) catch {
1399-
case e: ucs.DesugaringException => err(e.messages)
1411+
elf.desugaredTerm match {
1412+
case S(desugared) => typeTerm(desugared)
1413+
case N => err(msg"not desugared UCS term found", elf.toLoc)
14001414
}
14011415
case AdtMatchWith(cond, arms) =>
14021416
println(s"typed condition term ${cond}")
@@ -1612,7 +1626,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
16121626
}
16131627
solveQuoteContext(ctx, newCtx)
16141628
res
1615-
case Case(pat, bod, rest) =>
1629+
case cse @ Case(pat, bod, rest) =>
16161630
val (tagTy, patTy) : (ST, ST) = pat match {
16171631
case lit: Lit =>
16181632
val t = ClassTag(lit,
@@ -1621,7 +1635,13 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
16211635
case v @ Var(nme) =>
16221636
val tpr = tp(pat.toLoc, "type pattern")
16231637
ctx.tyDefs.get(nme) match {
1624-
case None =>
1638+
case Some(td) if !newDefs =>
1639+
td.kind match {
1640+
case Als | Mod | Mxn => val t = err(msg"can only match on classes and traits", pat.toLoc)(raise); t -> t
1641+
case Cls => val t = clsNameToNomTag(td)(tp(pat.toLoc, "class pattern"), ctx); t -> t
1642+
case Trt => val t = trtNameToNomTag(td)(tp(pat.toLoc, "trait pattern"), ctx); t -> t
1643+
}
1644+
case _ =>
16251645
val bail = () => {
16261646
val e = ClassTag(ErrTypeId, Set.empty)(tpr)
16271647
return ((e -> e) :: Nil) -> e
@@ -1636,7 +1656,7 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
16361656
lti match {
16371657
case dti: DelayedTypeInfo =>
16381658
val tag = clsNameToNomTag(dti.decl match { case decl: NuTypeDef => decl; case _ => die })(prov, ctx)
1639-
val ty =
1659+
val ty = // TODO update as below for refined
16401660
RecordType.mk(dti.tparams.map {
16411661
case (tn, tv, vi) =>
16421662
val nv = freshVar(tv.prov, S(tv), tv.nameHint)
@@ -1647,7 +1667,8 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
16471667
tag -> ty
16481668
case CompletedTypeInfo(cls: TypedNuCls) =>
16491669
val tag = clsNameToNomTag(cls.td)(prov, ctx)
1650-
val ty =
1670+
println(s"CASE $tag ${cse.refined}")
1671+
val ty = if (cse.refined) freshVar(tp(v.toLoc, "refined scrutinee"), N) else
16511672
RecordType.mk(cls.tparams.map {
16521673
case (tn, tv, vi) =>
16531674
val nv = freshVar(tv.prov, S(tv), tv.nameHint)
@@ -1663,12 +1684,6 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
16631684
err("type identifier not found: " + nme, pat.toLoc)(raise)
16641685
bail()
16651686
}
1666-
case Some(td) =>
1667-
td.kind match {
1668-
case Als | Mod | Mxn => val t = err(msg"can only match on classes and traits", pat.toLoc)(raise); t -> t
1669-
case Cls => val t = clsNameToNomTag(td)(tp(pat.toLoc, "class pattern"), ctx); t -> t
1670-
case Trt => val t = trtNameToNomTag(td)(tp(pat.toLoc, "trait pattern"), ctx); t -> t
1671-
}
16721687
}
16731688
}
16741689
val newCtx = if (ctx.inQuote) ctx.enterQuotedScope else ctx.nest

0 commit comments

Comments
 (0)