Skip to content

Improve module method checks #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ final case class ClsLikeDefn(

final case class Handler(
sym: BlockMemberSymbol,
resumeSym: LocalSymbol & NamedSymbol,
resumeSym: VarSymbol,
params: Ls[ParamList],
body: Block,
):
Expand Down
13 changes: 11 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,11 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
clsSym,
BlockMemberSymbol(clsSym.nme, Nil),
syntax.Cls,
S(PlainParamList(Param(FldFlags.empty, pcVar, N) :: Nil)),
S(PlainParamList({
val p = Param(FldFlags.empty.copy(value = true), pcVar, N)
pcVar.decl = S(p)
p
} :: Nil)),
Nil,
S(paths.contClsPath),
resumeFnDef :: Nil,
Expand All @@ -527,7 +531,12 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
Assign(freshTmp(), PureCall(
Value.Ref(State.builtinOpsMap("super")), // refers to runtime.FunctionContFrame which is pure
Value.Lit(Tree.UnitLit(true)) :: Nil), End()),
End()))
AssignField(
clsSym.asPath,
pcVar.id,
Value.Ref(pcVar),
End()
)(S(pcSymbol))))

private def genNormalBody(b: Block, clsSym: BlockMemberSymbol)(using HandlerCtx): Block =
val transform = new BlockTransformerShallow(SymbolSubst()):
Expand Down
82 changes: 61 additions & 21 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
* @param ignoredDefns The definitions which must not be lifted.
* @param inScopeDefns Definitions which are in scope to another definition (excluding itself and its nested definitions).
* @param modLocals A map from the modules and objects to the local to which it is instantiated after lifting.
* @param localCaptureSyms The symbols in a capture corresponding to a particular local
* @param localCaptureSyms The symbols in a capture corresponding to a particular local.
* The `VarSymbol` is the parameter in the capture class, and the `BlockMemberSymbol` is the field in the class.
* @param prevFnLocals Locals belonging to function definitions that have already been traversed
* @param prevClsDefns Class definitions that have already been traversed, excluding modules
* @param curModules Modules that that we are currently nested in (cleared if we are lifted out)
Expand All @@ -152,7 +153,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val ignoredDefns: Set[BlockMemberSymbol] = Set.empty,
val inScopeDefns: Map[BlockMemberSymbol, Set[BlockMemberSymbol]] = Map.empty,
val modLocals: Map[BlockMemberSymbol, Local] = Map.empty,
val localCaptureSyms: Map[Local, LocalSymbol & NamedSymbol] = Map.empty,
val localCaptureSyms: Map[Local, (VarSymbol, BlockMemberSymbol)] = Map.empty,
val prevFnLocals: FreeVars = FreeVars.empty,
val prevClsDefns: List[ClsLikeDefn] = Nil,
val curModules: List[ClsLikeDefn] = Nil,
Expand Down Expand Up @@ -186,7 +187,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
def withInScopes(mp: Map[BlockMemberSymbol, Set[BlockMemberSymbol]]) = copy(inScopeDefns = mp)
def addFnLocals(f: FreeVars) = copy(prevFnLocals = prevFnLocals ++ f)
def addClsDefn(c: ClsLikeDefn) = copy(prevClsDefns = c :: prevClsDefns)
def addLocalCaptureSyms(m: Map[Local, LocalSymbol & NamedSymbol]) = copy(localCaptureSyms = localCaptureSyms ++ m)
def addLocalCaptureSyms(m: Map[Local, (VarSymbol, BlockMemberSymbol)]) = copy(localCaptureSyms = localCaptureSyms ++ m)
def getBmsReqdInfo(sym: BlockMemberSymbol) = bmsReqdInfo.get(sym)
def replCapturePaths(paths: Map[BlockMemberSymbol, Path]) = copy(capturePaths = paths)
def addCapturePath(src: BlockMemberSymbol, path: Path) = copy(capturePaths = capturePaths + (src -> path))
Expand Down Expand Up @@ -218,7 +219,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
* @param f The function to create the capture class for.
* @param ctx The lifter context. Determines which variables will be captured.
* @return The triple (defn, varsMap, varsList), where `defn` is the capture class's definition,
* `varsMap` maps the function's locals to the correpsonding `VarSymbol` in the class, and
* `varsMap` maps the function's locals to the correpsonding `VarSymbol` (for the class parameters)
* and `BlockLocalSymbol` (for the class fields) in the class, and
* `varsList` specifies the order of these variables in the class's constructor.
*/
def createCaptureCls(f: FunDefn, ctx: LifterCtx) =
Expand All @@ -233,21 +235,54 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):

val fresh = FreshInt()

val varsMap: Map[Local, TermSymbol] = cap.map: s =>
val varsMap = cap.map: s =>
val id = fresh.make
s -> TermSymbol(syntax.ParamBind, S(clsSym), Tree.Ident(s.nme + id + "$"))
val nme = s.nme + id + "$"
val varSym = VarSymbol(Tree.Ident(nme))
val fldSym = BlockMemberSymbol(nme, Nil)
val fldDef = TermDefinition(
S(clsSym),
syntax.ImmutVal,
fldSym,
Nil, N, N,
S(Term.Ref(s)(Tree.Ident(s.nme), 666)), // FIXME: 666 is a dummy value
FlowSymbol("‹class-param-res›"),
TermDefFlags.empty,
Nil
)
fldSym.defn = S(fldDef)
s -> (
varSym,
fldDef,
)
.toMap

val varsList = cap.toList

val defn = ClsLikeDefn(
None, clsSym, BlockMemberSymbol(nme, Nil),
syntax.Cls,
S(PlainParamList(varsList.map(s => Param(FldFlags.empty, varsMap(s), None)))),
Nil, None, Nil, Nil, Nil, End(), End()
S(PlainParamList(varsList.map: s =>
val sym = varsMap(s)._1
val p = Param(FldFlags.empty.copy(value = true), sym, None)
sym.decl = S(p)
p
)),
Nil, None, Nil, Nil,
varsList.map(varsMap(_)._2),
varsList.map(varsMap(_)).foldLeft[Block](End()):
case (acc, (varSym, fldDef)) =>
AssignField(
clsSym.asPath,
Tree.Ident(fldDef.sym.nme),
Value.Ref(varSym),
acc
)(S(fldDef.sym))
,
End()
)

(defn, varsMap, varsList)
(defn, varsMap.view.mapValues(_.mapSecond(_.sym)).toMap, varsList)

private val innerSymCache: MutMap[Local, Set[Local]] = MutMap.empty

Expand Down Expand Up @@ -597,7 +632,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
case _ => super.applyBlock(rewritten)

case Assign(lhs, rhs, rest) => ctx.getLocalCaptureSym(lhs) match
case Some(captureSym) =>
case Some((captureSym, _)) =>
AssignField(ctx.getLocalClosPath(lhs).get, captureSym.id, applyResult(rhs), applyBlock(rest))(N)
case None => ctx.getLocalPath(lhs) match
case None => super.applyBlock(rewritten)
Expand Down Expand Up @@ -655,7 +690,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
// This rewrites naked references to locals. If a function is in a capture, then we select that value
// from the capture; otherwise, we see if that local is passed directly as a parameter to this defn.
case Value.Ref(l) => ctx.getLocalCaptureSym(l) match
case Some(captureSym) =>
case Some((captureSym, _)) =>
Select(ctx.getLocalClosPath(l).get, captureSym.id)(N)
case None => ctx.getLocalPath(l) match
case Some(value) => Value.Ref(value)
Expand Down Expand Up @@ -690,8 +725,13 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
val fresh = FreshInt()
(nme: String) =>
val id = fresh.make
TermSymbol(syntax.ParamBind, S(d.isym), Tree.Ident(nme + "$" + id))
case _ => ((nme: String) => VarSymbol(Tree.Ident(nme)))
(
VarSymbol(Tree.Ident(nme + "$" + id)),
TermSymbol(syntax.ParamBind, S(d.isym), Tree.Ident(nme + "$" + id))
)
case _ => (nme: String) =>
val vsym = VarSymbol(Tree.Ident(nme))
(vsym, vsym)

val capturesSymbols = includedCaptures.map: sym =>
(sym, createSym(sym.nme + "$capture"))
Expand All @@ -706,27 +746,27 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
(sym, createSym(sym.nme + "$member"))

val extraParamsCaptures = capturesSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newCapturePaths = capturesSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym.asPath
case (d, (_, sym)) => d -> sym.asPath
.toMap

val extraParamsLocals = localsSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newLocalsPaths = localsSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym
case (d, (_, sym)) => d -> sym
.toMap

val extraParamsIsyms = isymSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newIsymPaths = isymSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym
case (d, (_, sym)) => d -> sym
.toMap

val extraParamsBms = bmsSymbols.map: // parameter list
case (d, sym) => Param(FldFlags.empty, sym, None)
case (d, (sym, _)) => Param(FldFlags.empty, sym, None)
val newBmsPaths = bmsSymbols.map: // mapping from sym to param symbol
case (d, sym) => d -> sym.asPath
case (d, (_, sym)) => d -> sym.asPath
.toMap

val extraParams = extraParamsBms ++ extraParamsIsyms ++ extraParamsLocals ++ extraParamsCaptures
Expand Down
10 changes: 7 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
// * Note: `_pubFlds` is not used because in JS, fields are not declared
val clsParams = paramsOpt.fold(Nil)(_.paramSyms)
val ctorParams = clsParams.map(p => p -> scope.allocateName(p))
val ctorFields = ctorParams.filter: p =>
p._1.decl match
case S(Param(flags = FldFlags(value = true))) => true
case _ => false
val ctorAuxParams = auxParams.map(ps => ps.params.map(p => p.sym -> scope.allocateName(p.sym)))

val isModule = kind is syntax.Mod
Expand All @@ -212,7 +216,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
val nme = scp.allocateName(fld)
doc" # $mtdPrefix#$nme;"
.mkDocument(doc"")
val preCtorCode = (ctorParams ++ ctorAuxParams.flatMap(ps => ps)).foldLeft(body(preCtor, endSemi = true)):
val preCtorCode = ctorAuxParams.flatMap(ps => ps).foldLeft(body(preCtor, endSemi = true)):
case (acc, (sym, nme)) =>
doc"$acc # this.${sym.name} = $nme;"
val ctorCode = doc"$preCtorCode${body(ctor, endSemi = false)}"
Expand Down Expand Up @@ -267,9 +271,9 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
else doc""" # ${mtdPrefix}toString() { return "${sym.nme}${
if paramsOpt.isEmpty then doc"""""""
else doc"""(" + ${
ctorParams.headOption.fold("\"\"")("globalThis.Predef.render(this." + _._1.name + ")")
ctorFields.headOption.fold("\"\"")("globalThis.Predef.render(this." + _._1.name + ")")
}${
ctorParams.tailOption.fold("")(_.map(
ctorFields.tailOption.fold("")(_.map(
""" + ", " + globalThis.Predef.render(this.""" + _._1.name + ")").mkString)
} + ")""""
}; }"""
Expand Down
Loading
Loading