Skip to content

Improve module method checks #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ final case class ClsLikeDefn(

final case class Handler(
sym: BlockMemberSymbol,
resumeSym: LocalSymbol & NamedSymbol,
resumeSym: VarSymbol,
params: Ls[ParamList],
body: Block,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,12 @@ class HandlerLowering(using TL, Raise, Elaborator.State, Elaborator.Ctx):
Assign(freshTmp(), PureCall(
Value.Ref(State.builtinOpsMap("super")), // refers to Predef.__Cont which is pure
Value.Lit(Tree.UnitLit(true)) :: Value.Lit(Tree.UnitLit(true)) :: Nil), End()),
End()))
AssignField(
clsSym.asPath,
pcVar.id,
Value.Ref(pcVar),
End()
)(S(pcSymbol))))

private def genNormalBody(b: Block, clsSym: BlockMemberSymbol)(using HandlerCtx): Block =
val transform = new BlockTransformerShallow(SymbolSubst()):
Expand Down
4 changes: 1 addition & 3 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
val nme = scp.allocateName(fld)
doc" # $mtdPrefix#$nme;"
.mkDocument(doc"")
val preCtorCode = ctorParams.foldLeft(body(preCtor, endSemi = true)):
case (acc, (sym, nme)) =>
doc"$acc # this.${sym.name} = $nme;"
val preCtorCode = body(preCtor, endSemi = true)
val ctorCode = doc"$preCtorCode${body(ctor, endSemi = false)}"
val ctorOrStatic = if isModule
then doc"static"
Expand Down
52 changes: 45 additions & 7 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,35 @@ extends Importer:
params(ps)
newCtx = newCtx2
res
def withFields(using Ctx)(fn: (Ctx) ?=> (Term.Blk, Ctx)): (Term.Blk, Ctx) =
val fields = ps.map: ps =>
ps.params.map: p =>
// For class-like types, "desugar" the parameters into additional class fields.
val owner = td.symbol match
// Any MemberSymbol should be an InnerSymbol, except for TypeAliasSymbol,
// but type aliases should not call this function.
case s: InnerSymbol => S(s)
case _: TypeAliasSymbol => die
val fsym = BlockMemberSymbol(p.sym.nme, Nil)
val fdef = TermDefinition(
owner,
ImmutVal,
fsym,
Nil, N, p.sign,
S(Term.Ref(p.sym)(p.sym.id, 666)), // FIXME: 666 is a dummy value
FlowSymbol("‹class-param-res›"),
TermDefFlags.empty.copy(isModMember = k is Mod),
Nil
)
sym.defn = S(fdef)
fdef
val ctxWithFields = ctx.withMembers(
fields.fold(Nil)(_.map(f => f.sym.nme -> f.sym)),
ctx.outer
)
val (blk, c) = fn(using ctxWithFields)
val blkWithFields = fields.fold[Term.Blk](blk)(fs => blk.copy(stats = fs ::: blk.stats))
(blkWithFields, c)
val defn = k match
case Als =>
val alsSym = td.symbol.asInstanceOf[TypeAliasSymbol] // TODO improve `asInstanceOf`
Expand Down Expand Up @@ -1039,7 +1068,8 @@ extends Importer:
newCtx.nest(S(clsSym)).givenIn:
log(s"Processing type definition $nme")
val cd =
val (bod, c) = body match
val (bod, c) = withFields:
body match
case S(b: Tree.Block) => block(b, hasResult = false)
// case S(t) => block(t :: Nil)
case S(t) => ???
Expand All @@ -1053,7 +1083,8 @@ extends Importer:
newCtx.nest(S(clsSym)).givenIn:
log(s"Processing type definition $nme")
val cd =
val (bod, c) = body match
val (bod, c) = withFields:
body match
case S(b: Tree.Block) => block(b, hasResult = false)
// case S(t) => block(t :: Nil)
case S(t) => ???
Expand Down Expand Up @@ -1092,13 +1123,14 @@ extends Importer:
N
case N => N

def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): LocalSymbol & NamedSymbol =
def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): TermSymbol | VarSymbol =
if ctx.outer.isDefined then TermSymbol(k, ctx.outer, id)
else VarSymbol(id)

def param(t: Tree, inUsing: Bool): Ctxl[Opt[Opt[Bool] -> Param]] = t match
def param(t: Tree, inUsing: Bool): Ctxl[Opt[Opt[Bool] -> Param]] =
def go(t: Tree, inUsing: Bool, flags: FldFlags): Ctxl[Opt[Opt[Bool] -> Param]] = t match
case TypeDef(Mod, inner, N, N) =>
val ps = param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(mod = true))))
val ps = go(inner, inUsing, flags.copy(mod = true))
for p <- ps if p._2.flags.mod do p._2.sign match
case N =>
raise(ErrorReport(msg"Module parameters must have explicit types." -> t.toLoc :: Nil))
Expand All @@ -1107,10 +1139,16 @@ extends Importer:
case _ => ()
ps
case TypeDef(Pat, inner, N, N) =>
param(inner, inUsing).map(_.mapSecond(p => p.copy(flags = p.flags.copy(pat = true))))
go(inner, inUsing, flags.copy(pat = true))
case _ =>
t.asParam(inUsing).map: (isSpd, p, t) =>
isSpd -> Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term(_)))
val sym = VarSymbol(p)
val sign = t.map(term(_))
val param = Param(flags, sym, sign)
sym.decl = S(param)
isSpd -> param
go(t, inUsing, FldFlags.empty)


def params(t: Tree): Ctxl[(ParamList, Ctx)] = t match
case Tup(ps) =>
Expand Down
52 changes: 39 additions & 13 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ImplicitResolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import syntax.{Fun, Ins, Mod}
import semantics.Term
import semantics.Elaborator.State
import ImplicitResolver.ICtx.Type
import ImplicitResolver.TyParamSymbol

import Message.MessageContext

Expand All @@ -25,23 +24,21 @@ class CtxArgImpl extends CtxArg:

object ImplicitResolver:

type TyParamSymbol = LocalSymbol & NamedSymbol

/*
* An "implicit" or "instance" context, as opposed to the one in Elaborator.
*/
case class ICtx(
parent: Opt[ICtx],
iEnv: Map[Type.Sym, Ls[(Type, ICtx.Instance)]],
tEnv: Map[TyParamSymbol, Type]
tEnv: Map[VarSymbol, Type]
):

def +(typ: Type.Concrete, sym: Symbol): ICtx =
val newLs = (typ -> ICtx.Instance(sym)) :: iEnv.getOrElse(typ.toSym, Nil)
val newEnv = iEnv + (typ.toSym -> newLs)
copy(iEnv = newEnv)

def withTypeArg(param: TyParamSymbol, arg: Type): ICtx =
def withTypeArg(param: VarSymbol, arg: Type): ICtx =
copy(tEnv = tEnv + (param -> arg))

def get(query: Type.Concrete): Opt[ICtx.Instance] =
Expand Down Expand Up @@ -131,7 +128,7 @@ class ImplicitResolver(tl: TraceLogger)
msg"got ${targs.length.toString()}" -> base.toLoc :: Nil))
(tparams zip targs).foldLeft(ictx):
case (ictx, (tparam, targ)) => (tparam.sym, resolveType(targ)) match
case (sym: TyParamSymbol, S(typ)) =>
case (sym: VarSymbol, S(typ)) =>
log(s"Resolving App with type arg ${sym} = $typ")
ictx.withTypeArg(sym, typ)
case _ => ictx
Expand Down Expand Up @@ -282,21 +279,50 @@ object ModuleChecker:
.exists(_.isInstanceOf[TyParam])

/** Checks if a term evaluates to a module value. */
def evalsToModule(t: Term): Bool =
def isModule(t: Tree): Bool = t match
case Tree.TypeDef(Mod, _, _, _) => true
case _ => false
def evalsToModule(t: Term): Bool =
def returnsModule(t: Tree.TermDef): Bool = t.annotatedResultType match
case S(Tree.TypeDef(Mod, _, N, N)) => true
case _ => false
def checkDecl(decl: Declaration): Bool = decl match
// All TypeLikeDef are not modules, except for modules themselves.
// Objects use ModuleDef but is not a module.
case ModuleDef(kind = Mod) =>
true
case _: TypeLikeDef =>
false
// Check Member/Local symbols
case defn: TermDefinition =>
defn.flags.isModTyped
case defn: Param =>
defn.flags.mod
case defn: TyParam =>
defn.flags.mod
def checkSym(sym: Symbol): Bool = sym match
case sym if sym.asMod.nonEmpty => true
case sym if sym.asBlkMember.flatMap(_.trmTree).exists(returnsModule) => true
case _: (BuiltinSymbol | TopLevelSymbol) => false
case sym: BlockLocalSymbol => sym.decl match
case S(decl) => checkDecl(decl)
case N =>
// Most local symbols are let-bindings
// which do not have a definition at this point.
false
case sym: MemberSymbol[?] => sym.defn match
case S(defn) => checkDecl(defn)
case N =>
// At this point all member symbols should have definition,
// except for the class(-like) that are currently being elaborated.
// TODO: We will fix this by deferring the checks to the resolution stage.
false
case sym =>
lastWords(s"Unsupported symbol kind ${sym}")
t match
case Term.Blk(_, res) => evalsToModule(res)
case Term.App(lhs, rhs) => lhs.symbol match
case S(sym: BlockMemberSymbol) => sym.trmTree.exists(returnsModule)
case _ => false
case t => t.symbol match
case S(sym: BlockMemberSymbol) => sym.modTree.exists(isModule)
case _ => false
case t: Term.Ref => checkSym(t.sym)
case t => t.symbol.exists(checkSym)

/**
* An extractor that extracts the (tree) definition of a module method.
Expand Down
1 change: 1 addition & 0 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ class InstSymbol(val origin: Symbol)(using State) extends LocalSymbol:

class VarSymbol(val id: Ident)(using State) extends BlockLocalSymbol(id.name) with NamedSymbol with LocalSymbol:
val name: Str = id.name
override def toLoc: Option[Loc] = id.toLoc
// override def toString: Str = s"$name@$uid"
override def subst(using s: SymbolSubst): VarSymbol = s.mapVarSym(this)

Expand Down
10 changes: 5 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,14 @@ final case class LetDecl(sym: LocalSymbol, annotations: Ls[Annot]) extends State

final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement

final case class TermDefFlags(isModMember: Bool):
final case class TermDefFlags(isModMember: Bool, isModTyped: Bool):
def showDbg: Str =
val flags = Buffer.empty[String]
if isModMember then flags += "module"
flags.mkString(" ")
override def toString: String = "‹" + showDbg + "›"

object TermDefFlags { val empty: TermDefFlags = TermDefFlags(false) }
object TermDefFlags { val empty: TermDefFlags = TermDefFlags(false, false) }

final case class TermDefinition(
owner: Opt[InnerSymbol],
Expand All @@ -289,7 +289,7 @@ final case class TermDefinition(
case _ => true

final case class HandlerTermDefinition(
resumeSym: LocalSymbol & NamedSymbol,
resumeSym: VarSymbol,
td: TermDefinition
)

Expand Down Expand Up @@ -493,8 +493,8 @@ final case class TyParam(flags: FldFlags, vce: Opt[Bool], sym: VarSymbol) extend
flags.showDbg + sym


final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Opt[Term])
extends AutoLocated:
final case class Param(flags: FldFlags, sym: VarSymbol, sign: Opt[Term])
extends Declaration with AutoLocated:
def subTerms: Ls[Term] = sign.toList
override protected def children: List[Located] = subTerms
// def children: Ls[Located] = self.value :: self.asc.toList ::: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ extension (split: DeBrujinSplit)
Split.End // TODO: report mismatched arity

/** To instantiate the body of a pattern synonym. */
def instantiate(context: Map[LocalSymbol & NamedSymbol, DeBrujinSplit])(using tl: TraceLogger): DeBrujinSplit =
def instantiate(context: Map[VarSymbol, DeBrujinSplit])(using tl: TraceLogger): DeBrujinSplit =
import DeBrujinSplit.*, PatternStub.*, ConstructorLike.*, tl.*
def go(split: DeBrujinSplit): DeBrujinSplit = split.traceChange("instantiate"):
split match
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class Normalization(elaborator: Elaborator)(using raise: Raise, ctx: Ctx):
// the number of free variables, bind them, and substitute them
// with the new indices.
val paramSymbols = (1 to split.arity).map: i =>
TermSymbol(ParamBind, N, Ident(s"param$i"))
VarSymbol(Ident(s"param$i"))
.toVector
val paramList = PlainParamList:
paramSymbols.iterator.map(Param(FldFlags.empty, _, N)).toList
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum ConstructorLike:
/** This case represents pattern parameters, which only make sense within the
* body of the pattern declaration.
*/
case Parameter(symbol: LocalSymbol & NamedSymbol)
case Parameter(symbol: VarSymbol)
/**
* This case represents a nested split, where the arity of the split must be 1
* (there is exactly one `Binder` at the top level). The split must not have
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ class Translator(val elaborator: Elaborator)
private def errorSplit: Split = Split.Else(Term.Error)

/** Create a function definition from the given UCS splits. */
private def makeMatcher(name: Str, scrut: TermSymbol, topmost: Split)(using Raise): TermDefinition =
private def makeMatcher(name: Str, scrut: VarSymbol, topmost: Split)(using Raise): TermDefinition =
val normalize = new Normalization(elaborator)
val sym = BlockMemberSymbol(name, Nil)
val ps = PlainParamList(Param(FldFlags.empty, scrut, N) :: Nil)
Expand All @@ -221,15 +221,15 @@ class Translator(val elaborator: Elaborator)
post = (blk: Ls[TermDefinition]) => s"Translator >>> $blk"
):
val unapply = scoped("ucs:cp"):
val scrutSym = TermSymbol(ParamBind, N, Ident("scrut"))
val scrutSym = VarSymbol(Ident("scrut"))
val topmost = full(() => scrutSym.ref(), body, success(params))(using patternParams, raise) ~~: failure
log(s"Translated `unapply`: ${display(topmost)}")
makeMatcher("unapply", scrutSym, topmost)
val unapplyStringPrefix = scoped("ucs:cp"):
// We don't report errors here because they are already reported in the
// translation of `unapply` function.
given Raise = Function.const(())
val scrutSym = TermSymbol(ParamBind, N, Ident("topic"))
val scrutSym = VarSymbol(Ident("topic"))
stringPrefix(() => scrutSym.ref(), body, prefixSuccess(params)) match
case Split.Else(Term.Error) =>
makeMatcher("unapplyStringPrefix", scrutSym, failure)
Expand Down
10 changes: 8 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/utils/utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ extension (s: String)

import hkmc2.semantics.TermDefFlags
import hkmc2.semantics.FldFlags
import hkmc2.semantics.ParamListFlags
import scala.collection.mutable.Buffer
import mlscript.utils.StringOps
import hkmc2.semantics.CtxArg
Expand All @@ -42,9 +43,10 @@ extension (t: Product)
case xs: List[_] => "Ls of \n" + xs.iterator.map(aux(_)).mkString("\n").indent(" ")
case xs: Vector[_] => "Vector of \n" + xs.iterator.map(aux(_)).mkString("\n").indent(" ")
case s: String => s.escaped
case TermDefFlags(mod) =>
case TermDefFlags(isModMember, isModTyped) =>
val flags = Buffer.empty[String]
if mod then flags += "module"
if isModMember then flags += "modMember"
if isModMember then flags += "modTyped"
flags.mkString("(", ", ", ")")
case FldFlags(mut, spec, genGetter, mod, pat) =>
val flags = Buffer.empty[String]
Expand All @@ -54,6 +56,10 @@ extension (t: Product)
if mod then flags += "module"
if pat then flags += "pat"
flags.mkString("(", ", ", ")")
case ParamListFlags(ctx) =>
val flags = Buffer.empty[String]
if ctx then flags += "ctx"
flags.mkString("(", ", ", ")")
case Loc(start, end, origin) =>
val (sl, _, sc) = origin.fph.getLineColAt(start)
val (el, _, ec) = origin.fph.getLineColAt(end)
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript-compile/apps/Accounting.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ Accounting1 = class Accounting {
this.Report = function Report(fileName1) { return new Report.class(fileName1); };
this.Report.class = class Report {
constructor(fileName) {
this.fileName = fileName;
let tmp;
this.fileName = fileName;
tmp = fs.writeFileSync(this.fileName, "# Accounting\n");
}
w(txt) {
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript-compile/apps/CSV.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ let CSV1;
CSV1 = function CSV(strDelimiter1) { return new CSV.class(strDelimiter1); };
CSV1.class = class CSV {
constructor(strDelimiter) {
this.strDelimiter = strDelimiter;
let tmp, tmp1, tmp2, tmp3, tmp4, tmp5, tmp6, tmp7;
this.strDelimiter = strDelimiter;
tmp = this.strDelimiter || ",";
this.strDelimiter = tmp;
tmp1 = "(\\" + this.strDelimiter;
Expand Down
Loading
Loading