Skip to content

Commit d5e89d0

Browse files
Enforce rules on module arguments
1 parent 00f7c33 commit d5e89d0

File tree

2 files changed

+54
-13
lines changed

2 files changed

+54
-13
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,34 @@ extends Importer:
232232
term(rhs)
233233
case tree @ App(lhs, rhs) =>
234234
val sym = FlowSymbol("‹app-res›", nextUid)
235-
Term.App(term(lhs), term(rhs))(tree, sym)
235+
val lt = term(lhs)
236+
val rt = term(rhs)
237+
238+
// Check if module arguments match module parameters
239+
val args = rt match
240+
case Term.Tup(fields) => S(fields)
241+
case _ => N
242+
val params = lt.symbol
243+
.collect:
244+
case sym: BlockMemberSymbol => sym.trmTree
245+
.flatten
246+
.collect:
247+
case td: TermDef => td.paramLists.headOption
248+
.flatten
249+
for
250+
(args, params) <- (args zip params)
251+
(arg, param) <- (args zip params.fields)
252+
do
253+
val argMod = arg.flags.mod
254+
val paramMod = param match
255+
case Tree.TypeDef(Mod, _, N, N) => true
256+
case _ => false
257+
if argMod && !paramMod then raise:
258+
ErrorReport:
259+
msg"Only module parameters may receive module arguments (values)." ->
260+
arg.toLoc :: Nil
261+
262+
Term.App(lt, rt)(tree, sym)
236263
case Sel(pre, nme) =>
237264
val preTrm = term(pre)
238265
val sym = resolveField(nme, preTrm.symbol, nme)
@@ -329,9 +356,10 @@ extends Importer:
329356
Fld(FldFlags.empty, term(lhs), S(term(rhs)))
330357
case _ =>
331358
val t = term(tree)
332-
t.symbol.flatMap(_.asMod) match
333-
case S(_) => Fld(FldFlags.empty.copy(mod = true), t, N)
334-
case N => Fld(FldFlags.empty, t, N)
359+
val flags = FldFlags.empty
360+
if ModuleChecker.evalsToModule(t)
361+
then Fld(flags.copy(mod = true), t, N)
362+
else Fld(flags, t, N)
335363

336364
def unit: Term.Lit = Term.Lit(UnitLit(true))
337365

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,53 +6,64 @@ class C with {
66
module M with {
77
// module method foo
88
fun foo(x) = x
9+
fun self(): module M = M
910
}
1011

1112
:e
1213
fun f1(module m)
1314
//│ ╔══[ERROR] Module parameters must have explicit types.
14-
//│ ║ l.12: fun f1(module m)
15+
//│ ║ l.13: fun f1(module m)
1516
//│ ╙── ^
1617

1718
:e
1819
fun f2[T](module m: T)
1920
//│ ╔══[ERROR] Module parameters must have concrete types.
20-
//│ ║ l.18: fun f2[T](module m: T)
21+
//│ ║ l.19: fun f2[T](module m: T)
2122
//│ ╙── ^^^^
2223

2324
:e
2425
module N with {
2526
fun f3(): M = M
2627
}
2728
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
28-
//│ ║ l.25: fun f3(): M = M
29+
//│ ║ l.26: fun f3(): M = M
2930
//│ ╙── ^^^^^^^
3031

3132
:e
3233
module N with {
3334
fun f4[T](): module T = M
3435
}
3536
//│ ╔══[ERROR] Function returning module values must have concrete return types.
36-
//│ ║ l.33: fun f4[T](): module T = M
37+
//│ ║ l.34: fun f4[T](): module T = M
3738
//│ ╙── ^^^^^^^^^^^^^^^^^
3839

3940
:e
4041
module N with {
4142
fun f5(): M = M
4243
}
4344
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
44-
//│ ║ l.41: fun f5(): M = M
45+
//│ ║ l.42: fun f5(): M = M
4546
//│ ╙── ^^^^^^^
4647

48+
49+
fun f6(m: M)
50+
4751
:e
48-
fun f6(module m: M)
49-
f6(new C)
50-
//│ FAILURE: Unexpected lack of type error
52+
f6(M)
53+
//│ ╔══[ERROR] Only module parameters may receive module arguments (values).
54+
//│ ║ l.52: f6(M)
55+
//│ ╙── ^
56+
57+
:e
58+
f6(M.self())
59+
//│ ╔══[ERROR] Only module parameters may receive module arguments (values).
60+
//│ ║ l.58: f6(M.self())
61+
//│ ╙── ^^^^^^^^
5162

5263
:e
5364
fun f7(): module M
5465
//│ ╔══[ERROR] Only module methods may return module values.
55-
//│ ║ l.53: fun f7(): module M
66+
//│ ║ l.64: fun f7(): module M
5667
//│ ╙── ^^^^^^^^^^^^^^
5768

5869

@@ -63,3 +74,5 @@ module N with {
6374
}
6475

6576
ok1(M)
77+
78+
ok1(M.self())

0 commit comments

Comments
 (0)