@@ -489,12 +489,47 @@ extends Importer:
489
489
case ((pss, ctx), ps) =>
490
490
val (qs, newCtx) = params(ps)(using ctx)
491
491
(pss :+ ParamList (ParamListFlags .empty, qs), newCtx)
492
+ // * Elaborate signature
493
+ val st = td.signature.orElse(newSignatureTrees.get(id.name))
494
+ val s = st.map(term(_)(using newCtx))
492
495
val b = rhs.map(term(_)(using newCtx))
493
496
val r = FlowSymbol (s " ‹result of ${sym}› " , nextUid)
494
- val tdf = TermDefinition (owner, k, sym, pss,
495
- td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r,
496
- TermDefFlags (isModMember))
497
+ val tdf = TermDefinition (owner, k, sym, pss, s, b, r,
498
+ TermDefFlags .empty.copy(mod = isModMember))
497
499
sym.defn = S (tdf)
500
+
501
+ // the return type of the function
502
+ val result = td.head match
503
+ case InfixApp (_, Keyword .`:`, rhs) => S (term(rhs)(using newCtx))
504
+ case _ => N
505
+
506
+ // indicates if the function really returns a module
507
+ val em = b.fold(false )(ModuleChecker .evalsToModule)
508
+ // indicates if the function marks its result as "module"
509
+ val mm = st match
510
+ case Some (TypeDef (Mod , _, N , N )) => true
511
+ case _ => false
512
+
513
+ // checks rules regarding module methods
514
+ s match
515
+ case N if em => raise :
516
+ ErrorReport :
517
+ msg " Function returning module values must have explicit return types. " ->
518
+ td.head.toLoc :: Nil
519
+ case S (t) if em && ModuleChecker .isTypeParam(t) => raise :
520
+ ErrorReport :
521
+ msg " Function returning module values must have concrete return types. " ->
522
+ td.head.toLoc :: Nil
523
+ case S (_) if em && ! mm => raise :
524
+ ErrorReport :
525
+ msg " The return type of functions returning module values must be prefixed with module keyword. " ->
526
+ td.head.toLoc :: Nil
527
+ case S (_) if mm && ! isModMember => raise :
528
+ ErrorReport :
529
+ msg " Only module methods may return module values. " ->
530
+ td.head.toLoc :: Nil
531
+ case _ => ()
532
+
498
533
tdf
499
534
go(sts, tdf :: acc)
500
535
case L (d) =>
@@ -697,6 +732,21 @@ extends Importer:
697
732
.filter(_.isInstanceOf [VarSymbol ])
698
733
.flatMap(_.asInstanceOf [VarSymbol ].decl)
699
734
.fold(false )(_.isInstanceOf [TyParam ])
735
+
736
+ /** Checks if a term evaluates to a module value. */
737
+ def evalsToModule (t : Term ): Bool =
738
+ def returnsModule (t : TermDef ): Bool = t.signature match
739
+ case S (TypeDef (Mod , _, N , N )) => true
740
+ case _ => false
741
+ t match
742
+ case Term .Blk (_, res) => evalsToModule(res)
743
+ case Term .App (lhs, rhs) => evalsToModule(lhs)
744
+ case t => t.symbol match
745
+ case S (sym : BlockMemberSymbol ) => sym match
746
+ case _ if sym.modTree.isDefined => true
747
+ case _ if sym.trmTree.isDefined => returnsModule(sym.trmTree.get)
748
+ case _ => false
749
+ case _ => false
700
750
701
751
class VarianceTraverser (var changed : Bool = true ) extends Traverser :
702
752
override def traverseType (pol : Pol )(trm : Term ): Unit = trm match
0 commit comments