Skip to content

Module Methods #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 22, 2024
6 changes: 3 additions & 3 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,13 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
effBuff += eff
nestCtx += sym -> rhsTy
goStats(stats)
case TermDefinition(_, Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _) :: stats =>
case TermDefinition(_, Fun, sym, ParamList(_, ps) :: Nil, sig, Some(body), _, _) :: stats =>
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
goStats(stats)
case TermDefinition(_, Fun, sym, Nil, sig, Some(body), _) :: stats =>
case TermDefinition(_, Fun, sym, Nil, sig, Some(body), _, _) :: stats =>
typeFunDef(sym, body, sig, ctx) // * may be a case expressions
goStats(stats)
case TermDefinition(_, Fun, sym, _, S(sig), None, _) :: stats =>
case TermDefinition(_, Fun, sym, _, S(sig), None, _, _) :: stats =>
ctx += sym -> typeType(sig)
goStats(stats)
case (clsDef: ClassDef) :: stats =>
Expand Down
1 change: 1 addition & 0 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class Lowering(using TL, Raise, Elaborator.State):
case Tup(fs) =>
val as = fs.map:
case sem.Fld(sem.FldFlags.empty, value, N) => value
case sem.Fld(sem.FldFlags(false, false, false, true), value, N) => value
case sem.Fld(flags, value, asc) =>
TODO("Other argument forms")
val l = new TempSymbol(summon[Elaborator.State].nextUid, S(t))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class JSBuilder extends CodeBuilder:
} + ")""""
}; }"""
} #} # }"
if clsDefn.kind is syntax.Mod then
if (clsDefn.kind is syntax.Mod) || (clsDefn.kind is syntax.Obj) then
val clsTmp = summon[Scope].allocateName(new semantics.TempSymbol(0/*TODO rm this useless param*/, N, sym.nme+"$"+"class"))
clsDefn.owner match
case S(owner) =>
Expand Down
142 changes: 128 additions & 14 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,34 @@ extends Importer:
term(rhs)
case tree @ App(lhs, rhs) =>
val sym = FlowSymbol("‹app-res›", nextUid)
Term.App(term(lhs), term(rhs))(tree, sym)
val lt = term(lhs)
val rt = term(rhs)

// Check if module arguments match module parameters
val args = rt match
case Term.Tup(fields) => S(fields)
case _ => N
val params = lt.symbol
.collect:
case sym: BlockMemberSymbol => sym.trmTree
.flatten
.collect:
case td: TermDef => td.paramLists.headOption
.flatten
for
(args, params) <- (args zip params)
(arg, param) <- (args zip params.fields)
do
val argMod = arg.flags.mod
val paramMod = param match
case Tree.TypeDef(Mod, _, N, N) => true
case _ => false
if argMod && !paramMod then raise:
ErrorReport:
msg"Only module parameters may receive module arguments (values)." ->
arg.toLoc :: Nil

Term.App(lt, rt)(tree, sym)
case Sel(pre, nme) =>
val preTrm = term(pre)
val sym = resolveField(nme, preTrm.symbol, nme)
Expand Down Expand Up @@ -300,6 +327,8 @@ extends Importer:
Term.Throw(term(body))
case Modified(Keyword.`do`, kwLoc, body) =>
Term.Blk(term(body) :: Nil, unit)
case TypeDef(Mod, head, N, N) =>
term(head)
case Tree.Region(id: Tree.Ident, body) =>
val sym = VarSymbol(id, nextUid)
val nestCtx = ctx + (id.name -> sym)
Expand Down Expand Up @@ -354,7 +383,12 @@ extends Importer:
def fld(tree: Tree): Ctxl[Fld] = tree match
case InfixApp(lhs, Keyword.`:`, rhs) =>
Fld(FldFlags.empty, term(lhs), S(term(rhs)))
case _ => Fld(FldFlags.empty, term(tree), N)
case _ =>
val t = term(tree)
val flags = FldFlags.empty
if ModuleChecker.evalsToModule(t)
then Fld(flags.copy(mod = true), t, N)
else Fld(flags, t, N)

def unit: Term.Lit = Term.Lit(UnitLit(true))

Expand Down Expand Up @@ -385,7 +419,10 @@ extends Importer:
raise(ErrorReport(msg"Multiple declarations of symbol '$name'" -> N ::
decls.map(msg"declared here" -> _.toLoc)))
val sig = decls.collectFirst:
case td if td.signature.isDefined => td.signature.get
case td
if td.annotatedResultType.isDefined
&& td.paramLists.isEmpty
=> td.annotatedResultType.get
sig.foreach: sig =>
newSignatureTrees += name -> sig

Expand Down Expand Up @@ -503,29 +540,61 @@ extends Importer:
case R(id) =>
val sym = members.getOrElse(id.name, die)
val owner = ctx.outer
val isModMember = owner.exists(_.isInstanceOf[ModuleSymbol])
val tdf = ctx.nest(N).givenIn:
// * Add type parameters to context
val (tps, newCtx1) = td.typeParams match
case S(t) => typeParams(t)
case N => (N, ctx)
// * Add parameters to context
val (pss, newCtx) =
val (pss, newCtx) =
td.paramLists.foldLeft(Ls[ParamList](), newCtx1):
case ((pss, ctx), ps) =>
val (qs, newCtx) = params(ps)(using ctx)
(pss :+ ParamList(ParamListFlags.empty, qs), newCtx)
// * Elaborate signature
val st = td.annotatedResultType.orElse(newSignatureTrees.get(id.name))
val s = st.map(term(_)(using newCtx))
val b = rhs.map(term(_)(using newCtx))
val r = FlowSymbol(s"‹result of ${sym}›", nextUid)
val tdf = TermDefinition(owner, k, sym, pss,
td.signature.orElse(newSignatureTrees.get(id.name)).map(term), b, r)
val tdf = TermDefinition(owner, k, sym, pss, s, b, r,
TermDefFlags.empty.copy(isModMember = isModMember))
sym.defn = S(tdf)

// indicates if the function really returns a module
val em = b.exists(ModuleChecker.evalsToModule)
// indicates if the function marks its result as "module"
val mm = st match
case Some(TypeDef(Mod, _, N, N)) => true
case _ => false

// checks rules regarding module methods
s match
case N if em => raise:
ErrorReport:
msg"Function returning module values must have explicit return types." ->
td.head.toLoc :: Nil
case S(t) if em && ModuleChecker.isTypeParam(t) => raise:
ErrorReport:
msg"Function returning module values must have concrete return types." ->
td.head.toLoc :: Nil
case S(_) if em && !mm => raise:
ErrorReport:
msg"The return type of functions returning module values must be prefixed with module keyword." ->
td.head.toLoc :: Nil
case S(_) if mm && !isModMember => raise:
ErrorReport:
msg"Only module methods may return module values." ->
td.head.toLoc :: Nil
case _ => ()

tdf
go(sts, tdf :: acc)
case L(d) =>
raise(d)
go(sts, acc)
case (td @ TypeDef(k, head, extension, body)) :: sts =>
assert((k is Als) || (k is Cls) || (k is Mod), k)
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj), k)
val nme = td.name match
case R(id) => id
case L(d) =>
Expand Down Expand Up @@ -577,7 +646,7 @@ extends Importer:
semantics.TypeDef(alsSym, tps, extension.map(term), N)
alsSym.defn = S(d)
d
case Mod =>
case k: (Mod.type | Obj.type) =>
val clsSym = td.symbol.asInstanceOf[ModuleSymbol] // TODO: improve `asInstanceOf`
val owner = ctx.outer
newCtx.nest(S(clsSym)).givenIn:
Expand All @@ -588,7 +657,7 @@ extends Importer:
// case S(t) => block(t :: Nil)
case S(t) => ???
case N => (new Term.Blk(Nil, Term.Lit(UnitLit(true))), ctx)
ModuleDef(owner, clsSym, tps, ps, ObjBody(bod))
ModuleDef(owner, clsSym, tps, ps, k, ObjBody(bod))
clsSym.defn = S(cd)
cd
case Cls =>
Expand All @@ -612,7 +681,6 @@ extends Importer:
// TODO: pass abstract to `go`
go(body :: sts, acc)
case Modified(Keyword.`declare`, absLoc, body) :: sts =>
???
// TODO: pass declare to `go`
go(body :: sts, acc)
case (result: Tree) :: Nil =>
Expand All @@ -631,8 +699,18 @@ extends Importer:
if ctx.outer.isDefined then TermSymbol(k, ctx.outer, id)
else VarSymbol(id, nextUid)

def param(t: Tree): Ctxl[Ls[Param]] = t.param.map: (p, t) =>
Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term))
def param(t: Tree): Ctxl[Ls[Param]] = t match
case TypeDef(Mod, inner, N, N) =>
val ps = param(inner).map(p => p.copy(flags = p.flags.copy(mod = true)))
for p <- ps if p.flags.mod do p.sign match
case N =>
raise(ErrorReport(msg"Module parameters must have explicit types." -> t.toLoc :: Nil))
case S(ret) if ModuleChecker.isTypeParam(ret) =>
raise(ErrorReport(msg"Module parameters must have concrete types." -> t.toLoc :: Nil))
case _ => ()
ps
case _ => t.param.map: (p, t) =>
Param(FldFlags.empty, fieldOrVarSym(ParamBind, p), t.map(term))

def params(t: Tree): Ctxl[(Ls[Param], Ctx)] = t match
case Tup(ps) =>
Expand Down Expand Up @@ -681,7 +759,7 @@ extends Importer:
def computeVariances(s: Statement): Unit =
val trav = VarianceTraverser()
def go(s: Statement): Unit = s match
case TermDefinition(_, k, sym, pss, sign, body, r) =>
case TermDefinition(_, k, sym, pss, sign, body, r, _) =>
pss.foreach(ps => ps.params.foreach(trav.traverseType(S(false))))
sign.foreach(trav.traverseType(S(true)))
body match
Expand All @@ -699,11 +777,36 @@ extends Importer:
while trav.changed do
trav.changed = false
go(s)

object ModuleChecker:

/** Checks if a term is a reference to a type parameter. */
def isTypeParam(t: Term): Bool = t.symbol
.filter(_.isInstanceOf[VarSymbol])
.flatMap(_.asInstanceOf[VarSymbol].decl)
.exists(_.isInstanceOf[TyParam])

/** Checks if a term evaluates to a module value. */
def evalsToModule(t: Term): Bool =
def isModule(t: Tree): Bool = t match
case TypeDef(Mod, _, _, _) => true
case _ => false
def returnsModule(t: TermDef): Bool = t.annotatedResultType match
case S(TypeDef(Mod, _, N, N)) => true
case _ => false
t match
case Term.Blk(_, res) => evalsToModule(res)
case Term.App(lhs, rhs) => lhs.symbol match
case S(sym: BlockMemberSymbol) => sym.trmTree.exists(returnsModule)
case _ => false
case t => t.symbol match
case S(sym: BlockMemberSymbol) => sym.modTree.exists(isModule)
case _ => false

class VarianceTraverser(var changed: Bool = true) extends Traverser:
override def traverseType(pol: Pol)(trm: Term): Unit = trm match
case Term.TyApp(lhs, targs) =>
lhs.symbol.flatMap(_.asTpe) match
lhs.symbol.flatMap(sym => sym.asTpe orElse sym.asMod) match
case S(sym: ClassSymbol) =>
sym.defn match
case S(td: ClassDef) =>
Expand All @@ -715,6 +818,17 @@ extends Importer:
if !tp.isCovariant then traverseType(pol.!)(targ)
case N =>
// TODO(sym->sym.uid)
case S(sym: ModuleSymbol) =>
sym.defn match
case S(td: ModuleDef) =>
if td.tparams.sizeCompare(targs) =/= 0 then
raise(ErrorReport(msg"Wrong number of type arguments" -> trm.toLoc :: Nil)) // TODO BE
td.tparams.zip(targs).foreach:
case (tp, targ) =>
if !tp.isContravariant then traverseType(pol)(targ)
if !tp.isCovariant then traverseType(pol.!)(targ)
case N =>
// TODO(sym->sym.uid)
case S(sym: TypeAliasSymbol) =>
// TODO dedup with above...
sym.defn match
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/semantics/Symbol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class BlockMemberSymbol(val nme: Str, val trees: Ls[Tree]) extends MemberSymbol[
def clsTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Cls => t
def modTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Mod => t
case t: Tree.TypeDef if (t.k is Mod) || (t.k is Obj) => t
def alsTree: Opt[Tree.TypeDef] = trees.collectFirst:
case t: Tree.TypeDef if t.k is Als => t
def trmTree: Opt[Tree.TermDef] = trees.collectFirst:
Expand Down
39 changes: 31 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Term.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package semantics

import mlscript.utils.*, shorthands.*
import syntax.*
import scala.collection.mutable.Buffer


final case class QuantVar(sym: VarSymbol, ub: Opt[Term], lb: Opt[Term])
Expand Down Expand Up @@ -104,7 +105,7 @@ sealed trait Statement extends AutoLocated:
case RegRef(reg, value) => reg :: value :: Nil
case Assgn(lhs, rhs) => lhs :: rhs :: Nil
case Deref(term) => term :: Nil
case TermDefinition(_, k, _, ps, sign, body, res) =>
case TermDefinition(_, k, _, ps, sign, body, res, _) =>
ps.toList.flatMap(_.subTerms) ::: sign.toList ::: body.toList
case cls: ClassDef =>
cls.paramsOpt.toList.flatMap(_.flatMap(_.subTerms)) ::: cls.body.blk :: Nil
Expand All @@ -123,6 +124,7 @@ sealed trait Statement extends AutoLocated:
case t: Tup => t.tree :: Nil
case l: Lam => l.params.map(_.sym.id) ::: l.body :: Nil
case t: App => t.tree :: Nil
case Sel(pre, nme) => pre :: nme :: Nil
case SelProj(prefix, cls, proj) => prefix :: cls :: proj :: Nil
case _ =>
subTerms // TODO more precise (include located things that aren't terms)
Expand Down Expand Up @@ -171,7 +173,7 @@ sealed trait Statement extends AutoLocated:
case CompType(lhs, rhs, pol) => s"${lhs.showDbg} ${if pol then "|" else "&"} ${rhs.showDbg}"
case Error => "<error>"
case Tup(fields) => fields.map(_.showDbg).mkString("[", ", ", "]")
case TermDefinition(_, k, sym, ps, sign, body, res) => s"${k.str} ${sym}${
case TermDefinition(_, k, sym, ps, sign, body, res, flags) => s"${flags} ${k.str} ${sym}${
ps.map(_.showDbg).mkString("")
}${sign.fold("")(": "+_.showDbg)}${
body match
Expand All @@ -188,6 +190,15 @@ final case class LetDecl(sym: LocalSymbol) extends Statement

final case class DefineVar(sym: LocalSymbol, rhs: Term) extends Statement

final case class TermDefFlags(isModMember: Bool):
def showDbg: Str =
val flags = Buffer.empty[String]
if isModMember then flags += "module"
flags.mkString(" ")
override def toString: String = "‹" + showDbg + "›"

object TermDefFlags { val empty: TermDefFlags = TermDefFlags(false) }

final case class TermDefinition(
owner: Opt[InnerSymbol],
k: TermDefKind,
Expand All @@ -196,6 +207,7 @@ final case class TermDefinition(
sign: Opt[Term],
body: Opt[Term],
resSym: FlowSymbol,
flags: TermDefFlags,
) extends Companion

case class ObjBody(blk: Term.Blk):
Expand Down Expand Up @@ -226,9 +238,14 @@ sealed abstract class ClassLikeDef extends TypeLikeDef:
val body: ObjBody


case class ModuleDef(owner: Opt[InnerSymbol], sym: ModuleSymbol, tparams: Ls[TyParam], paramsOpt: Opt[Ls[Param]], body: ObjBody) extends ClassLikeDef with Companion:
self =>
val kind: ClsLikeKind = Mod
case class ModuleDef(
owner: Opt[InnerSymbol],
sym: ModuleSymbol,
tparams: Ls[TyParam],
paramsOpt: Opt[Ls[Param]],
kind: ClsLikeKind,
body: ObjBody,
) extends ClassLikeDef with Companion


sealed abstract class ClassDef extends ClassLikeDef:
Expand Down Expand Up @@ -278,8 +295,14 @@ case class TypeDef(


// TODO Store optional source locations for the flags instead of booleans
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool):
def showDbg: Str = (if mut then "mut " else "") + (if spec then "spec " else "") + (if genGetter then "val " else "")
final case class FldFlags(mut: Bool, spec: Bool, genGetter: Bool, mod: Bool):
def showDbg: Str =
val flags = Buffer.empty[String]
if mut then flags += "mut"
if spec then flags += "spec"
if genGetter then flags += "gen"
if mod then flags += "module"
flags.mkString(" ")
override def toString: String = "‹" + showDbg + "›"

final case class Fld(flags: FldFlags, value: Term, asc: Opt[Term]) extends FldImpl
Expand All @@ -303,7 +326,7 @@ final case class Param(flags: FldFlags, sym: LocalSymbol & NamedSymbol, sign: Op
// def showDbg: Str = flags.showDbg + sym.name + ": " + sign.showDbg
def showDbg: Str = flags.showDbg + sym + sign.fold("")(": " + _.showDbg)

object FldFlags { val empty: FldFlags = FldFlags(false, false, false) }
object FldFlags { val empty: FldFlags = FldFlags(false, false, false, false) }

final case class ParamListFlags(ctx: Bool):
def showDbg: Str = (if ctx then "ctx " else "")
Expand Down
Loading
Loading