Skip to content

Commit ce9abc4

Browse files
Improve module method checks
1 parent 26a4a14 commit ce9abc4

File tree

10 files changed

+77
-26
lines changed

10 files changed

+77
-26
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,13 +1092,14 @@ extends Importer:
10921092
N
10931093
case N => N
10941094

1095-
def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): LocalSymbol & NamedSymbol =
1095+
def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): TermSymbol | VarSymbol =
10961096
if ctx.outer.isDefined then TermSymbol(k, ctx.outer, id)
10971097
else VarSymbol(id)
10981098

1099-
def param(t: Tree, inUsing: Bool): Ctxl[Opt[Opt[Bool] -> Param]] = t match
1099+
def param(t: Tree, inUsing: Bool): Ctxl[Opt[Opt[Bool] -> Param]] =
1100+
def go(t: Tree, inUsing: Bool, flags: FldFlags): Ctxl[Opt[Opt[Bool] -> Param]] = t match
11001101
case TypeDef(Mod, inner, N, N) =>
1101-
val ps = param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(mod = true))))
1102+
val ps = go(inner, inUsing, flags.copy(mod = true))
11021103
for p <- ps if p._2.flags.mod do p._2.sign match
11031104
case N =>
11041105
raise(ErrorReport(msg"Module parameters must have explicit types." -> t.toLoc :: Nil))
@@ -1107,10 +1108,21 @@ extends Importer:
11071108
case _ => ()
11081109
ps
11091110
case TypeDef(Pat, inner, N, N) =>
1110-
param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(pat = true))))
1111+
go(inner, inUsing, flags.copy(pat = true))
11111112
case _ =>
11121113
t.asParam(inUsing).map: (isSpd, p, t) =>
1113-
isSpd -> Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term(_)))
1114+
val sym = fieldOrVarSym(ParamBind, p)
1115+
val sign = t.map(term(_))
1116+
val param = Param(flags, sym, sign)
1117+
sym match
1118+
case sym: TermSymbol =>
1119+
// TODO: How can a TermSymbol accept a Declaration
1120+
// sym.defn = S(TermDefinition(ctx.outer, Fun, sym, Nil, Nil, sign, N, FlowSymbol(s"‹result of ${sym}›"), TermDefFlags.empty, Nil))
1121+
case sym: VarSymbol =>
1122+
sym.decl = S(param)
1123+
isSpd -> param
1124+
go(t, inUsing, FldFlags.empty)
1125+
11141126

11151127
def params(t: Tree): Ctxl[(ParamList, Ctx)] = t match
11161128
case Tup(ps) =>

hkmc2/shared/src/main/scala/hkmc2/semantics/ImplicitResolver.scala

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -282,21 +282,50 @@ object ModuleChecker:
282282
.exists(_.isInstanceOf[TyParam])
283283

284284
/** Checks if a term evaluates to a module value. */
285-
def evalsToModule(t: Term): Bool =
286-
def isModule(t: Tree): Bool = t match
287-
case Tree.TypeDef(Mod, _, _, _) => true
288-
case _ => false
285+
def evalsToModule(t: Term): Bool =
289286
def returnsModule(t: Tree.TermDef): Bool = t.annotatedResultType match
290287
case S(Tree.TypeDef(Mod, _, N, N)) => true
291288
case _ => false
289+
def checkDecl(decl: Declaration): Bool = decl match
290+
// All TypeLikeDef are not modules, except for modules themselves.
291+
// Objects use ModuleDef but is not a module.
292+
case ModuleDef(kind = Mod) =>
293+
true
294+
case _: TypeLikeDef =>
295+
false
296+
// Check Member/Local symbols
297+
case defn: TermDefinition =>
298+
defn.flags.isModTyped
299+
case defn: Param =>
300+
defn.flags.mod
301+
case defn: TyParam =>
302+
defn.flags.mod
303+
def checkSym(sym: Symbol): Bool = sym match
304+
case sym if sym.asMod.nonEmpty => true
305+
case sym if sym.asBlkMember.flatMap(_.trmTree).exists(returnsModule) => true
306+
case _: (BuiltinSymbol | TopLevelSymbol) => false
307+
case sym: BlockLocalSymbol => sym.decl match
308+
case S(decl) => checkDecl(decl)
309+
case N =>
310+
// Most local symbols are let-bindings
311+
// which do not have a definition at this point.
312+
false
313+
case sym: MemberSymbol[?] => sym.defn match
314+
case S(defn) => checkDecl(defn)
315+
case N =>
316+
// At this point all member symbols should have definition,
317+
// except for the class(-like) that are currently being elaborated.
318+
// TODO: We will fix this by deferring the checks to the resolution stage.
319+
false
320+
case sym =>
321+
lastWords(s"Unsupported symbol kind ${sym}")
292322
t match
293323
case Term.Blk(_, res) => evalsToModule(res)
294324
case Term.App(lhs, rhs) => lhs.symbol match
295325
case S(sym: BlockMemberSymbol) => sym.trmTree.exists(returnsModule)
296326
case _ => false
297-
case t => t.symbol match
298-
case S(sym: BlockMemberSymbol) => sym.modTree.exists(isModule)
299-
case _ => false
327+
case t: Term.Ref => checkSym(t.sym)
328+
case t => t.symbol.exists(checkSym)
300329

301330
/**
302331
* An extractor that extracts the (tree) definition of a module method.

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,14 @@ final case class LetDecl(sym: LocalSymbol, annotations: Ls[Annot]) extends State
263263

264264
final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement
265265

266-
final case class TermDefFlags(isModMember: Bool):
266+
final case class TermDefFlags(isModMember: Bool, isModTyped: Bool):
267267
def showDbg: Str =
268268
val flags = Buffer.empty[String]
269269
if isModMember then flags += "module"
270270
flags.mkString(" ")
271271
override def toString: String = "" + showDbg + ""
272272

273-
object TermDefFlags { val empty: TermDefFlags = TermDefFlags(false) }
273+
object TermDefFlags { val empty: TermDefFlags = TermDefFlags(false, false) }
274274

275275
final case class TermDefinition(
276276
owner: Opt[InnerSymbol],
@@ -494,7 +494,7 @@ final case class TyParam(flags: FldFlags, vce: Opt[Bool], sym: VarSymbol) extend
494494

495495

496496
final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Opt[Term])
497-
extends AutoLocated:
497+
extends Declaration with AutoLocated:
498498
def subTerms: Ls[Term] = sign.toList
499499
override protected def children: List[Located] = subTerms
500500
// def children: Ls[Located] = self.value :: self.asc.toList ::: Nil

hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ extension (s: String)
2222

2323
import hkmc2.semantics.TermDefFlags
2424
import hkmc2.semantics.FldFlags
25+
import hkmc2.semantics.ParamListFlags
2526
import scala.collection.mutable.Buffer
2627
import mlscript.utils.StringOps
2728
import hkmc2.semantics.CtxArg
@@ -42,9 +43,10 @@ extension (t: Product)
4243
case xs: List[_] => "Ls of \n" + xs.iterator.map(aux(_)).mkString("\n").indent(" ")
4344
case xs: Vector[_] => "Vector of \n" + xs.iterator.map(aux(_)).mkString("\n").indent(" ")
4445
case s: String => s.escaped
45-
case TermDefFlags(mod) =>
46+
case TermDefFlags(isModMember, isModTyped) =>
4647
val flags = Buffer.empty[String]
47-
if mod then flags += "module"
48+
if isModMember then flags += "modMember"
49+
if isModMember then flags += "modTyped"
4850
flags.mkString("(", ", ", ")")
4951
case FldFlags(mut, spec, genGetter, mod, pat) =>
5052
val flags = Buffer.empty[String]
@@ -54,6 +56,10 @@ extension (t: Product)
5456
if mod then flags += "module"
5557
if pat then flags += "pat"
5658
flags.mkString("(", ", ", ")")
59+
case ParamListFlags(ctx) =>
60+
val flags = Buffer.empty[String]
61+
if ctx then flags += "ctx"
62+
flags.mkString("(", ", ", ")")
5763
case Loc(start, end, origin) =>
5864
val (sl, _, sc) = origin.fph.getLineColAt(start)
5965
val (el, _, ec) = origin.fph.getLineColAt(end)

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,16 @@ fun f7(): module M
6868
//│ ╙── ^^^^^^^^^^^^^^
6969

7070

71-
:todo // should be an error
7271
:e
7372
fun f8(module m: M) = m
73+
//│ ╔══[ERROR] Functions returning module values must have explicit return types.
74+
//│ ║ l.72: fun f8(module m: M) = m
75+
//│ ╙── ^^^^^^^^^^^^^^^
7476

7577
:e
7678
fun f9(module m: M): module M = m
7779
//│ ╔══[ERROR] Only module methods may return module values.
78-
//│ ║ l.76: fun f9(module m: M): module M = m
80+
//│ ║ l.78: fun f9(module m: M): module M = m
7981
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^^
8082

8183
module Test with

hkmc2/shared/src/test/mlscript/codegen/FieldSymbols.mls

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ case
6060
//│ stats = Nil
6161
//│ res = Lam:
6262
//│ params = ParamList:
63-
//│ flags = ParamListFlags of false
63+
//│ flags = ()
6464
//│ params = Ls of
6565
//│ Param:
6666
//│ flags = ()
@@ -92,7 +92,7 @@ case
9292
//│ lhs = $block$res
9393
//│ rhs = Lam:
9494
//│ params = ParamList:
95-
//│ flags = ParamListFlags of false
95+
//│ flags = ()
9696
//│ params = Ls of
9797
//│ Param:
9898
//│ flags = ()

hkmc2/shared/src/test/mlscript/codegen/Hygiene.mls

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,10 @@ foo()
156156

157157
:sjs
158158
module Whoops with
159-
val v = this
159+
val v: module Whoops = this
160160
fun f() = "Hello"
161161
module Whoops with
162-
val w = this
162+
val w: module Whoops = this
163163
fun g() = f()
164164
//│ JS (unsanitized):
165165
//│ let Whoops2;

hkmc2/shared/src/test/mlscript/codegen/SelfReferences.mls

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ Foo.self
2020
//│ = Foo
2121

2222

23-
:todo
2423
:e
2524
module Foo with
2625
val self = this
26+
//│ ╔══[ERROR] Functions returning module values must have explicit return types.
27+
//│ ║ l.25: val self = this
28+
//│ ╙── ^^^^
2729

2830
Foo.self
2931
//│ = Foo

hkmc2/shared/src/test/mlscript/ucs/papers/OperatorSplit.mls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ fun example(args) =
7575
//│ sym = member:example
7676
//│ params = Ls of
7777
//│ ParamList:
78-
//│ flags = ParamListFlags of false
78+
//│ flags = ()
7979
//│ params = Ls of
8080
//│ Param:
8181
//│ flags = ()

hkmc2/shared/src/test/mlscript/ucs/syntax/NestedOpSplits.mls

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ fun f(x) =
1616
//│ sym = member:f
1717
//│ params = Ls of
1818
//│ ParamList:
19-
//│ flags = ParamListFlags of false
19+
//│ flags = ()
2020
//│ params = Ls of
2121
//│ Param:
2222
//│ flags = ()

0 commit comments

Comments
 (0)