Skip to content

Improve module method checks #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ final case class ClsLikeDefn(

final case class Handler(
sym: BlockMemberSymbol,
resumeSym: LocalSymbol & NamedSymbol,
resumeSym: VarSymbol,
params: Ls[ParamList],
body: Block,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,12 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
Assign(freshTmp(), PureCall(
Value.Ref(State.builtinOpsMap("super")), // refers to Predef.__Cont which is pure
Value.Lit(Tree.UnitLit(true)) :: Value.Lit(Tree.UnitLit(true)) :: Nil), End()),
End()))
AssignField(
clsSym.asPath,
pcVar.id,
Value.Ref(pcVar),
End()
)(S(pcSymbol))))

private def genNormalBody(b: Block, clsSym: BlockMemberSymbol)(using HandlerCtx): Block =
val transform = new BlockTransformerShallow(SymbolSubst()):
Expand Down
12 changes: 7 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
// * Note: `_pubFlds` is not used because in JS, fields are not declared
val clsParams = paramsOpt.fold(Nil)(_.paramSyms)
val ctorParams = clsParams.map(p => p -> scope.allocateName(p))
val ctorFields = ctorParams.filter: p =>
p._1.decl match
case S(Param(flags = FldFlags(value = true))) => true
case _ => false
val isModule = kind is syntax.Mod
val mtdPrefix = if isModule then "static " else ""
val privs =
Expand All @@ -196,9 +200,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
val nme = scp.allocateName(fld)
doc" # $mtdPrefix#$nme;"
.mkDocument(doc"")
val preCtorCode = ctorParams.foldLeft(body(preCtor, endSemi = true)):
case (acc, (sym, nme)) =>
doc"$acc # this.${sym.name} = $nme;"
val preCtorCode = body(preCtor, endSemi = true)
val ctorCode = doc"$preCtorCode${body(ctor, endSemi = false)}"
val ctorOrStatic = if isModule
then doc"static"
Expand Down Expand Up @@ -239,9 +241,9 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
else doc""" # ${mtdPrefix}toString() { return "${sym.nme}${
if paramsOpt.isEmpty then doc"""""""
else doc"""(" + ${
ctorParams.headOption.fold("\"\"")("globalThis.Predef.render(this." + _._1.name + ")")
ctorFields.headOption.fold("\"\"")("globalThis.Predef.render(this." + _._1.name + ")")
}${
ctorParams.tailOption.fold("")(_.map(
ctorFields.tailOption.fold("")(_.map(
""" + ", " + globalThis.Predef.render(this.""" + _._1.name + ")").mkString)
} + ")""""
}; }"""
Expand Down
87 changes: 72 additions & 15 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ extends Importer:
case _ => trm

def annot(tree: Tree): Ctxl[Opt[Annot]] = tree match
case Keywrd(kw @ (Keyword.`abstract` | Keyword.`declare`)) => S(Annot.Modifier(kw))
case Keywrd(kw @ (Keyword.`abstract` | Keyword.`declare` | Keyword.`data`)) => S(Annot.Modifier(kw))
case _ => term(tree) match
case Term.Error => N
case trm =>
Expand Down Expand Up @@ -375,7 +375,7 @@ extends Importer:
Term.FunTy(term(lhs), term(rhs), N)
case InfixApp(lhs, Keyword.`=>`, rhs) =>
ctx.nest(N).givenIn:
val (syms, nestCtx) = params(lhs)
val (syms, nestCtx) = params(lhs, false)
Term.Lam(syms, term(rhs)(using nestCtx))
case InfixApp(lhs, Keyword.`:`, rhs) =>
Term.Asc(term(lhs), term(rhs))
Expand Down Expand Up @@ -894,7 +894,7 @@ extends Importer:
// * Add parameters to context
var newCtx = newCtx1
val pss = td.paramLists.map: ps =>
val (res, newCtx2) = params(ps)(using newCtx)
val (res, newCtx2) = params(ps, false)(using newCtx)
newCtx = newCtx2
res
// * Elaborate signature
Expand Down Expand Up @@ -972,6 +972,9 @@ extends Importer:
res :: Nil
case N => Nil
newCtx ++= tps.map(tp => tp.sym.name -> tp.sym) // TODO: correct ++?
val isDataClass = annotations.exists:
case Annot.Modifier(Keyword.`data`) => true
case _ => false
val ps =
td.paramLists.match
case Nil => N
Expand All @@ -985,9 +988,52 @@ extends Importer:
.map: ps =>
val (res, newCtx2) =
given Ctx = newCtx
params(ps)
params(ps, isDataClass)
newCtx = newCtx2
res
def withFields(using Ctx)(fn: (Ctx) ?=> (Term.Blk, Ctx)): (Term.Blk, Ctx) =
val fields: Opt[List[TermDefinition | LetDecl | DefineVar]] = ps.map: ps =>
ps.params.flatMap: p =>
// For class-like types, "desugar" the parameters into additional class fields.
val owner = td.symbol match
// Any MemberSymbol should be an InnerSymbol, except for TypeAliasSymbol,
// but type aliases should not call this function.
case s: InnerSymbol => S(s)
case _: TypeAliasSymbol => die

if p.flags.value || isDataClass then
val fsym = BlockMemberSymbol(p.sym.nme, Nil)
val fdef = TermDefinition(
owner,
ImmutVal,
fsym,
Nil, N, N,
S(Term.Ref(p.sym)(p.sym.id, 666)), // FIXME: 666 is a dummy value
FlowSymbol("‹class-param-res›"),
TermDefFlags.empty.copy(isModMember = k is Mod),
Nil
)
sym.defn = S(fdef)
fdef :: Nil
else
val psym = TermSymbol(LetBind, owner, p.sym.id)
val decl = LetDecl(psym, Nil)
val defn = DefineVar(psym, Term.Ref(p.sym)(p.sym.id, 666)) // FIXME: 666 is a dummy value
decl :: defn :: Nil

val ctxWithFields = ctx
.withMembers(
fields.fold(Nil)(_.collect:
case f: TermDefinition => f.sym.nme -> f.sym // class fields
),
ctx.outer
) ++ fields.fold(Nil)(_.collect:
case d: LetDecl => d.sym.nme -> d.sym // class params
)
val ctxWithLets = ctx
val (blk, c) = fn(using ctxWithFields)
val blkWithFields = fields.fold[Term.Blk](blk)(fs => blk.copy(stats = fs ::: blk.stats))
(blkWithFields, c)
val defn = k match
case Als =>
val alsSym = td.symbol.asInstanceOf[TypeAliasSymbol] // TODO improve `asInstanceOf`
Expand All @@ -1010,8 +1056,8 @@ extends Importer:
case S(tree) =>
val (patternParams, extractionParams) = ps match // Filter out pattern parameters.
case S(ParamList(_, params, _)) => params.partition:
case param @ Param(FldFlags(false, false, false, false, true), _, _) => true
case param @ Param(FldFlags(_, _, _, _, false), _, _) => false
case param @ Param(FldFlags(false, false, false, false, true, false), _, _) => true
case param @ Param(FldFlags(_, _, _, _, false, _), _, _) => false
case N => (Nil, Nil)
// TODO: Implement extraction parameters.
if extractionParams.nonEmpty then
Expand Down Expand Up @@ -1039,7 +1085,8 @@ extends Importer:
newCtx.nest(S(clsSym)).givenIn:
log(s"Processing type definition $nme")
val cd =
val (bod, c) = body match
val (bod, c) = withFields:
body match
case S(b: Tree.Block) => block(b, hasResult = false)
// case S(t) => block(t :: Nil)
case S(t) => ???
Expand All @@ -1053,7 +1100,8 @@ extends Importer:
newCtx.nest(S(clsSym)).givenIn:
log(s"Processing type definition $nme")
val cd =
val (bod, c) = body match
val (bod, c) = withFields:
body match
case S(b: Tree.Block) => block(b, hasResult = false)
// case S(t) => block(t :: Nil)
case S(t) => ???
Expand Down Expand Up @@ -1092,13 +1140,14 @@ extends Importer:
N
case N => N

def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): LocalSymbol & NamedSymbol =
def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): TermSymbol | VarSymbol =
if ctx.outer.isDefined then TermSymbol(k, ctx.outer, id)
else VarSymbol(id)

def param(t: Tree, inUsing: Bool): Ctxl[Opt[Opt[Bool] -> Param]] = t match
def param(t: Tree, inUsing: Bool, inDataClass: Bool): Ctxl[Opt[Opt[Bool] -> Param]] =
def go(t: Tree, inUsing: Bool, flags: FldFlags): Ctxl[Opt[Opt[Bool] -> Param]] = t match
case TypeDef(Mod, inner, N, N) =>
val ps = param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(mod = true))))
val ps = go(inner, inUsing, flags.copy(mod = true))
for p <- ps if p._2.flags.mod do p._2.sign match
case N =>
raise(ErrorReport(msg"Module parameters must have explicit types." -> t.toLoc :: Nil))
Expand All @@ -1107,18 +1156,26 @@ extends Importer:
case _ => ()
ps
case TypeDef(Pat, inner, N, N) =>
param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(pat = true))))
go(inner, inUsing, flags.copy(pat = true))
case TermDef(ImmutVal, inner, _) =>
go(inner, inUsing, flags.copy(value = true))
case _ =>
t.asParam(inUsing).map: (isSpd, p, t) =>
isSpd -> Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term(_)))
val sym = VarSymbol(p)
val sign = t.map(term(_))
val param = Param(flags, sym, sign)
sym.decl = S(param)
isSpd -> param
go(t, inUsing, if inDataClass then FldFlags.empty.copy(value = true) else FldFlags.empty)


def params(t: Tree): Ctxl[(ParamList, Ctx)] = t match
def params(t: Tree, inDataClass: Bool): Ctxl[(ParamList, Ctx)] = t match
case Tup(ps) =>
def go(ps: Ls[Tree], acc: Ls[Param], ctx: Ctx, flags: ParamListFlags): (ParamList, Ctx) =
ps match
case Nil => (ParamList(flags, acc.reverse, N), ctx)
case hd :: tl =>
param(hd, flags.ctx)(using ctx) match
param(hd, flags.ctx, inDataClass)(using ctx) match
case S((isSpd, p)) =>
val isCtx = hd match
case Modified(Keyword.`using`, _, _) => true
Expand Down
52 changes: 39 additions & 13 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ImplicitResolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import syntax.{Fun, Ins, Mod}
import semantics.Term
import semantics.Elaborator.State
import ImplicitResolver.ICtx.Type
import ImplicitResolver.TyParamSymbol

import Message.MessageContext

Expand All @@ -25,23 +24,21 @@ class CtxArgImpl extends CtxArg:

object ImplicitResolver:

type TyParamSymbol = LocalSymbol & NamedSymbol

/*
* An "implicit" or "instance" context, as opposed to the one in Elaborator.
*/
case class ICtx(
parent: Opt[ICtx],
iEnv: Map[Type.Sym, Ls[(Type, ICtx.Instance)]],
tEnv: Map[TyParamSymbol, Type]
tEnv: Map[VarSymbol, Type]
):

def +(typ: Type.Concrete, sym: Symbol): ICtx =
val newLs = (typ -> ICtx.Instance(sym)) :: iEnv.getOrElse(typ.toSym, Nil)
val newEnv = iEnv + (typ.toSym -> newLs)
copy(iEnv = newEnv)

def withTypeArg(param: TyParamSymbol, arg: Type): ICtx =
def withTypeArg(param: VarSymbol, arg: Type): ICtx =
copy(tEnv = tEnv + (param -> arg))

def get(query: Type.Concrete): Opt[ICtx.Instance] =
Expand Down Expand Up @@ -131,7 +128,7 @@ class ImplicitResolver(tl: TraceLogger)
msg"got ${targs.length.toString()}" -> base.toLoc :: Nil))
(tparams zip targs).foldLeft(ictx):
case (ictx, (tparam, targ)) => (tparam.sym, resolveType(targ)) match
case (sym: TyParamSymbol, S(typ)) =>
case (sym: VarSymbol, S(typ)) =>
log(s"Resolving App with type arg ${sym} = $typ")
ictx.withTypeArg(sym, typ)
case _ => ictx
Expand Down Expand Up @@ -282,21 +279,50 @@ object ModuleChecker:
.exists(_.isInstanceOf[TyParam])

/** Checks if a term evaluates to a module value. */
def evalsToModule(t: Term): Bool =
def isModule(t: Tree): Bool = t match
case Tree.TypeDef(Mod, _, _, _) => true
case _ => false
def evalsToModule(t: Term): Bool =
def returnsModule(t: Tree.TermDef): Bool = t.annotatedResultType match
case S(Tree.TypeDef(Mod, _, N, N)) => true
case _ => false
def checkDecl(decl: Declaration): Bool = decl match
// All TypeLikeDef are not modules, except for modules themselves.
// Objects use ModuleDef but is not a module.
case ModuleDef(kind = Mod) =>
true
case _: TypeLikeDef =>
false
// Check Member/Local symbols
case defn: TermDefinition =>
defn.flags.isModTyped
case defn: Param =>
defn.flags.mod
case defn: TyParam =>
defn.flags.mod
def checkSym(sym: Symbol): Bool = sym match
case sym if sym.asMod.nonEmpty => true
case sym if sym.asBlkMember.flatMap(_.trmTree).exists(returnsModule) => true
case _: (BuiltinSymbol | TopLevelSymbol) => false
case sym: BlockLocalSymbol => sym.decl match
case S(decl) => checkDecl(decl)
case N =>
// Most local symbols are let-bindings
// which do not have a definition at this point.
false
case sym: MemberSymbol[?] => sym.defn match
case S(defn) => checkDecl(defn)
case N =>
// At this point all member symbols should have definition,
// except for the class(-like) that are currently being elaborated.
// TODO: We will fix this by deferring the checks to the resolution stage.
false
case sym =>
lastWords(s"Unsupported symbol kind ${sym}")
t match
case Term.Blk(_, res) => evalsToModule(res)
case Term.App(lhs, rhs) => lhs.symbol match
case S(sym: BlockMemberSymbol) => sym.trmTree.exists(returnsModule)
case _ => false
case t => t.symbol match
case S(sym: BlockMemberSymbol) => sym.modTree.exists(isModule)
case _ => false
case t: Term.Ref => checkSym(t.sym)
case t => t.symbol.exists(checkSym)

/**
* An extractor that extracts the (tree) definition of a module method.
Expand Down
1 change: 1 addition & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol:

class VarSymbol(val id: Ident)(using State) extends BlockLocalSymbol(id.name) with NamedSymbol with LocalSymbol:
val name: Str = id.name
override def toLoc: Opt[Loc] = id.toLoc
// override def toString: Str = s"$name@$uid"
override def subst(using s: SymbolSubst): VarSymbol = s.mapVarSym(this)

Expand Down
Loading