Skip to content

Commit 189e6eb

Browse files
Enforce rules on functions returning modules
1 parent 5502030 commit 189e6eb

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,40 @@ extends Importer:
489489
case ((pss, ctx), ps) =>
490490
val (qs, newCtx) = params(ps)(using ctx)
491491
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
492+
// * Elaborate signature
493+
val s = td.signature.orElse(newSignatureTrees.get(id.name)).map(term)
492494
val b = rhs.map(term(_)(using newCtx))
493495
val r = FlowSymbol(s"‹result of ${sym}", nextUid)
494-
val tdf = TermDefinition(owner, k, sym, pss,
495-
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
496-
TermDefFlags(mod))
496+
val tdf = TermDefinition(owner, k, sym, pss, s, b, r, TermDefFlags(mod))
497497
sym.defn = S(tdf)
498+
499+
// the return type of the function
500+
val result = td.head match
501+
case InfixApp(_, Keyword.`:`, rhs) => S(term(rhs)(using newCtx))
502+
case _ => N
503+
// indicates if the function really returns a module
504+
val em = b.fold(false)(ModuleChecker.evalsToModule)
505+
506+
// checks rules regarding module methods
507+
result match
508+
case N if em => raise:
509+
ErrorReport:
510+
msg"Function returning module values must have explicit return types." ->
511+
td.head.toLoc :: Nil
512+
case S(ret) if em && ModuleChecker.isTypeParam(ret) => raise:
513+
ErrorReport:
514+
msg"Function returning module values must have concrete return types." ->
515+
td.head.toLoc :: Nil
516+
case S(ret) if em && !ret.isInstanceOf[Term.Mod] => raise:
517+
ErrorReport:
518+
msg"The return type of functions returning module values must be prefixed with module keyword." ->
519+
td.head.toLoc :: Nil
520+
case S(Term.Mod(_)) if !mod => raise:
521+
ErrorReport:
522+
msg"Only module methods may return module values." ->
523+
td.head.toLoc :: Nil
524+
case _ => ()
525+
498526
tdf
499527
go(sts, tdf :: acc)
500528
case L(d) =>
@@ -697,6 +725,23 @@ extends Importer:
697725
.filter(_.isInstanceOf[VarSymbol])
698726
.flatMap(_.asInstanceOf[VarSymbol].decl)
699727
.fold(false)(_.isInstanceOf[TyParam])
728+
729+
/** Checks if a term evaluates to a module value. */
730+
def evalsToModule(t: Term): Bool =
731+
def returnsModule(t: Tree): Bool = t match
732+
case InfixApp(_, Keyword.`:`, rhs) => rhs match
733+
case TypeDef(Mod, _, _, _) => true
734+
case _ => false
735+
case _ => false
736+
t match
737+
case Term.Blk(_, res) => evalsToModule(res)
738+
case Term.App(lhs, rhs) => evalsToModule(lhs)
739+
case t => t.symbol match
740+
case S(sym: BlockMemberSymbol) => sym match
741+
case _ if sym.modTree.isDefined => true
742+
case _ if sym.trmTree.isDefined => returnsModule(sym.trmTree.get.head)
743+
case _ => false
744+
case _ => false
700745

701746
class VarianceTraverser(var changed: Bool = true) extends Traverser:
702747
override def traverseType(pol: Pol)(trm: Term): Unit = trm match

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,25 @@ fun f2[T](module m: T)
2424
module N with {
2525
fun f3(): M = M
2626
}
27-
//│ FAILURE: Unexpected lack of type error
27+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
28+
//│ ║ l.25: fun f3(): M = M
29+
//│ ╙── ^^^^^^^
2830

2931
:e
3032
module N with {
3133
fun f4[T](): module T = M
3234
}
33-
//│ FAILURE: Unexpected lack of type error
35+
//│ ╔══[ERROR] Function returning module values must have concrete return types.
36+
//│ ║ l.33: fun f4[T](): module T = M
37+
//│ ╙── ^^^^^^^^^^^^^^^^^
3438

3539
:e
3640
module N with {
3741
fun f5(): M = M
3842
}
39-
//│ FAILURE: Unexpected lack of type error
43+
//│ ╔══[ERROR] The return type of functions returning module values must be prefixed with module keyword.
44+
//│ ║ l.41: fun f5(): M = M
45+
//│ ╙── ^^^^^^^
4046

4147
:e
4248
fun f6(module m: M)
@@ -45,7 +51,9 @@ f6(new C)
4551

4652
:e
4753
fun f7(): module M
48-
//│ FAILURE: Unexpected lack of type error
54+
//│ ╔══[ERROR] Only module methods may return module values.
55+
//│ ║ l.53: fun f7(): module M
56+
//│ ╙── ^^^^^^^^^^^^^^
4957

5058

5159
fun ok1(module m: M)

0 commit comments

Comments
 (0)