@@ -489,12 +489,40 @@ extends Importer:
489
489
case ((pss, ctx), ps) =>
490
490
val (qs, newCtx) = params(ps)(using ctx)
491
491
(pss :+ ParamList (ParamListFlags .empty, qs), newCtx)
492
+ // * Elaborate signature
493
+ val s = td.signature.orElse(newSignatureTrees.get(id.name)).map(term)
492
494
val b = rhs.map(term(_)(using newCtx))
493
495
val r = FlowSymbol (s " ‹result of ${sym}› " , nextUid)
494
- val tdf = TermDefinition (owner, k, sym, pss,
495
- td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
496
- TermDefFlags (mod))
496
+ val tdf = TermDefinition (owner, k, sym, pss, s, b, r, TermDefFlags (mod))
497
497
sym.defn = S (tdf)
498
+
499
+ // the return type of the function
500
+ val result = td.head match
501
+ case InfixApp (_, Keyword .`:`, rhs) => S (term(rhs)(using newCtx))
502
+ case _ => N
503
+ // indicates if the function really returns a module
504
+ val em = b.fold(false )(ModuleChecker .evalsToModule)
505
+
506
+ // checks rules regarding module methods
507
+ result match
508
+ case N if em => raise :
509
+ ErrorReport :
510
+ msg " Function returning module values must have explicit return types. " ->
511
+ td.head.toLoc :: Nil
512
+ case S (ret) if em && ModuleChecker .isTypeParam(ret) => raise :
513
+ ErrorReport :
514
+ msg " Function returning module values must have concrete return types. " ->
515
+ td.head.toLoc :: Nil
516
+ case S (ret) if em && ! ret.isInstanceOf [Term .Mod ] => raise :
517
+ ErrorReport :
518
+ msg " The return type of functions returning module values must be prefixed with module keyword. " ->
519
+ td.head.toLoc :: Nil
520
+ case S (Term .Mod (_)) if ! mod => raise :
521
+ ErrorReport :
522
+ msg " Only module methods may return module values. " ->
523
+ td.head.toLoc :: Nil
524
+ case _ => ()
525
+
498
526
tdf
499
527
go(sts, tdf :: acc)
500
528
case L (d) =>
@@ -697,6 +725,23 @@ extends Importer:
697
725
.filter(_.isInstanceOf [VarSymbol ])
698
726
.flatMap(_.asInstanceOf [VarSymbol ].decl)
699
727
.fold(false )(_.isInstanceOf [TyParam ])
728
+
729
+ /** Checks if a term evaluates to a module value. */
730
+ def evalsToModule (t : Term ): Bool =
731
+ def returnsModule (t : Tree ): Bool = t match
732
+ case InfixApp (_, Keyword .`:`, rhs) => rhs match
733
+ case TypeDef (Mod , _, _, _) => true
734
+ case _ => false
735
+ case _ => false
736
+ t match
737
+ case Term .Blk (_, res) => evalsToModule(res)
738
+ case Term .App (lhs, rhs) => evalsToModule(lhs)
739
+ case t => t.symbol match
740
+ case S (sym : BlockMemberSymbol ) => sym match
741
+ case _ if sym.modTree.isDefined => true
742
+ case _ if sym.trmTree.isDefined => returnsModule(sym.trmTree.get.head)
743
+ case _ => false
744
+ case _ => false
700
745
701
746
class VarianceTraverser (var changed : Bool = true ) extends Traverser :
702
747
override def traverseType (pol : Pol )(trm : Term ): Unit = trm match
0 commit comments