@@ -489,12 +489,41 @@ extends Importer:
489
489
case ((pss, ctx), ps) =>
490
490
val (qs, newCtx) = params(ps)(using ctx)
491
491
(pss :+ ParamList (ParamListFlags .empty, qs), newCtx)
492
+ // * Elaborate signature
493
+ val s = td.signature.orElse(newSignatureTrees.get(id.name)).map(term)
492
494
val b = rhs.map(term(_)(using newCtx))
493
495
val r = FlowSymbol (s " ‹result of ${sym}› " , nextUid)
494
- val tdf = TermDefinition (owner, k, sym, pss,
495
- td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
496
- TermDefFlags (mod))
496
+ val tdf = TermDefinition (owner, k, sym, pss, s, b, r,
497
+ TermDefFlags .empty.copy(mod = isModMember))
497
498
sym.defn = S (tdf)
499
+
500
+ // the return type of the function
501
+ val result = td.head match
502
+ case InfixApp (_, Keyword .`:`, rhs) => S (term(rhs)(using newCtx))
503
+ case _ => N
504
+ // indicates if the function really returns a module
505
+ val em = b.fold(false )(ModuleChecker .evalsToModule)
506
+
507
+ // checks rules regarding module methods
508
+ result match
509
+ case N if em => raise :
510
+ ErrorReport :
511
+ msg " Function returning module values must have explicit return types. " ->
512
+ td.head.toLoc :: Nil
513
+ case S (ret) if em && ModuleChecker .isTypeParam(ret) => raise :
514
+ ErrorReport :
515
+ msg " Function returning module values must have concrete return types. " ->
516
+ td.head.toLoc :: Nil
517
+ case S (ret) if em && ! ret.isInstanceOf [Term .Mod ] => raise :
518
+ ErrorReport :
519
+ msg " The return type of functions returning module values must be prefixed with module keyword. " ->
520
+ td.head.toLoc :: Nil
521
+ case S (Term .Mod (_)) if ! isModMember => raise :
522
+ ErrorReport :
523
+ msg " Only module methods may return module values. " ->
524
+ td.head.toLoc :: Nil
525
+ case _ => ()
526
+
498
527
tdf
499
528
go(sts, tdf :: acc)
500
529
case L (d) =>
@@ -697,6 +726,23 @@ extends Importer:
697
726
.filter(_.isInstanceOf [VarSymbol ])
698
727
.flatMap(_.asInstanceOf [VarSymbol ].decl)
699
728
.fold(false )(_.isInstanceOf [TyParam ])
729
+
730
+ /** Checks if a term evaluates to a module value. */
731
+ def evalsToModule (t : Term ): Bool =
732
+ def returnsModule (t : Tree ): Bool = t match
733
+ case InfixApp (_, Keyword .`:`, rhs) => rhs match
734
+ case TypeDef (Mod , _, N , N ) => true
735
+ case _ => false
736
+ case _ => false
737
+ t match
738
+ case Term .Blk (_, res) => evalsToModule(res)
739
+ case Term .App (lhs, rhs) => evalsToModule(lhs)
740
+ case t => t.symbol match
741
+ case S (sym : BlockMemberSymbol ) => sym match
742
+ case _ if sym.modTree.isDefined => true
743
+ case _ if sym.trmTree.isDefined => returnsModule(sym.trmTree.get.head)
744
+ case _ => false
745
+ case _ => false
700
746
701
747
class VarianceTraverser (var changed : Bool = true ) extends Traverser :
702
748
override def traverseType (pol : Pol )(trm : Term ): Unit = trm match
0 commit comments