Skip to content

Commit 00f7c33

Browse files
Enforce rules on functions returning modules
1 parent f58ab90 commit 00f7c33

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,41 @@ extends Importer:
489489
case ((pss, ctx), ps) =>
490490
val (qs, newCtx) = params(ps)(using ctx)
491491
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
492+
// * Elaborate signature
493+
val s = td.signature.orElse(newSignatureTrees.get(id.name)).map(term)
492494
val b = rhs.map(term(_)(using newCtx))
493495
val r = FlowSymbol(s"‹result of ${sym}", nextUid)
494-
val tdf = TermDefinition(owner, k, sym, pss,
495-
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
496-
TermDefFlags(mod))
496+
val tdf = TermDefinition(owner, k, sym, pss, s, b, r,
497+
TermDefFlags.empty.copy(mod = isModMember))
497498
sym.defn = S(tdf)
499+
500+
// the return type of the function
501+
val result = td.head match
502+
case InfixApp(_, Keyword.`:`, rhs) => S(term(rhs)(using newCtx))
503+
case _ => N
504+
// indicates if the function really returns a module
505+
val em = b.fold(false)(ModuleChecker.evalsToModule)
506+
507+
// checks rules regarding module methods
508+
result match
509+
case N if em => raise:
510+
ErrorReport:
511+
msg"Function returning module values must have explicit return types." ->
512+
td.head.toLoc :: Nil
513+
case S(ret) if em && ModuleChecker.isTypeParam(ret) => raise:
514+
ErrorReport:
515+
msg"Function returning module values must have concrete return types." ->
516+
td.head.toLoc :: Nil
517+
case S(ret) if em && !ret.isInstanceOf[Term.Mod] => raise:
518+
ErrorReport:
519+
msg"The return type of functions returning module values must be prefixed with module keyword." ->
520+
td.head.toLoc :: Nil
521+
case S(Term.Mod(_)) if !isModMember => raise:
522+
ErrorReport:
523+
msg"Only module methods may return module values." ->
524+
td.head.toLoc :: Nil
525+
case _ => ()
526+
498527
tdf
499528
go(sts, tdf :: acc)
500529
case L(d) =>
@@ -697,6 +726,23 @@ extends Importer:
697726
.filter(_.isInstanceOf[VarSymbol])
698727
.flatMap(_.asInstanceOf[VarSymbol].decl)
699728
.fold(false)(_.isInstanceOf[TyParam])
729+
730+
/** Checks if a term evaluates to a module value. */
731+
def evalsToModule(t: Term): Bool =
732+
def returnsModule(t: Tree): Bool = t match
733+
case InfixApp(_, Keyword.`:`, rhs) => rhs match
734+
case TypeDef(Mod, _, N, N) => true
735+
case _ => false
736+
case _ => false
737+
t match
738+
case Term.Blk(_, res) => evalsToModule(res)
739+
case Term.App(lhs, rhs) => evalsToModule(lhs)
740+
case t => t.symbol match
741+
case S(sym: BlockMemberSymbol) => sym match
742+
case _ if sym.modTree.isDefined => true
743+
case _ if sym.trmTree.isDefined => returnsModule(sym.trmTree.get.head)
744+
case _ => false
745+
case _ => false
700746

701747
class VarianceTraverser(var changed: Bool = true) extends Traverser:
702748
override def traverseType(pol: Pol)(trm: Term): Unit = trm match

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,25 @@ fun f2[T](module m: T)
2424
module N with {
2525
fun f3(): M = M
2626
}
27-
//│ FAILURE: Unexpected lack of type error
27+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
28+
//│ ║ l.25: fun f3(): M = M
29+
//│ ╙── ^^^^^^^
2830

2931
:e
3032
module N with {
3133
fun f4[T](): module T = M
3234
}
33-
//│ FAILURE: Unexpected lack of type error
35+
//│ ╔══[ERROR] Function returning module values must have concrete return types.
36+
//│ ║ l.33: fun f4[T](): module T = M
37+
//│ ╙── ^^^^^^^^^^^^^^^^^
3438

3539
:e
3640
module N with {
3741
fun f5(): M = M
3842
}
39-
//│ FAILURE: Unexpected lack of type error
43+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
44+
//│ ║ l.41: fun f5(): M = M
45+
//│ ╙── ^^^^^^^
4046

4147
:e
4248
fun f6(module m: M)
@@ -45,7 +51,9 @@ f6(new C)
4551

4652
:e
4753
fun f7(): module M
48-
//│ FAILURE: Unexpected lack of type error
54+
//│ ╔══[ERROR] Only module methods may return module values.
55+
//│ ║ l.53: fun f7(): module M
56+
//│ ╙── ^^^^^^^^^^^^^^
4957

5058

5159
fun ok1(module m: M)

0 commit comments

Comments
 (0)