Skip to content

Commit 933af7c

Browse files
Separate module parameters from regular parameters
Module parameters are function parameters that have module identifiers as their type annotations. Module parameters must have an explicit and concrete type.
1 parent 45ab1f1 commit 933af7c

File tree

4 files changed

+66
-25
lines changed

4 files changed

+66
-25
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,23 @@ extends Importer:
605605
case id: Ident =>
606606
Param(FldFlags.empty, fieldOrVarSym(ParamBind, id), N) :: Nil
607607
case InfixApp(lhs: Ident, Keyword.`:`, rhs) =>
608-
Param(FldFlags.empty, fieldOrVarSym(ParamBind, lhs), S(term(rhs))) :: Nil
608+
// return S(moduleParam) if t represents a module parameter
609+
// return N otherwise
610+
def moduleParam(t: Term): Opt[Param] = t match
611+
case s: hkmc2.semantics.Term.Sel => s.symbol
612+
.flatMap(_.asMod)
613+
.map(_ => Param(FldFlags.module, fieldOrVarSym(ParamBind, lhs), S(s)))
614+
case hkmc2.semantics.Term.TyApp(s: hkmc2.semantics.Term.Sel, _) => s.symbol
615+
.flatMap(_.asMod)
616+
.map(_ => Param(FldFlags.module, fieldOrVarSym(ParamBind, lhs), S(s)))
617+
case _ => N
618+
619+
val t = term(rhs)
620+
moduleParam(t) match
621+
case S(p) =>
622+
p :: Nil
623+
case N =>
624+
Param(FldFlags.empty, fieldOrVarSym(ParamBind, lhs), S(term(rhs))) :: Nil
609625
case App(Ident(","), list) => params(list)._1
610626
case TermDef(ImmutVal, inner, _) => param(inner)
611627

@@ -678,7 +694,7 @@ extends Importer:
678694
class VarianceTraverser(var changed: Bool = true) extends Traverser:
679695
override def traverseType(pol: Pol)(trm: Term): Unit = trm match
680696
case Term.TyApp(lhs, targs) =>
681-
lhs.symbol.flatMap(_.asTpe) match
697+
lhs.symbol.flatMap(sym => sym.asTpe orElse sym.asMod) match
682698
case S(sym: ClassSymbol) =>
683699
sym.defn match
684700
case S(td: ClassDef) =>
@@ -690,6 +706,17 @@ extends Importer:
690706
if !tp.isCovariant then traverseType(pol.!)(targ)
691707
case N =>
692708
TODO(sym->sym.uid)
709+
case S(sym: ModuleSymbol) =>
710+
sym.defn match
711+
case S(td: ModuleDef) =>
712+
if td.tparams.sizeCompare(targs) =/= 0 then
713+
raise(ErrorReport(msg"Wrong number of type arguments" -> trm.toLoc :: Nil)) // TODO BE
714+
td.tparams.zip(targs).foreach:
715+
case (tp, targ) =>
716+
if !tp.isContravariant then traverseType(pol)(targ)
717+
if !tp.isCovariant then traverseType(pol.!)(targ)
718+
case N =>
719+
TODO(sym->sym.uid)
693720
case S(sym: TypeAliasSymbol) =>
694721
// TODO dedup with above...
695722
sym.defn match

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,8 @@ case class TypeDef(
283283

284284

285285
// TODO Store optional source locations for the flags instead of booleans
286-
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool):
287-
def showDbg: Str = (if mut then "mut " else "") + (if spec then "spec " else "") + (if genGetter then "val " else "")
286+
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool):
287+
def showDbg: Str = (if mut then "mut " else "") + (if spec then "spec " else "") + (if genGetter then "val " else "") + (if mod then "mod " else "")
288288
override def toString: String = "" + showDbg + ""
289289

290290
final case class Fld(flags: FldFlags, value: Term, asc: Opt[Term]) extends FldImpl
@@ -308,7 +308,12 @@ final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Op
308308
// def showDbg: Str = flags.showDbg + sym.name + ": " + sign.showDbg
309309
def showDbg: Str = flags.showDbg + sym + sign.fold("")(": " + _.showDbg)
310310

311-
object FldFlags { val empty: FldFlags = FldFlags(false, false, false) }
311+
object FldFlags {
312+
val empty: FldFlags = FldFlags(false, false, false, false)
313+
314+
// module parameter
315+
val module: FldFlags = FldFlags(false, false, false, true)
316+
}
312317

313318
final case class ParamListFlags(ctx: Bool):
314319
def showDbg: Str = (if ctx then "ctx " else "")

hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ extension (t: Product)
4141
val flags = Buffer.empty[String]
4242
if mod then flags += "module"
4343
if flags.isEmpty then "()" else flags.mkString("(", ", ", ")")
44-
case FldFlags(mut, spec, genGetter) =>
44+
case FldFlags(mut, spec, genGetter, mod) =>
4545
val flags = Buffer.empty[String]
4646
if mut then flags += "mut"
4747
if spec then flags += "spec"
4848
if genGetter then flags += "gen"
49+
if mod then flags += "module"
4950
if flags.isEmpty then "()" else flags.mkString("(", ", ", ")")
5051
case t: Product => t.showAsTree(inTailPos)
5152
case v => v.toString

hkmc2/shared/src/test/mlscript/basics/ModuleMethods.mls

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ fun f2(m: M): Int = m.foo(42)
155155
//│ flags = ParamListFlags of false
156156
//│ params = Ls of
157157
//│ Param:
158-
//│ flags = ()
158+
//│ flags = (module)
159159
//│ sym = m@42
160160
//│ sign = S of Sel:
161161
//│ prefix = Ref of globalThis:block#1
@@ -296,30 +296,38 @@ id(IntMT)
296296
// Good: module parameters must have an explicit and concrete type.
297297
// Mod[T] is a concrete type and Mod is a module, therefore t is a module parameter.
298298
fun idMod[T](t: MT[T]): MT[T] = t
299-
//│ FAILURE: Unexpected exception
300-
//│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing
301-
//│ at: scala.Predef$.$qmark$qmark$qmark(Predef.scala:344)
302-
//│ at: hkmc2.semantics.Elaborator$VarianceTraverser.traverseType(Elaborator.scala:708)
303-
//│ at: hkmc2.semantics.Elaborator$Traverser.traverseType$$anonfun$10(Elaborator.scala:750)
304-
//│ at: scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
305-
//│ at: scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
306-
//│ at: scala.Option.foreach(Option.scala:437)
307-
//│ at: hkmc2.semantics.Elaborator$Traverser.traverseType(Elaborator.scala:750)
308-
//│ at: hkmc2.semantics.Elaborator.go$4$$anonfun$1$$anonfun$1(Elaborator.scala:660)
309-
//│ at: scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
310-
//│ at: scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
299+
//│ Elaborated tree:
300+
//│ Blk:
301+
//│ stats = Ls of
302+
//│ TermDefinition:
303+
//│ owner = S of globalThis:block#12
304+
//│ k = Fun
305+
//│ sym = member:idMod
306+
//│ params = Ls of
307+
//│ ParamList:
308+
//│ flags = ParamListFlags of false
309+
//│ params = Ls of
310+
//│ Param:
311+
//│ flags = (module)
312+
//│ sym = t@60
313+
//│ sign = S of Sel:
314+
//│ prefix = Ref of globalThis:block#8
315+
//│ nme = Ident of "MT"
316+
//│ sign = N
317+
//│ body = S of Ref of t@60
318+
//│ resSym = ‹result of member:idMod›@61
319+
//│ flags = ()
320+
//│ res = Lit of UnitLit of true
311321

312322
// OK
313323
idMod(IntMT)
314-
//│ FAILURE: Unexpected type error
315-
//│ ╔══[ERROR] Name not found: idMod
316-
//│ ║ l.313: idMod(IntMT)
317-
//│ ╙── ^^^^^
318324
//│ Elaborated tree:
319325
//│ Blk:
320326
//│ stats = Nil
321327
//│ res = App:
322-
//│ lhs = Error
328+
//│ lhs = Sel:
329+
//│ prefix = Ref of globalThis:block#12
330+
//│ nme = Ident of "idMod"
323331
//│ rhs = Tup of Ls of
324332
//│ Fld:
325333
//│ flags = ()
@@ -344,7 +352,7 @@ fun f3(x: M) = x
344352
//│ flags = ParamListFlags of false
345353
//│ params = Ls of
346354
//│ Param:
347-
//│ flags = ()
355+
//│ flags = (module)
348356
//│ sym = x@65
349357
//│ sign = S of Sel:
350358
//│ prefix = Ref of globalThis:block#1

0 commit comments

Comments
 (0)