Skip to content

Improve module method checks #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ final case class ClsLikeDefn(

final case class Handler(
sym: BlockMemberSymbol,
resumeSym: LocalSymbol & NamedSymbol,
resumeSym: VarSymbol,
params: Ls[ParamList],
body: Block,
):
Expand Down
13 changes: 11 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,11 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
clsSym,
BlockMemberSymbol(clsSym.nme, Nil),
syntax.Cls,
S(PlainParamList(Param(FldFlags.empty, pcVar, N) :: Nil)),
S(PlainParamList({
val p = Param(FldFlags.empty.copy(value = true), pcVar, N)
pcVar.decl = S(p)
p
} :: Nil)),
Nil,
S(paths.contClsPath),
resumeFnDef :: Nil,
Expand All @@ -553,7 +557,12 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
Assign(freshTmp(), PureCall(
Value.Ref(State.builtinOpsMap("super")), // refers to runtime.FunctionContFrame which is pure
Value.Lit(Tree.UnitLit(true)) :: Nil), End()),
End()))
AssignField(
clsSym.asPath,
pcVar.id,
Value.Ref(pcVar),
End()
)(S(pcSymbol))))

private def genNormalBody(b: Block, clsSym: BlockMemberSymbol)(using HandlerCtx): Block =
val transform = new BlockTransformerShallow(SymbolSubst()):
Expand Down
74 changes: 56 additions & 18 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val ignoredDefns: Set[BlockMemberSymbol] = Set.empty,
val inScopeDefns: Map[BlockMemberSymbol, Set[BlockMemberSymbol]] = Map.empty,
val modLocals: Map[BlockMemberSymbol, Local] = Map.empty,
val localCaptureSyms: Map[Local, LocalSymbol & NamedSymbol] = Map.empty,
val localCaptureSyms: Map[Local, (VarSymbol, TermDefinition)] = Map.empty,
val prevFnLocals: FreeVars = FreeVars.empty,
val prevClsDefns: List[ClsLikeDefn] = Nil,
val curModules: List[ClsLikeDefn] = Nil,
Expand Down Expand Up @@ -186,7 +186,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
def withInScopes(mp: Map[BlockMemberSymbol, Set[BlockMemberSymbol]]) = copy(inScopeDefns = mp)
def addFnLocals(f: FreeVars) = copy(prevFnLocals = prevFnLocals ++ f)
def addClsDefn(c: ClsLikeDefn) = copy(prevClsDefns = c :: prevClsDefns)
def addLocalCaptureSyms(m: Map[Local, LocalSymbol & NamedSymbol]) = copy(localCaptureSyms = localCaptureSyms ++ m)
def addLocalCaptureSyms(m: Map[Local, (VarSymbol, TermDefinition)]) = copy(localCaptureSyms = localCaptureSyms ++ m)
def getBmsReqdInfo(sym: BlockMemberSymbol) = bmsReqdInfo.get(sym)
def replCapturePaths(paths: Map[BlockMemberSymbol, Path]) = copy(capturePaths = paths)
def addCapturePath(src: BlockMemberSymbol, path: Path) = copy(capturePaths = capturePaths + (src -> path))
Expand Down Expand Up @@ -233,18 +233,51 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):

val fresh = FreshInt()

val varsMap: Map[Local, TermSymbol] = cap.map: s =>
val varsMap = cap.map: s =>
val id = fresh.make
s -> TermSymbol(syntax.ParamBind, S(clsSym), Tree.Ident(s.nme + id + "$"))
val nme = s.nme + id + "$"
val varSym = VarSymbol(Tree.Ident(nme))
val fldSym = BlockMemberSymbol(nme, Nil)
val fldDef = TermDefinition(
S(clsSym),
syntax.ImmutVal,
fldSym,
Nil, N, N,
S(Term.Ref(s)(Tree.Ident(s.nme), 666)), // FIXME: 666 is a dummy value
FlowSymbol("‹class-param-res›"),
TermDefFlags.empty,
Nil
)
fldSym.defn = S(fldDef)
s -> (
varSym,
fldDef,
)
.toMap

val varsList = cap.toList

val defn = ClsLikeDefn(
None, clsSym, BlockMemberSymbol(nme, Nil),
syntax.Cls,
S(PlainParamList(varsList.map(s => Param(FldFlags.empty, varsMap(s), None)))),
Nil, None, Nil, Nil, Nil, End(), End()
S(PlainParamList(varsList.map: s =>
val sym = varsMap(s)._1
val p = Param(FldFlags.empty.copy(value = true), sym, None)
sym.decl = S(p)
p
)),
Nil, None, Nil, Nil,
varsList.map(varsMap(_)._2),
varsList.map(varsMap(_)).foldLeft[Block](End()):
case (acc, (varSym, fldDef)) =>
AssignField(
clsSym.asPath,
Tree.Ident(fldDef.sym.nme),
Value.Ref(varSym),
acc
)(S(fldDef.sym))
,
End()
)

(defn, varsMap, varsList)
Expand Down Expand Up @@ -597,7 +630,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
case _ => super.applyBlock(rewritten)

case Assign(lhs, rhs, rest) => ctx.getLocalCaptureSym(lhs) match
case Some(captureSym) =>
case Some((captureSym, _)) =>
AssignField(ctx.getLocalClosPath(lhs).get, captureSym.id, applyResult(rhs), applyBlock(rest))(N)
case None => ctx.getLocalPath(lhs) match
case None => super.applyBlock(rewritten)
Expand Down Expand Up @@ -655,7 +688,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
// This rewrites naked references to locals. If a function is in a capture, then we select that value
// from the capture; otherwise, we see if that local is passed directly as a parameter to this defn.
case Value.Ref(l) => ctx.getLocalCaptureSym(l) match
case Some(captureSym) =>
case Some((captureSym, _)) =>
Select(ctx.getLocalClosPath(l).get, captureSym.id)(N)
case None => ctx.getLocalPath(l) match
case Some(value) => Value.Ref(value)
Expand Down Expand Up @@ -690,8 +723,13 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val fresh = FreshInt()
(nme: String) =>
val id = fresh.make
TermSymbol(syntax.ParamBind, S(d.isym), Tree.Ident(nme + "$" + id))
case _ => ((nme: String) => VarSymbol(Tree.Ident(nme)))
(
VarSymbol(Tree.Ident(nme + "$" + id)),
TermSymbol(syntax.ParamBind, S(d.isym), Tree.Ident(nme + "$" + id))
)
case _ => (nme: String) =>
val vsym = VarSymbol(Tree.Ident(nme))
(vsym, vsym)

val capturesSymbols = includedCaptures.map: sym =>
(sym, createSym(sym.nme + "$capture"))
Expand All @@ -706,27 +744,27 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
(sym, createSym(sym.nme + "$member"))

val extraParamsCaptures = capturesSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newCapturePaths = capturesSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym.asPath
case (d, (_, sym)) => d -> sym.asPath
.toMap

val extraParamsLocals = localsSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newLocalsPaths = localsSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym
case (d, (_, sym)) => d -> sym
.toMap

val extraParamsIsyms = isymSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newIsymPaths = isymSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym
case (d, (_, sym)) => d -> sym
.toMap

val extraParamsBms = bmsSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newBmsPaths = bmsSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym.asPath
case (d, (_, sym)) => d -> sym.asPath
.toMap

val extraParams = extraParamsBms ++ extraParamsIsyms ++ extraParamsLocals ++ extraParamsCaptures
Expand Down
10 changes: 7 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
// * Note: `_pubFlds` is not used because in JS, fields are not declared
val clsParams = paramsOpt.fold(Nil)(_.paramSyms)
val ctorParams = clsParams.map(p => p -> scope.allocateName(p))
val ctorFields = ctorParams.filter: p =>
p._1.decl match
case S(Param(flags = FldFlags(value = true))) => true
case _ => false
val ctorAuxParams = auxParams.map(ps => ps.params.map(p => p.sym -> scope.allocateName(p.sym)))

val isModule = kind is syntax.Mod
Expand All @@ -212,7 +216,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
val nme = scp.allocateName(fld)
doc" # $mtdPrefix#$nme;"
.mkDocument(doc"")
val preCtorCode = (ctorParams ++ ctorAuxParams.flatMap(ps => ps)).foldLeft(body(preCtor, endSemi = true)):
val preCtorCode = ctorAuxParams.flatMap(ps => ps).foldLeft(body(preCtor, endSemi = true)):
case (acc, (sym, nme)) =>
doc"$acc # this.${sym.name} = $nme;"
val ctorCode = doc"$preCtorCode${body(ctor, endSemi = false)}"
Expand Down Expand Up @@ -267,9 +271,9 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
else doc""" # ${mtdPrefix}toString() { return "${sym.nme}${
if paramsOpt.isEmpty then doc"""""""
else doc"""(" + ${
ctorParams.headOption.fold("\"\"")("globalThis.Predef.render(this." + _._1.name + ")")
ctorFields.headOption.fold("\"\"")("globalThis.Predef.render(this." + _._1.name + ")")
}${
ctorParams.tailOption.fold("")(_.map(
ctorFields.tailOption.fold("")(_.map(
""" + ", " + globalThis.Predef.render(this.""" + _._1.name + ")").mkString)
} + ")""""
}; }"""
Expand Down
87 changes: 72 additions & 15 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ extends Importer:
case _ => trm

def annot(tree: Tree): Ctxl[Opt[Annot]] = tree match
case Keywrd(kw @ (Keyword.`abstract` | Keyword.`declare`)) => S(Annot.Modifier(kw))
case Keywrd(kw @ (Keyword.`abstract` | Keyword.`declare` | Keyword.`data`)) => S(Annot.Modifier(kw))
case _ => term(tree) match
case Term.Error => N
case trm =>
Expand Down Expand Up @@ -396,7 +396,7 @@ extends Importer:
Term.FunTy(term(lhs), term(rhs), N)
case InfixApp(lhs, Keyword.`=>`, rhs) =>
ctx.nest(N).givenIn:
val (syms, nestCtx) = params(lhs)
val (syms, nestCtx) = params(lhs, false)
Term.Lam(syms, term(rhs)(using nestCtx))
case InfixApp(lhs, Keyword.`as`, rhs) =>
Term.Asc(term(lhs), term(rhs))
Expand Down Expand Up @@ -956,7 +956,7 @@ extends Importer:
// * Add parameters to context
var newCtx = newCtx1
val pss = td.paramLists.map: ps =>
val (res, newCtx2) = params(ps)(using newCtx)
val (res, newCtx2) = params(ps, false)(using newCtx)
newCtx = newCtx2
res
// * Elaborate signature
Expand Down Expand Up @@ -1034,6 +1034,9 @@ extends Importer:
res :: Nil
case N => Nil
newCtx ++= tps.map(tp => tp.sym.name -> tp.sym) // TODO: correct ++?
val isDataClass = annotations.exists:
case Annot.Modifier(Keyword.`data`) => true
case _ => false
val ps =
td.paramLists.match
case Nil => N
Expand All @@ -1047,9 +1050,52 @@ extends Importer:
.map: ps =>
val (res, newCtx2) =
given Ctx = newCtx
params(ps)
params(ps, isDataClass)
newCtx = newCtx2
res
def withFields(using Ctx)(fn: (Ctx) ?=> (Term.Blk, Ctx)): (Term.Blk, Ctx) =
val fields: Opt[List[TermDefinition | LetDecl | DefineVar]] = ps.map: ps =>
ps.params.flatMap: p =>
// For class-like types, "desugar" the parameters into additional class fields.
val owner = td.symbol match
// Any MemberSymbol should be an InnerSymbol, except for TypeAliasSymbol,
// but type aliases should not call this function.
case s: InnerSymbol => S(s)
case _: TypeAliasSymbol => die

if p.flags.value || isDataClass then
val fsym = BlockMemberSymbol(p.sym.nme, Nil)
val fdef = TermDefinition(
owner,
ImmutVal,
fsym,
Nil, N, N,
S(Term.Ref(p.sym)(p.sym.id, 666)), // FIXME: 666 is a dummy value
FlowSymbol("‹class-param-res›"),
TermDefFlags.empty.copy(isModMember = k is Mod),
Nil
)
sym.defn = S(fdef)
fdef :: Nil
else
val psym = TermSymbol(LetBind, owner, p.sym.id)
val decl = LetDecl(psym, Nil)
val defn = DefineVar(psym, Term.Ref(p.sym)(p.sym.id, 666)) // FIXME: 666 is a dummy value
decl :: defn :: Nil

val ctxWithFields = ctx
.withMembers(
fields.fold(Nil)(_.collect:
case f: TermDefinition => f.sym.nme -> f.sym // class fields
),
ctx.outer
) ++ fields.fold(Nil)(_.collect:
case d: LetDecl => d.sym.nme -> d.sym // class params
)
val ctxWithLets = ctx
val (blk, c) = fn(using ctxWithFields)
val blkWithFields = fields.fold[Term.Blk](blk)(fs => blk.copy(stats = fs ::: blk.stats))
(blkWithFields, c)
val defn = k match
case Als =>
val alsSym = td.symbol.asInstanceOf[TypeAliasSymbol] // TODO improve `asInstanceOf`
Expand All @@ -1072,8 +1118,8 @@ extends Importer:
case S(tree) =>
val (patternParams, extractionParams) = ps match // Filter out pattern parameters.
case S(ParamList(_, params, _)) => params.partition:
case param @ Param(FldFlags(false, false, false, false, true), _, _) => true
case param @ Param(FldFlags(_, _, _, _, false), _, _) => false
case param @ Param(FldFlags(false, false, false, false, true, false), _, _) => true
case param @ Param(FldFlags(_, _, _, _, false, _), _, _) => false
case N => (Nil, Nil)
// TODO: Implement extraction parameters.
if extractionParams.nonEmpty then
Expand Down Expand Up @@ -1101,7 +1147,8 @@ extends Importer:
newCtx.nest(S(clsSym)).givenIn:
log(s"Processing type definition $nme")
val cd =
val (bod, c) = body match
val (bod, c) = withFields:
body match
case S(b: Tree.Block) => block(b, hasResult = false)
// case S(t) => block(t :: Nil)
case S(t) => ???
Expand All @@ -1115,7 +1162,8 @@ extends Importer:
newCtx.nest(S(clsSym)).givenIn:
log(s"Processing type definition $nme")
val cd =
val (bod, c) = body match
val (bod, c) = withFields:
body match
case S(b: Tree.Block) => block(b, hasResult = false)
// case S(t) => block(t :: Nil)
case S(t) => ???
Expand Down Expand Up @@ -1166,13 +1214,14 @@ extends Importer:
N
case N => N

def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): LocalSymbol & NamedSymbol =
def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): TermSymbol | VarSymbol =
if ctx.outer.isDefined then TermSymbol(k, ctx.outer, id)
else VarSymbol(id)

def param(t: Tree, inUsing: Bool): Ctxl[Opt[Opt[Bool] -> Param]] = t match
def param(t: Tree, inUsing: Bool, inDataClass: Bool): Ctxl[Opt[Opt[Bool] -> Param]] =
def go(t: Tree, inUsing: Bool, flags: FldFlags): Ctxl[Opt[Opt[Bool] -> Param]] = t match
case TypeDef(Mod, inner, N, N) =>
val ps = param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(mod = true))))
val ps = go(inner, inUsing, flags.copy(mod = true))
for p <- ps if p._2.flags.mod do p._2.sign match
case N =>
raise(ErrorReport(msg"Module parameters must have explicit types." -> t.toLoc :: Nil))
Expand All @@ -1181,18 +1230,26 @@ extends Importer:
case _ => ()
ps
case TypeDef(Pat, inner, N, N) =>
param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(pat = true))))
go(inner, inUsing, flags.copy(pat = true))
case TermDef(ImmutVal, inner, _) =>
go(inner, inUsing, flags.copy(value = true))
case _ =>
t.asParam(inUsing).map: (isSpd, p, t) =>
isSpd -> Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term(_)))
val sym = VarSymbol(p)
val sign = t.map(term(_))
val param = Param(flags, sym, sign)
sym.decl = S(param)
isSpd -> param
go(t, inUsing, if inDataClass then FldFlags.empty.copy(value = true) else FldFlags.empty)


def params(t: Tree): Ctxl[(ParamList, Ctx)] = t match
def params(t: Tree, inDataClass: Bool): Ctxl[(ParamList, Ctx)] = t match
case Tup(ps) =>
def go(ps: Ls[Tree], acc: Ls[Param], ctx: Ctx, flags: ParamListFlags): (ParamList, Ctx) =
ps match
case Nil => (ParamList(flags, acc.reverse, N), ctx)
case hd :: tl =>
param(hd, flags.ctx)(using ctx) match
param(hd, flags.ctx, inDataClass)(using ctx) match
case S((isSpd, p)) =>
val isCtx = hd match
case Modified(Keyword.`using`, _, _) => true
Expand Down
Loading
Loading