Skip to content

Commit 1c4936b

Browse files
address comments
1 parent cf96590 commit 1c4936b

File tree

4 files changed

+48
-414
lines changed

4 files changed

+48
-414
lines changed

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -235,28 +235,31 @@ extends Importer:
235235
val lt = term(lhs)
236236
val rt = term(rhs)
237237

238-
// Check module parameters
238+
// Check if module arguments match module parameters
239239
val args = rt match
240240
case Term.Tup(fields) => S(fields)
241241
case _ => N
242-
val argsModFlags = args
243-
.map(_.map(_.flags.mod))
244-
val paramsModFlags = lt.symbol match
245-
case S(sym: BlockMemberSymbol) => sym.defn match
246-
case S(defn: TermDefinition) => defn.params.lift(0)
247-
.map(_.params.map(_.flags.mod))
248-
case _ => argsModFlags.map(_.map(_ => false))
249-
case _ => argsModFlags.map(_.map(_ => false))
242+
val params = lt.symbol
243+
.filter(_.isInstanceOf[BlockMemberSymbol])
244+
.flatMap(_.asInstanceOf[BlockMemberSymbol].trmTree)
245+
.filter(_.isInstanceOf[TermDef])
246+
.flatMap(_.asInstanceOf[TermDef].paramLists.headOption)
250247
for
251-
(argLists, amfs, pmfs) <- (args lazyZip argsModFlags lazyZip paramsModFlags)
252-
(a, amf, pmf) <- (argLists lazyZip amfs lazyZip pmfs)
253-
if amf && !pmf
248+
(args, params) <- (args zip params)
249+
(arg, param) <- (args zip params.fields)
254250
do
255-
log(s"${a.value}")
256-
raise(ErrorReport(
257-
msg"Module values can only be passed to module parameters." -> a.toLoc
258-
:: Nil,
259-
))
251+
val argMod = arg.flags.mod
252+
val paramMod = param match
253+
case Tree.TypeDef(Mod, _, _, _) => true
254+
case _ => false
255+
if argMod && !paramMod then
256+
raise:
257+
ErrorReport:
258+
msg"Module values can only be passed to module parameters." -> arg.toLoc :: Nil
259+
if !argMod && paramMod then
260+
raise:
261+
ErrorReport:
262+
msg"Module parameters can only receive module values." -> arg.toLoc :: Nil
260263

261264
Term.App(lt, rt)(tree, sym)
262265
case Sel(pre, nme) =>
@@ -354,7 +357,7 @@ extends Importer:
354357
case _ =>
355358
val t = term(tree)
356359
t.symbol.flatMap(_.asMod) match
357-
case S(_) => Fld(FldFlags.module, t, N)
360+
case S(_) => Fld(FldFlags.empty.copy(mod = true), t, N)
358361
case N => Fld(FldFlags.empty, t, N)
359362

360363
def unit: Term.Lit = Term.Lit(UnitLit(true))
@@ -652,24 +655,10 @@ extends Importer:
652655
case id: Ident =>
653656
Param(FldFlags.empty, fieldOrVarSym(ParamBind, id), N) :: Nil
654657
case InfixApp(lhs: Ident, Keyword.`:`, rhs) =>
655-
// return S(moduleParam) if t represents a module parameter
656-
// return N otherwise
657-
def moduleParam(t: Term): Opt[Param] = t match
658-
case s: hkmc2.semantics.Term.Sel => s.symbol
659-
.flatMap(_.asMod)
660-
.map(_ => Param(FldFlags.module, fieldOrVarSym(ParamBind, lhs), S(s)))
661-
case hkmc2.semantics.Term.TyApp(s: hkmc2.semantics.Term.Sel, _) => s.symbol
662-
.flatMap(_.asMod)
663-
.map(_ => Param(FldFlags.module, fieldOrVarSym(ParamBind, lhs), S(s)))
664-
case _ => N
665-
666-
val t = term(rhs)
667-
moduleParam(t) match
668-
case S(p) =>
669-
p :: Nil
670-
case N =>
671-
Param(FldFlags.empty, fieldOrVarSym(ParamBind, lhs), S(term(rhs))) :: Nil
658+
Param(FldFlags.empty, fieldOrVarSym(ParamBind, lhs), S(term(rhs))) :: Nil
672659
case App(Ident(","), list) => params(list)._1
660+
case TypeDef(Mod, inner, _, _) => param(inner)
661+
.map(p => p.copy(flags = p.flags.copy(mod = true)))
673662
case TermDef(ImmutVal, inner, _) => param(inner)
674663

675664
def params(t: Tree): Ctxl[(Ls[Param], Ctx)] = t match

hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package semantics
33

44
import mlscript.utils.*, shorthands.*
55
import syntax.*
6+
import scala.collection.mutable.Buffer
67

78

89
final case class QuantVar(sym: VarSymbol, ub: Opt[Term], lb: Opt[Term])
@@ -187,10 +188,10 @@ final case class LetDecl(sym: LocalSymbol) extends Statement
187188
final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement
188189

189190
final case class TermDefFlags(mod: Bool):
190-
def showDbg: Str = (
191-
(if mod then S("module") else N) ::
192-
Nil
193-
).flatten.mkString(" ")
191+
def showDbg: Str =
192+
val flags = Buffer.empty[String]
193+
if mod then flags += "module"
194+
flags.mkString(" ")
194195
override def toString: String = "" + showDbg + ""
195196

196197
object TermDefFlags { val empty: TermDefFlags = TermDefFlags(false) }
@@ -284,7 +285,13 @@ case class TypeDef(
284285

285286
// TODO Store optional source locations for the flags instead of booleans
286287
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool):
287-
def showDbg: Str = (if mut then "mut " else "") + (if spec then "spec " else "") + (if genGetter then "val " else "") + (if mod then "mod " else "")
288+
def showDbg: Str =
289+
val flags = Buffer.empty[String]
290+
if mut then flags += "mut"
291+
if spec then flags += "spec"
292+
if genGetter then flags += "gen"
293+
if mod then flags += "module"
294+
flags.mkString(" ")
288295
override def toString: String = "" + showDbg + ""
289296

290297
final case class Fld(flags: FldFlags, value: Term, asc: Opt[Term]) extends FldImpl
@@ -310,10 +317,7 @@ final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Op
310317

311318
object FldFlags {
312319
val empty: FldFlags = FldFlags(false, false, false, false)
313-
314-
// module parameter / module argument, depending on the context
315-
val module: FldFlags = FldFlags(false, false, false, true)
316-
}
320+
}
317321

318322
final case class ParamListFlags(ctx: Bool):
319323
def showDbg: Str = (if ctx then "ctx " else "")

hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ extension (t: Product)
4040
case TermDefFlags(mod) =>
4141
val flags = Buffer.empty[String]
4242
if mod then flags += "module"
43-
if flags.isEmpty then "()" else flags.mkString("(", ", ", ")")
43+
flags.mkString("(", ", ", ")")
4444
case FldFlags(mut, spec, genGetter, mod) =>
4545
val flags = Buffer.empty[String]
4646
if mut then flags += "mut"
4747
if spec then flags += "spec"
4848
if genGetter then flags += "gen"
4949
if mod then flags += "module"
50-
if flags.isEmpty then "()" else flags.mkString("(", ", ", ")")
50+
flags.mkString("(", ", ", ")")
5151
case t: Product => t.showAsTree(inTailPos)
5252
case v => v.toString
5353
val postfix = post(t)

0 commit comments

Comments
 (0)