Skip to content

Commit c254ddc

Browse files
Adapt Elaborator and JSBuilder for multiple parameter lists
1 parent 254eb7f commit c254ddc

File tree

7 files changed

+104
-35
lines changed

7 files changed

+104
-35
lines changed

hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,11 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
448448
effBuff += eff
449449
nestCtx += sym -> rhsTy
450450
goStats(stats)
451-
case TermDefinition(Fun, sym, params, sig, Some(body), _) :: stats =>
452-
typeFunDef(sym, params match {
453-
case S(params) => Term.Lam(params, body)
454-
case _ => body // * may be a case expressions
455-
}, sig, ctx)
451+
case TermDefinition(Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _) :: stats =>
452+
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
453+
goStats(stats)
454+
case TermDefinition(Fun, sym, Nil, sig, Some(body), _) :: stats =>
455+
typeFunDef(sym, body, sig, ctx) // * may be a case expressions
456456
goStats(stats)
457457
case TermDefinition(Fun, sym, _, S(sig), None, _) :: stats =>
458458
ctx += sym -> typeType(sig)

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ sealed abstract class Defn:
7474
final case class TermDefn(
7575
k: syntax.TermDefKind,
7676
sym: TermSymbol,
77-
params: Opt[Ls[Param]],
77+
params: Ls[ParamList],
7878
body: Block,
7979
) extends Defn
8080

hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import hkmc2.semantics.Elaborator
1212
import hkmc2.syntax.Tree
1313
import hkmc2.semantics.TopLevelSymbol
1414
import hkmc2.semantics.MemberSymbol
15+
import hkmc2.semantics.ParamList
16+
import hkmc2.codegen.Value.Lam
1517

1618

1719
// TODO factor some logic for other codegen backends
@@ -99,11 +101,14 @@ class JSBuilder extends CodeBuilder:
99101
result(Value.This(sym))
100102
val (thisProxy, res) = scope.nestRebindThis(defn.sym):
101103
val defnJS = defn match
102-
case TermDefn(syntax.Fun, sym, N, body) =>
104+
case TermDefn(syntax.Fun, sym, Nil, body) =>
103105
TODO("getters")
104-
case TermDefn(syntax.Fun, sym, S(ps), bod) =>
105-
val vars = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ")
106-
doc"function ${sym.nme}($vars) { #{ # ${body(bod)} #} # }"
106+
case TermDefn(syntax.Fun, sym, ParamList(_, ps) :: pss, bod) =>
107+
val paramList = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ")
108+
val result = pss.foldRight(bod):
109+
case (ParamList(_, ps), block) =>
110+
Return(Lam(ps, block), false)
111+
doc"function ${sym.nme}(${paramList}) { #{ # ${body(result)} #} # }"
107112
case ClsDefn(sym, syntax.Cls, mtds, flds, ctor) =>
108113
val clsDefn = sym.defn.getOrElse(die)
109114
val clsParams = clsDefn.paramsOpt.getOrElse(Nil)
@@ -118,11 +123,12 @@ class JSBuilder extends CodeBuilder:
118123
}) { #{ # ${
119124
ctorCode.stripBreaks
120125
} #} # }${
121-
mtds.map: td =>
122-
val vars = td.params.getOrElse(Nil).map(p => scope.allocateName(p.sym)).mkDocument(", ")
123-
doc" # ${td.sym.nme}($vars) { #{ # ${
124-
body(td.body)
125-
} #} # }"
126+
mtds.map:
127+
case td @ TermDefn(_, _, ParamList(_, ps) :: Nil, _) =>
128+
val vars = ps.map(p => scope.allocateName(p.sym)).mkDocument(", ")
129+
doc" # ${td.sym.nme}($vars) { #{ # ${
130+
body(td.body)
131+
} #} # }"
126132
.mkDocument(" ")
127133
}${
128134
if mtds.exists(_.sym.nme == "toString")

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,12 +428,14 @@ extends Importer:
428428
case S(t) => typeParams(t)
429429
case N => (N, ctx)
430430
// Add parameters to context
431-
val (ps, newCtx) = td.paramLists.foldLeft((Ls[Param](), newCtx1)):
432-
case ((ps, ctx), t) => params(t)(using ctx).mapFirst(ps ++ _)
433-
.mapFirst(some)
431+
val (pss, newCtx) =
432+
td.paramLists.foldLeft(Ls[ParamList](), newCtx1):
433+
case ((pss, ctx), ps) =>
434+
val (qs, newCtx) = params(ps)(using ctx)
435+
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
434436
val b = rhs.map(term(_)(using newCtx))
435437
val r = FlowSymbol(s"‹result of ${sym}", nextUid)
436-
val tdf = TermDefinition(k, sym, ps,
438+
val tdf = TermDefinition(k, sym, pss,
437439
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r)
438440
sym.defn = S(tdf)
439441
tdf
@@ -592,8 +594,8 @@ extends Importer:
592594
def computeVariances(s: Statement): Unit =
593595
val trav = VarianceTraverser()
594596
def go(s: Statement): Unit = s match
595-
case TermDefinition(k, sym, ps, sign, body, r) =>
596-
ps.foreach(_.foreach(trav.traverseType(S(false))))
597+
case TermDefinition(k, sym, pss, sign, body, r) =>
598+
pss.foreach(ps => ps.params.foreach(trav.traverseType(S(false))))
597599
sign.foreach(trav.traverseType(S(true)))
598600
body match
599601
case S(b) =>

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ sealed trait Statement extends AutoLocated:
113113
case Assgn(lhs, rhs) => lhs :: rhs :: Nil
114114
case Deref(term) => term :: Nil
115115
case TermDefinition(k, _, ps, sign, body, res) =>
116-
ps.toList.flatMap(_.flatMap(_.subTerms)) ::: sign.toList ::: body.toList
116+
ps.toList.flatMap(_.subTerms) ::: sign.toList ::: body.toList
117117
case cls: ClassDef =>
118118
cls.paramsOpt.toList.flatMap(_.flatMap(_.subTerms)) ::: cls.body.blk :: Nil
119119
case td: TypeDef =>
@@ -175,7 +175,7 @@ sealed trait Statement extends AutoLocated:
175175
case Error => "<error>"
176176
case Tup(fields) => fields.map(_.showDbg).mkString("[", ", ", "]")
177177
case TermDefinition(k, sym, ps, sign, body, res) => s"${k.str} ${sym}${
178-
ps.fold("")(_.map(_.showDbg).mkString("(", ", ", ")"))
178+
ps.map(_.showDbg).mkString("")
179179
}${sign.fold("")(": "+_.showDbg)}${
180180
body match
181181
case S(x) => " = " + x.showDbg
@@ -194,7 +194,7 @@ final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement
194194
final case class TermDefinition(
195195
k: TermDefKind,
196196
sym: TermSymbol,
197-
params: Opt[Ls[Param]],
197+
params: Ls[ParamList],
198198
sign: Opt[Term],
199199
body: Opt[Term],
200200
resSym: FlowSymbol,
@@ -272,6 +272,17 @@ final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Op
272272

273273
object FldFlags { val empty: FldFlags = FldFlags(false, false, false) }
274274

275+
final case class ParamListFlags(ctx: Bool):
276+
def showDbg: Str = (if ctx then "ctx " else "")
277+
override def toString: String = "" + showDbg + ""
278+
279+
object ParamListFlags:
280+
val empty = ParamListFlags(false)
281+
282+
final case class ParamList(flags: ParamListFlags, params: Ls[Param]):
283+
def subTerms: Ls[Term] = params.flatMap(_.subTerms)
284+
def showDbg: Str = flags.showDbg + params.mkString("(", ", ", ")")
285+
275286
trait FldImpl extends AutoLocated:
276287
self: Fld =>
277288
def children: Ls[Located] = self.value :: self.asc.toList ::: Nil

hkmc2/shared/src/main/scala/hkmc2/typing/TypeChecker.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TypeChecker(using raise: Raise):
3030
ts.defn match
3131
case S(td: TermDefinition) =>
3232
td.params match
33-
case N => P.Flow(td.resSym)
33+
case Nil => P.Flow(td.resSym)
3434
case Blk(stats, res) =>
3535
// val p1 = stats.map(typeStat)
3636
// val p2 = typeProd(res)
@@ -40,7 +40,7 @@ class TypeChecker(using raise: Raise):
4040
stats.foreach:
4141
case t: TermDefinition =>
4242
t.sign.map(typeProd)
43-
t.params.map(typeParams)
43+
t.params.map(_.params).map(typeParams)
4444
t.body.map(typeProd)
4545
P.Ctor(LitSymbol(Tree.UnitLit(true)), Nil)
4646
case t: Term =>
@@ -57,10 +57,12 @@ class TypeChecker(using raise: Raise):
5757
ts.defn match
5858
case S(td: TermDefinition) =>
5959
td.params match
60-
case N =>
60+
case Nil =>
6161
val f = typeProd(r)
6262
constrain(P.exitIf(f, ts, r.refNum, rc), C.Fun(typeProd(tup), C.Flow(app.resSym)))
63-
case S(ps) =>
63+
case ParamList(_, ps) :: Nil =>
64+
// App applies to the leftmost parameter list
65+
// TODO: how to recursively check the subsequent Apps (if any)?
6466
if ps.size != args.size then
6567
raise(ErrorReport(
6668
msg"Expected ${ps.size.toString} arguments, but got ${
Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,64 @@
11
:js
2+
:sjs
23

4+
fun f(n1: Int): Int = n1
5+
//│ JS:
6+
//│ function f(n1) { return n1 }; undefined
37

4-
// FIXME elbaoration is currently wrong
8+
f(42)
9+
//│ JS:
10+
//│ this.f(42)
11+
//│ = 42
512

6-
:sjs
7-
fun foo(x)(y) = x * y
13+
fun f(n1: Int)(n2: Int): Int = (10 * n1 + n2)
814
//│ JS:
9-
//│ function foo(x, y) { return x * y }; undefined
15+
//│ function f(n1) { return (n2) => { let tmp; tmp = 10 * n1; return tmp + n2 } }; undefined
1016

11-
:sjs
12-
fun foo(x)(y)(z) = x * y + z
17+
f(4)(2)
1318
//│ JS:
14-
//│ function foo(x, y, z) { let tmp; tmp = x * y; return tmp + z }; undefined
19+
//│ let tmp; tmp = this.f(4); tmp(2)
20+
//│ = 42
1521

22+
fun f(n1: Int)(n2: Int)(n3: Int): Int = 10 * (10 * n1 + n2) + n3
23+
//│ JS:
24+
//│ function f(n1) {
25+
//│ return (n2) => {
26+
//│ return (n3) => {
27+
//│ let tmp, tmp1, tmp2;
28+
//│ tmp = 10 * n1;
29+
//│ tmp1 = tmp + n2;
30+
//│ tmp2 = 10 * tmp1;
31+
//│ return tmp2 + n3
32+
//│ }
33+
//│ }
34+
//│ };
35+
//│ undefined
1636

37+
f(4)(2)(0)
38+
//│ JS:
39+
//│ let tmp, tmp1; tmp = this.f(4); tmp1 = tmp(2); tmp1(0)
40+
//│ = 420
41+
42+
fun f(n1: Int)(n2: Int)(n3: Int)(n4: Int): Int = 10 * (10 * (10 * n1 + n2) + n3) + n4
43+
//│ JS:
44+
//│ function f(n1) {
45+
//│ return (n2) => {
46+
//│ return (n3) => {
47+
//│ return (n4) => {
48+
//│ let tmp, tmp1, tmp2, tmp3, tmp4;
49+
//│ tmp = 10 * n1;
50+
//│ tmp1 = tmp + n2;
51+
//│ tmp2 = 10 * tmp1;
52+
//│ tmp3 = tmp2 + n3;
53+
//│ tmp4 = 10 * tmp3;
54+
//│ return tmp4 + n4
55+
//│ }
56+
//│ }
57+
//│ }
58+
//│ };
59+
//│ undefined
60+
61+
f(3)(0)(3)(1)
62+
//│ JS:
63+
//│ let tmp, tmp1, tmp2; tmp = this.f(3); tmp1 = tmp(0); tmp2 = tmp1(3); tmp2(1)
64+
//│ = 3031

0 commit comments

Comments
 (0)