Skip to content

Commit 70bf285

Browse files
Enforce rules on functions returning modules
1 parent 76170a9 commit 70bf285

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,47 @@ extends Importer:
489489
case ((pss, ctx), ps) =>
490490
val (qs, newCtx) = params(ps)(using ctx)
491491
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
492+
// * Elaborate signature
493+
val st = td.signature.orElse(newSignatureTrees.get(id.name))
494+
val s = st.map(term(_)(using newCtx))
492495
val b = rhs.map(term(_)(using newCtx))
493496
val r = FlowSymbol(s"‹result of ${sym}", nextUid)
494-
val tdf = TermDefinition(owner, k, sym, pss,
495-
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
496-
TermDefFlags(isModMember))
497+
val tdf = TermDefinition(owner, k, sym, pss, s, b, r,
498+
TermDefFlags.empty.copy(mod = isModMember))
497499
sym.defn = S(tdf)
500+
501+
// the return type of the function
502+
val result = td.head match
503+
case InfixApp(_, Keyword.`:`, rhs) => S(term(rhs)(using newCtx))
504+
case _ => N
505+
506+
// indicates if the function really returns a module
507+
val em = b.fold(false)(ModuleChecker.evalsToModule)
508+
// indicates if the function marks its result as "module"
509+
val mm = st match
510+
case Some(TypeDef(Mod, _, N, N)) => true
511+
case _ => false
512+
513+
// checks rules regarding module methods
514+
s match
515+
case N if em => raise:
516+
ErrorReport:
517+
msg"Function returning module values must have explicit return types." ->
518+
td.head.toLoc :: Nil
519+
case S(t) if em && ModuleChecker.isTypeParam(t) => raise:
520+
ErrorReport:
521+
msg"Function returning module values must have concrete return types." ->
522+
td.head.toLoc :: Nil
523+
case S(_) if em && !mm => raise:
524+
ErrorReport:
525+
msg"The return type of functions returning module values must be prefixed with module keyword." ->
526+
td.head.toLoc :: Nil
527+
case S(_) if mm && !isModMember => raise:
528+
ErrorReport:
529+
msg"Only module methods may return module values." ->
530+
td.head.toLoc :: Nil
531+
case _ => ()
532+
498533
tdf
499534
go(sts, tdf :: acc)
500535
case L(d) =>
@@ -697,6 +732,21 @@ extends Importer:
697732
.filter(_.isInstanceOf[VarSymbol])
698733
.flatMap(_.asInstanceOf[VarSymbol].decl)
699734
.fold(false)(_.isInstanceOf[TyParam])
735+
736+
/** Checks if a term evaluates to a module value. */
737+
def evalsToModule(t: Term): Bool =
738+
def returnsModule(t: TermDef): Bool = t.signature match
739+
case S(TypeDef(Mod, _, N, N)) => true
740+
case _ => false
741+
t match
742+
case Term.Blk(_, res) => evalsToModule(res)
743+
case Term.App(lhs, rhs) => evalsToModule(lhs)
744+
case t => t.symbol match
745+
case S(sym: BlockMemberSymbol) => sym match
746+
case _ if sym.modTree.isDefined => true
747+
case _ if sym.trmTree.isDefined => returnsModule(sym.trmTree.get)
748+
case _ => false
749+
case _ => false
700750

701751
class VarianceTraverser(var changed: Bool = true) extends Traverser:
702752
override def traverseType(pol: Pol)(trm: Term): Unit = trm match

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,25 @@ fun f2[T](module m: T)
2626
module N with {
2727
fun f3(): M = M
2828
}
29-
//│ FAILURE: Unexpected lack of type error
29+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
30+
//│ ║ l.27: fun f3(): M = M
31+
//│ ╙── ^^^^^^^
3032

3133
:e
3234
module N with {
3335
fun f4[T](): module T = M
3436
}
35-
//│ FAILURE: Unexpected lack of type error
37+
//│ ╔══[ERROR] Function returning module values must have explicit return types.
38+
//│ ║ l.35: fun f4[T](): module T = M
39+
//│ ╙── ^^^^^^^^^^^^^^^^^
3640

3741
:e
3842
module N with {
3943
fun f5(): M = M
4044
}
41-
//│ FAILURE: Unexpected lack of type error
45+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
46+
//│ ║ l.43: fun f5(): M = M
47+
//│ ╙── ^^^^^^^
4248

4349

4450
fun f6(m: M)
@@ -53,7 +59,9 @@ f6(M.self())
5359

5460
:e
5561
fun f7(): module M
56-
//│ FAILURE: Unexpected lack of type error
62+
//│ ╔══[ERROR] Only module methods may return module values.
63+
//│ ║ l.61: fun f7(): module M
64+
//│ ╙── ^^^^^^^^^^^^^^
5765

5866

5967
fun ok1(module m: M)

0 commit comments

Comments
 (0)