diff --git a/compiler/shared/main/scala/mlscript/compiler/codegen/CppAst.scala b/compiler/shared/main/scala/mlscript/compiler/codegen/CppAst.scala index ef054b0268..954a91672e 100644 --- a/compiler/shared/main/scala/mlscript/compiler/codegen/CppAst.scala +++ b/compiler/shared/main/scala/mlscript/compiler/codegen/CppAst.scala @@ -161,7 +161,7 @@ case class CompilationUnit(includes: Ls[Str], decls: Ls[Decl], defs: Ls[Def]): "HiddenTheseEntities", "True", "False", "Callable", "List", "Cons", "Nil", "Option", "Some", "None", "Pair", "Tuple2", "Tuple3", "Nat", "S", "O" ) stack_list(defs.filterNot { - case Def.StructDef(name, _, _, _) => hiddenNames.contains(name.stripPrefix("_mls_")) + case d: Def.StructDef => hiddenNames.contains(d.name.stripPrefix("_mls_")) case _ => false }.map(_.toDocument)) diff --git a/compiler/shared/main/scala/mlscript/compiler/codegen/CppCompilerHost.scala b/compiler/shared/main/scala/mlscript/compiler/codegen/CppCompilerHost.scala index 3897648dbc..4bbfa19b7b 100644 --- a/compiler/shared/main/scala/mlscript/compiler/codegen/CppCompilerHost.scala +++ b/compiler/shared/main/scala/mlscript/compiler/codegen/CppCompilerHost.scala @@ -41,4 +41,4 @@ final class CppCompilerHost(val auxPath: Str): return output("Execution succeeded: ") - for line <- stdout do output(line) \ No newline at end of file + for line <- stdout do output(line) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index da45361b67..3e609e6d5b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -120,27 +120,24 @@ sealed abstract class Block extends Product with AutoLocated: (bod.freeVars - lhs) ++ rst.freeVars ++ hdr.flatMap(_.freeVars) case End(msg) => Set.empty - // TODO: freeVarsLLIR skips `fun` and `cls` in `Call` and `Instantiate` respectively, which is needed in some - // other places. However, adding them breaks some LLIR tests. Supposedly, once the IR uses the new symbol system, - // this should no longer happen. This version should be removed once that is resolved. lazy val freeVarsLLIR: Set[Local] = this match case Match(scrut, arms, dflt, rest) => scrut.freeVarsLLIR ++ dflt.toList.flatMap(_.freeVarsLLIR) ++ rest.freeVarsLLIR ++ arms.flatMap: - (pat, arm) => arm.freeVarsLLIR -- pat.freeVars + (pat, arm) => arm.freeVarsLLIR -- pat.freeVarsLLIR case Return(res, implct) => res.freeVarsLLIR case Throw(exc) => exc.freeVarsLLIR case Label(label, body, rest) => (body.freeVarsLLIR - label) ++ rest.freeVarsLLIR - case Break(label) => Set(label) - case Continue(label) => Set(label) + case Break(label) => Set.empty + case Continue(label) => Set.empty case Begin(sub, rest) => sub.freeVarsLLIR ++ rest.freeVarsLLIR case TryBlock(sub, finallyDo, rest) => sub.freeVarsLLIR ++ finallyDo.freeVarsLLIR ++ rest.freeVarsLLIR - case Assign(lhs, rhs, rest) => Set(lhs) ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR + case Assign(lhs, rhs, rest) => rhs.freeVarsLLIR ++ (rest.freeVarsLLIR - lhs) case AssignField(lhs, nme, rhs, rest) => lhs.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => lhs.freeVarsLLIR ++ fld.freeVarsLLIR ++ rhs.freeVarsLLIR ++ rest.freeVarsLLIR - case Define(defn, rest) => defn.freeVarsLLIR ++ rest.freeVarsLLIR + case Define(defn, rest) => defn.freeVarsLLIR ++ (rest.freeVarsLLIR - defn.sym) case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) => - (bod.freeVarsLLIR - lhs) ++ rst.freeVarsLLIR ++ hdr.flatMap(_.freeVars) + (bod.freeVarsLLIR - lhs) ++ rst.freeVarsLLIR ++ hdr.flatMap(_.freeVarsLLIR) case End(msg) => Set.empty lazy val subBlocks: Ls[Block] = this match @@ -385,8 +382,8 @@ final case class Handler( params: Ls[ParamList], body: Block, ): - lazy val freeVarsLLIR: Set[Local] = body.freeVarsLLIR -- params.flatMap(_.paramSyms) - sym - resumeSym lazy val freeVars: Set[Local] = body.freeVars -- params.flatMap(_.paramSyms) - sym - resumeSym + lazy val freeVarsLLIR: Set[Local] = body.freeVarsLLIR -- params.flatMap(_.paramSyms) - sym - resumeSym /* Represents either unreachable code (for functions that must return a result) * or the end of a non-returning function or a REPL block */ @@ -446,9 +443,13 @@ sealed abstract class Result extends AutoLocated: case Value.Rcd(args) => args.flatMap(arg => arg.idx.fold(Set.empty)(_.freeVars) ++ arg.value.freeVars).toSet lazy val freeVarsLLIR: Set[Local] = this match - case Call(fun, args) => args.flatMap(_.value.freeVarsLLIR).toSet - case Instantiate(cls, args) => args.flatMap(_.freeVarsLLIR).toSet + case Call(fun, args) => fun.freeVarsLLIR ++ args.flatMap(_.value.freeVarsLLIR).toSet + case Instantiate(cls, args) => cls.freeVarsLLIR ++ args.flatMap(_.freeVarsLLIR).toSet case Select(qual, name) => qual.freeVarsLLIR + case Value.Ref(l: (BuiltinSymbol | TopLevelSymbol | ClassSymbol | TermSymbol)) => Set.empty + case Value.Ref(l: MemberSymbol[?]) => l.defn match + case Some(d: ClassLikeDef) => Set.empty + case _ => Set(l) case Value.Ref(l) => Set(l) case Value.This(sym) => Set.empty case Value.Lit(lit) => Set.empty diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala index e4be496c17..dfd194f4d2 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala @@ -31,7 +31,7 @@ object Printer: .map{ case (c, b) => doc"${case_doc(c)} => #{ # ${mkDocument(b)} #} " } .mkDocument(sep = doc" # ") val docDefault = dflt.map(mkDocument).getOrElse(doc"") - doc"match ${mkDocument(scrut)} #{ # ${docCases} # else #{ # ${docDefault} #} #} # in # ${mkDocument(rest)} " + doc"match ${mkDocument(scrut)} #{ # ${docCases} # else #{ # ${docDefault} #} #} # in # ${mkDocument(rest)}" case Return(res, implct) => doc"return ${mkDocument(res)}" case Throw(exc) => doc"throw ${mkDocument(exc)}" case Label(label, body, rest) => diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala index c05e4f1d3e..e2218e3295 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/Ast.scala @@ -1,14 +1,13 @@ -package hkmc2.codegen.cpp +package hkmc2 +package codegen.cpp import mlscript._ import mlscript.utils._ import mlscript.utils.shorthands._ - -import hkmc2.Message.MessageContext -import hkmc2.document._ - import scala.language.implicitConversions +import document._ + private def raw(x: String): Document = doc"$x" given Conversion[String, Document] = x => doc"$x" @@ -130,14 +129,13 @@ enum Expr: case Unary(op: Str, expr: Expr) case Binary(op: Str, lhs: Expr, rhs: Expr) case Initializer(exprs: Ls[Expr]) - case Constructor(name: Str, init: Expr) def toDocument: Document = def aux(x: Expr): Document = x match case Var(name) => name case IntLit(value) => value.toString case DoubleLit(value) => value.toString - case StrLit(value) => s"\"$value\"" // need more reliable escape utils + case StrLit(value) => value.escaped case CharLit(value) => value.toInt.toString case Call(func, args) => doc"${func.toDocument}(${Expr.toDocuments(args, sep = ", ")})" @@ -151,8 +149,6 @@ enum Expr: doc"(${lhs.toDocument} $op ${rhs.toDocument})" case Initializer(exprs) => doc"{${Expr.toDocuments(exprs, sep = ", ")}}" - case Constructor(name, init) => - doc"$name(${init.toDocument})" aux(this) case class CompilationUnit(includes: Ls[Str], decls: Ls[Decl], defs: Ls[Def]): @@ -161,43 +157,48 @@ case class CompilationUnit(includes: Ls[Str], decls: Ls[Decl], defs: Ls[Def]): def toDocumentWithoutHidden: Document = val hiddenNames: Set[Str] = Set() defs.filterNot { - case Def.StructDef(name, _, _, _) => hiddenNames.contains(name.stripPrefix("_mls_")) + case Def.StructDef(name, _, _, _, _) => hiddenNames.contains(name.stripPrefix("_mls_")) case _ => false }.map(_.toDocument).mkDocument(doc" # ") enum Decl: case StructDecl(name: Str) case EnumDecl(name: Str) - case FuncDecl(ret: Type, name: Str, args: Ls[Type]) + case FuncDecl(ret: Type, name: Str, args: Ls[Type], isOverride: Bool, isVirtual: Bool) case VarDecl(name: Str, typ: Type) def toDocument: Document = def aux(x: Decl): Document = x match case StructDecl(name) => doc"struct $name;" case EnumDecl(name) => doc"enum $name;" - case FuncDecl(ret, name, args) => - doc"${ret.toDocument()} $name(${Type.toDocuments(args, sep = ", ")});" + case FuncDecl(ret, name, args, or, virt) => + val docVirt = (if virt then doc"virtual " else doc"") + val docSpecRet = ret.toDocument() + val docArgs = Type.toDocuments(args, sep = ", ") + val docOverride = if or then doc" override" else doc"" + doc"$docVirt$docSpecRet $name($docArgs)$docOverride;" case VarDecl(name, typ) => doc"${typ.toDocument()} $name;" aux(this) enum Def: - case StructDef(name: Str, fields: Ls[(Str, Type)], inherit: Opt[Ls[Str]], methods: Ls[Def] = Ls.empty) + case StructDef(name: Str, fields: Ls[(Str, Type)], inherit: Opt[Ls[Str]], methods: Ls[Def], methodsDecl: Ls[Decl]) case EnumDef(name: Str, fields: Ls[(Str, Opt[Int])]) - case FuncDef(specret: Type, name: Str, args: Ls[(Str, Type)], body: Stmt.Block, or: Bool = false, virt: Bool = false) + case FuncDef(specret: Type, name: Str, args: Ls[(Str, Type)], body: Stmt.Block, isOverride: Bool, isVirtual: Bool, in_scope: Opt[Str]) case VarDef(typ: Type, name: Str, init: Opt[Expr]) case RawDef(raw: Str) def toDocument: Document = def aux(x: Def): Document = x match - case StructDef(name, fields, inherit, defs) => + case StructDef(name, fields, inherit, defs, decls) => val docFirst = doc"struct $name${inherit.fold(doc"")(x => doc": public ${x.mkDocument(doc", ")}")} {" val docFields = fields.map { case (name, typ) => doc"${typ.toDocument()} $name;" }.mkDocument(doc" # ") val docDefs = defs.map(_.toDocument).mkDocument(doc" # ") + val docDecls = decls.map(_.toDocument).mkDocument(doc" # ") val docLast = "};" - doc"$docFirst #{ # $docFields # $docDefs #} # $docLast" + doc"$docFirst #{ # $docFields # $docDefs # $docDecls #} # $docLast" case EnumDef(name, fields) => val docFirst = doc"enum $name {" val docFields = fields.map { @@ -205,13 +206,14 @@ enum Def: }.mkDocument(doc" # ") val docLast = "};" doc"$docFirst #{ # $docFields #} # $docLast" - case FuncDef(specret, name, args, body, or, virt) => + case FuncDef(specret, name, args, body, or, virt, scope) => val docVirt = (if virt then doc"virtual " else doc"") val docSpecRet = specret.toDocument() val docArgs = Type.toDocuments(args, sep = ", ") val docOverride = if or then doc" override" else doc"" val docBody = body.toDocument - doc"$docVirt$docSpecRet $name($docArgs)$docOverride ${body.toDocument}" + val docScope = scope.fold(doc"")(x => doc"$x::") + doc"$docVirt$docSpecRet $docScope$name($docArgs)$docOverride ${body.toDocument}" case VarDef(typ, name, init) => val docTyp = typ.toDocument() val docInit = init.fold(raw(""))(x => doc" = ${x.toDocument}") diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala index 80f83ae7c3..b3d0e7462c 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/cpp/CodeGen.scala @@ -1,36 +1,55 @@ -package hkmc2.codegen.cpp +package hkmc2 +package codegen +package cpp import mlscript.utils._ import mlscript.utils.shorthands._ import scala.collection.mutable.ListBuffer -import hkmc2.codegen.llir.{Expr => IExpr, _} -import hkmc2.codegen.cpp._ +import llir.{Expr => IExpr, _} +import utils.{Scope, TraceLogger} +import semantics._ -object CppCodeGen: - def mapName(name: Name): Str = "_mls_" + name.str.replace('$', '_').replace('\'', '_') - def mapName(name: Str): Str = "_mls_" + name.replace('$', '_').replace('\'', '_') - val freshName = Fresh(div = '_'); +class CppCodeGen(builtinClassSymbols: Set[Local], tl: TraceLogger): + import tl.{trace, log, logs} + def mapName(name: Str): Str = "_mls_" + Scope.replaceInvalidCharacters(name) + def mapClsLikeName(sym: Local)(using Raise, Scope): Str = + if builtinClassSymbols.contains(sym) then sym.nme |> mapName + else allocIfNew(sym) + def directName(sym: Local): Str = + sym.nme |> mapName val mlsValType = Type.Prim("_mlsValue") val mlsUnitValue = Expr.Call(Expr.Var("_mlsValue::create<_mls_Unit>"), Ls()); val mlsRetValue = "_mls_retval" - val mlsRetValueDecl = Decl.VarDecl(mlsRetValue, mlsValType) + def mlsRetValType(n: Int) = + if n === 1 then + mlsValType + else + Type.Template("std::tuple", Ls.fill(n)(mlsValType)) + def mlsRetValueDecl(n: Int) = + if n === 1 then + Decl.VarDecl(mlsRetValue, mlsValType) + else + Decl.VarDecl(mlsRetValue, mlsRetValType(n)) val mlsMainName = "_mlsMain" val mlsPrelude = "#include \"mlsprelude.h\"" val mlsPreludeImpl = "#include \"mlsprelude.cpp\"" - val mlsInternalClass = Set("True", "False", "Boolean", "Callable") + val builtinClassSymbolNames = Set("Callable", "Lazy") + def mlsIsInternalClass(sym: Local) = builtinClassSymbolNames.contains(sym.nme) val mlsObject = "_mlsObject" val mlsBuiltin = "builtin" val mlsEntryPoint = s"int main() { return _mlsLargeStack(_mlsMainWrapper); }"; + def mlsCallEntry(s: Str) = s"_mlsValue _mlsMain() { return $s(); }" def mlsIntLit(x: BigInt) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.IntLit(x))) - def mlsStrLit(x: Str) = Expr.Call(Expr.Var("_mlsValue::fromStrLit"), Ls(Expr.StrLit(x))) + def mlsStrLit(x: Str) = Expr.Call(Expr.Var("_mlsValue::create<_mls_Str>"), Ls(Expr.StrLit(x))) + def mlsDecLit(x: BigDecimal) = Expr.Call(Expr.Var("_mlsValue::create<_mls_Float>"), Ls(Expr.DoubleLit(x.toDouble))) def mlsCharLit(x: Char) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.CharLit(x))) def mlsNewValue(cls: Str, args: Ls[Expr]) = Expr.Call(Expr.Var(s"_mlsValue::create<$cls>"), args) def mlsIsValueOf(cls: Str, scrut: Expr) = Expr.Call(Expr.Var(s"_mlsValue::isValueOf<$cls>"), Ls(scrut)) def mlsIsBoolLit(scrut: Expr, lit: hkmc2.syntax.Tree.BoolLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(if lit.value then 1 else 0))) def mlsIsIntLit(scrut: Expr, lit: hkmc2.syntax.Tree.IntLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(lit.value))) def mlsDebugPrint(x: Expr) = Expr.Call(Expr.Var("_mlsValue::print"), Ls(x)) - def mlsTupleValue(init: Expr) = Expr.Constructor("_mlsValue::tuple", init) + def mlsTupleValue(init: Ls[Expr]) = Expr.Call(Expr.Var("std::make_tuple"), init) def mlsAs(name: Str, cls: Str) = Expr.Var(s"_mlsValue::as<$cls>($name)") def mlsAsUnchecked(name: Str, cls: Str) = Expr.Var(s"_mlsValue::cast<$cls>($name)") def mlsObjectNameMethod(name: Str) = s"constexpr static inline const char *typeName = \"${name}\";" @@ -50,62 +69,96 @@ object CppCodeGen: s"virtual void destroy() override { $fieldsDeletion operator delete (this, std::align_val_t(_mlsAlignment)); }" def mlsThrowNonExhaustiveMatch = Stmt.Raw("_mlsNonExhaustiveMatch();"); def mlsCall(fn: Str, args: Ls[Expr]) = Expr.Call(Expr.Var("_mlsCall"), Expr.Var(fn) :: args) - def mlsMethodCall(cls: ClassRef, method: Str, args: Ls[Expr]) = - Expr.Call(Expr.Member(Expr.Call(Expr.Var(s"_mlsMethodCall<${cls.name |> mapName}>"), Ls(args.head)), method), args.tail) + def mlsMethodCall(cls: Local, method: Str, args: Ls[Expr])(using Raise, Scope) = + Expr.Call(Expr.Member(Expr.Call(Expr.Var(s"_mlsMethodCall<${cls |> mapClsLikeName}>"), Ls(args.head)), method), args.tail) + def mlsThisCall(cls: Local, method: Str, args: Ls[Expr])(using Raise, Scope) = + Expr.Call(Expr.Member(Expr.Var("this"), method), args) def mlsFnWrapperName(fn: Str) = s"_mlsFn_$fn" def mlsFnCreateMethod(fn: Str) = s"static _mlsValue create() { static _mlsFn_$fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); }" def mlsNeverValue(n: Int) = if (n <= 1) then Expr.Call(Expr.Var(s"_mlsValue::never"), Ls()) else Expr.Call(Expr.Var(s"_mlsValue::never<$n>"), Ls()) + val mlsThis = Expr.Var("_mlsValue(this, _mlsValue::inc_ref_tag{})") // first construct a value, then incRef() case class Ctx( - defnCtx: Set[Str], + fieldCtx: Set[Local], ) - def codegenClassInfo(using ctx: Ctx)(cls: ClassInfo): (Opt[Def], Decl) = - val fields = cls.fields.map{x => (x |> mapName, mlsValType)} - val parents = if cls.parents.nonEmpty then cls.parents.toList.map(mapName) else mlsObject :: Nil - val decl = Decl.StructDecl(cls.name |> mapName) - if mlsInternalClass.contains(cls.name) then return (None, decl) - val theDef = Def.StructDef( - cls.name |> mapName, fields, - if parents.nonEmpty then Some(parents) else None, - Ls(Def.RawDef(mlsObjectNameMethod(cls.name)), - Def.RawDef(mlsTypeTag()), - Def.RawDef(mlsCommonPrintMethod(cls.fields.map(mapName))), - Def.RawDef(mlsCommonDestructorMethod(cls.name |> mapName, cls.fields.map(mapName))), - Def.RawDef(mlsCommonCreateMethod(cls.name |> mapName, cls.fields.map(mapName), cls.id))) - ++ cls.methods.map{case (name, defn) => { - val (theDef, decl) = codegenDefn(using Ctx(ctx.defnCtx + cls.name))(defn) - theDef match - case x @ Def.FuncDef(_, name, _, _, _, _) => x.copy(virt = true) - case _ => theDef - }} - ) - (S(theDef), decl) + def getVar(l: Local)(using Raise, Scope): String = l match + case ts: hkmc2.semantics.TermSymbol => + ts.owner match + case S(owner) => summon[Scope].lookup_!(ts) + case N => summon[Scope].lookup_!(ts) + case ts: hkmc2.semantics.InnerSymbol => + summon[Scope].lookup_!(ts) + case _ => summon[Scope].lookup_!(l) + + def allocIfNew(l: Local)(using Raise, Scope): Str = + trace[Str](s"allocIfNew $l begin", r => s"allocIfNew $l end -> $r"): + if summon[Scope].lookup(l).isDefined then + getVar(l) |> mapName + else + summon[Scope].allocateName(l) |> mapName - def toExpr(texpr: TrivialExpr, reifyUnit: Bool = false)(using ctx: Ctx): Opt[Expr] = texpr match - case IExpr.Ref(name) => S(Expr.Var(name |> mapName)) + def codegenClassInfo(using Ctx, Raise, Scope)(cls: ClassInfo) = + trace[(Opt[Def], Decl, Ls[Def])](s"codegenClassInfo ${cls.symbol} begin"): + val fields = cls.fields.map{x => (x |> directName, mlsValType)} + cls.fields.foreach(x => summon[Scope].allocateName(x)) + val parents = if cls.parents.nonEmpty then cls.parents.toList.map(mapClsLikeName) else mlsObject :: Nil + val decl = Decl.StructDecl(cls.symbol |> mapClsLikeName) + if mlsIsInternalClass(cls.symbol) then (None, decl, Ls.empty) + else + val methods = cls.methods.map: + case (name, defn) => + val (cdef, decl) = codegenDefn(using Ctx(summon[Ctx].fieldCtx ++ cls.fields))(defn) + val cdef2 = cdef match + case x: Def.FuncDef if builtinApply.contains(defn.name.nme) => x.copy(name = defn.name |> directName, in_scope = Some(cls.symbol |> mapClsLikeName)) + case x: Def.FuncDef => x.copy(in_scope = Some(cls.symbol |> mapClsLikeName)) + case _ => throw new Exception(s"codegenClassInfo: unexpected def $cdef") + val decl2 = decl match + case x: Decl.FuncDecl if builtinApply.contains(defn.name.nme) => x.copy(isVirtual = true, name = defn.name |> directName) + case x: Decl.FuncDecl => x.copy(isVirtual = true) + case _ => throw new Exception(s"codegenClassInfo: unexpected decl $decl") + log(s"codegenClassInfo: ${cls.symbol} method ${defn.name} $decl2") + (cdef2, decl2) + val theDef = Def.StructDef( + cls.symbol |> mapClsLikeName, fields, + if parents.nonEmpty then Some(parents) else None, + Ls(Def.RawDef(mlsObjectNameMethod(cls.symbol.nme)), + Def.RawDef(mlsTypeTag()), + Def.RawDef(mlsCommonPrintMethod(cls.fields.map(directName))), + Def.RawDef(mlsCommonDestructorMethod(cls.symbol |> mapClsLikeName, cls.fields.map(directName))), + Def.RawDef(mlsCommonCreateMethod(cls.symbol |> mapClsLikeName, cls.fields.map(directName), cls.id))), + methods.iterator.map(_._2).toList + ) + (S(theDef), decl, methods.iterator.map(_._1).toList) + + def toExpr(texpr: TrivialExpr, reifyUnit: Bool = false)(using Ctx, Raise, Scope): Opt[Expr] = texpr match + case IExpr.Ref(name) if summon[Ctx].fieldCtx.contains(name) => S(Expr.Var(name |> directName)) + case IExpr.Ref(name: BuiltinSymbol) if name.nme === "" => S(mlsThis) + case IExpr.Ref(name) => S(Expr.Var(name |> allocIfNew)) case IExpr.Literal(hkmc2.syntax.Tree.BoolLit(x)) => S(mlsIntLit(if x then 1 else 0)) case IExpr.Literal(hkmc2.syntax.Tree.IntLit(x)) => S(mlsIntLit(x)) - case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => S(mlsIntLit(x.toBigInt)) + case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => S(mlsDecLit(x)) case IExpr.Literal(hkmc2.syntax.Tree.StrLit(x)) => S(mlsStrLit(x)) case IExpr.Literal(hkmc2.syntax.Tree.UnitLit(_)) => if reifyUnit then S(mlsUnitValue) else None - def toExpr(texpr: TrivialExpr)(using ctx: Ctx): Expr = texpr match - case IExpr.Ref(name) => Expr.Var(name |> mapName) + def toExpr(texpr: TrivialExpr)(using Ctx, Raise, Scope): Expr = texpr match + case IExpr.Ref(name) if summon[Ctx].fieldCtx.contains(name) => Expr.Var(name |> directName) + case IExpr.Ref(name: BuiltinSymbol) if name.nme === "" => mlsThis + case IExpr.Ref(name) => Expr.Var(name |> allocIfNew) case IExpr.Literal(hkmc2.syntax.Tree.BoolLit(x)) => mlsIntLit(if x then 1 else 0) case IExpr.Literal(hkmc2.syntax.Tree.IntLit(x)) => mlsIntLit(x) - case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => mlsIntLit(x.toBigInt) + case IExpr.Literal(hkmc2.syntax.Tree.DecLit(x)) => mlsDecLit(x) case IExpr.Literal(hkmc2.syntax.Tree.StrLit(x)) => mlsStrLit(x) case IExpr.Literal(hkmc2.syntax.Tree.UnitLit(_)) => mlsUnitValue - def wrapMultiValues(exprs: Ls[TrivialExpr])(using ctx: Ctx): Expr = exprs match + def wrapMultiValues(exprs: Ls[TrivialExpr])(using Ctx, Raise, Scope): Expr = exprs match case x :: Nil => toExpr(x, reifyUnit = true).get case _ => - val init = Expr.Initializer(exprs.map{x => toExpr(x)}) + val init = exprs.map{x => toExpr(x)} mlsTupleValue(init) - def codegenCaseWithIfs(scrut: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = + def codegenCaseWithIfs(scrut: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using Ctx, Raise, Scope): (Ls[Decl], Ls[Stmt]) = val scrut2 = toExpr(scrut) val init: Stmt = default.fold(mlsThrowNonExhaustiveMatch)(x => { @@ -115,7 +168,7 @@ object CppCodeGen: val stmt = cases.foldRight(S(init)) { case ((Pat.Class(cls), arm), nextarm) => val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) - val stmt = Stmt.If(mlsIsValueOf(cls.name |> mapName, scrut2), Stmt.Block(decls2, stmts2), nextarm) + val stmt = Stmt.If(mlsIsValueOf(cls |> mapClsLikeName, scrut2), Stmt.Block(decls2, stmts2), nextarm) S(stmt) case ((Pat.Lit(i @ hkmc2.syntax.Tree.IntLit(_)), arm), nextarm) => val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) @@ -125,116 +178,131 @@ object CppCodeGen: val (decls2, stmts2) = codegen(arm, storeInto)(using Ls.empty, Ls.empty[Stmt]) val stmt = Stmt.If(mlsIsBoolLit(scrut2, i), Stmt.Block(decls2, stmts2), nextarm) S(stmt) - case _ => ??? + case _ => TODO("codegenCaseWithIfs doesn't support these patterns currently") } (decls, stmt.fold(stmts)(x => stmts :+ x)) - def codegenJumpWithCall(func: FuncRef, args: Ls[TrivialExpr], storeInto: Opt[Str])(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = - val call = Expr.Call(Expr.Var(func.name |> mapName), args.map(toExpr)) + def codegenJumpWithCall(func: Local, args: Ls[TrivialExpr], storeInto: Opt[Str])(using decls: Ls[Decl], stmts: Ls[Stmt])(using Ctx, Raise, Scope): (Ls[Decl], Ls[Stmt]) = + val call = Expr.Call(Expr.Var(func |> allocIfNew), args.map(toExpr)) val stmts2 = stmts ++ Ls(storeInto.fold(Stmt.Return(call))(x => Stmt.Assign(x, call))) (decls, stmts2) - def codegenOps(op: Str, args: Ls[TrivialExpr])(using ctx: Ctx) = op match - case "+" => Expr.Binary("+", toExpr(args(0)), toExpr(args(1))) - case "-" => Expr.Binary("-", toExpr(args(0)), toExpr(args(1))) - case "*" => Expr.Binary("*", toExpr(args(0)), toExpr(args(1))) - case "/" => Expr.Binary("/", toExpr(args(0)), toExpr(args(1))) - case "%" => Expr.Binary("%", toExpr(args(0)), toExpr(args(1))) - case "==" => Expr.Binary("==", toExpr(args(0)), toExpr(args(1))) - case "!=" => Expr.Binary("!=", toExpr(args(0)), toExpr(args(1))) - case "<" => Expr.Binary("<", toExpr(args(0)), toExpr(args(1))) - case "<=" => Expr.Binary("<=", toExpr(args(0)), toExpr(args(1))) - case ">" => Expr.Binary(">", toExpr(args(0)), toExpr(args(1))) - case ">=" => Expr.Binary(">=", toExpr(args(0)), toExpr(args(1))) - case "&&" => Expr.Binary("&&", toExpr(args(0)), toExpr(args(1))) - case "||" => Expr.Binary("||", toExpr(args(0)), toExpr(args(1))) - case "!" => Expr.Unary("!", toExpr(args(0))) - case _ => TODO("codegenOps") - - - def codegen(expr: IExpr)(using ctx: Ctx): Expr = expr match + def codegenOps(op: BuiltinSymbol, args: Ls[TrivialExpr])(using Ctx, Raise, Scope) = + trace[Expr](s"codegenOps $op begin"): + var op2 = op.nme + if op2 === "===" then + op2 = "==" + else if op2 === "!===" then + op2 = "!=" + if op.binary && args.length === 2 then + Expr.Binary(op2, toExpr(args(0)), toExpr(args(1))) + else if op.unary && args.length === 1 then + Expr.Unary(op2, toExpr(args(0))) + else + TODO(s"codegenOps ${op.nme} ${args.size} ${op.binary} ${op.unary} ${args.map(_.show)}") + + + def codegen(expr: IExpr)(using Ctx, Raise, Scope): Expr = expr match case x @ (IExpr.Ref(_) | IExpr.Literal(_)) => toExpr(x, reifyUnit = true).get - case IExpr.CtorApp(cls, args) => mlsNewValue(cls.name |> mapName, args.map(toExpr)) - case IExpr.Select(name, cls, field) => Expr.Member(mlsAsUnchecked(name |> mapName, cls.name |> mapName), field |> mapName) + case IExpr.CtorApp(cls, args) => mlsNewValue(cls |> mapClsLikeName, args.map(toExpr)) + case IExpr.Select(name, cls, field) if field.forall(_.isDigit) => + Expr.Member(mlsAsUnchecked(name |> allocIfNew, cls |> mapClsLikeName), s"field${field}" |> mapName) + case IExpr.Select(name, cls, field) => Expr.Member(mlsAsUnchecked(name |> allocIfNew, cls |> mapClsLikeName), field |> mapName) case IExpr.BasicOp(name, args) => codegenOps(name, args) case IExpr.AssignField(assignee, cls, field, value) => TODO("codegen assign field") - def codegenBuiltin(names: Ls[Name], builtin: Str, args: Ls[TrivialExpr])(using ctx: Ctx): Ls[Stmt] = builtin match - case "error" => Ls(Stmt.Raw("throw std::runtime_error(\"Error\");"), Stmt.AutoBind(names.map(mapName), mlsNeverValue(names.size))) - case _ => Ls(Stmt.AutoBind(names.map(mapName), Expr.Call(Expr.Var("_mls_builtin_" + builtin), args.map(toExpr)))) - - def codegen(body: Node, storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = body match - case Node.Result(res) => - val expr = wrapMultiValues(res) - val stmts2 = stmts ++ Ls(Stmt.Assign(storeInto, expr)) - (decls, stmts2) - case Node.Jump(defn, args) => - codegenJumpWithCall(defn, args, S(storeInto)) - case Node.Panic(msg) => (decls, stmts :+ Stmt.Raw(s"throw std::runtime_error(\"$msg\");")) - case Node.LetExpr(name, expr, body) => - val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> mapName), codegen(expr))) - codegen(body, storeInto)(using decls, stmts2) - case Node.LetMethodCall(names, cls, method, IExpr.Ref(Name("builtin")) :: args, body) => - val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail) - codegen(body, storeInto)(using decls, stmts2) - case Node.LetMethodCall(names, cls, method, args, body) => - val call = mlsMethodCall(cls, method.str |> mapName, args.map(toExpr)) - val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call)) - codegen(body, storeInto)(using decls, stmts2) - case Node.LetCall(names, defn, args, body) => - val call = Expr.Call(Expr.Var(defn.name |> mapName), args.map(toExpr)) - val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call)) - codegen(body, storeInto)(using decls, stmts2) - case Node.Case(scrut, cases, default) => - codegenCaseWithIfs(scrut, cases, default, storeInto) + def codegenBuiltin(names: Ls[Local], builtin: Str, args: Ls[TrivialExpr])(using Ctx, Raise, Scope): Ls[Stmt] = builtin match + case "error" => Ls(Stmt.Raw("throw std::runtime_error(\"Error\");"), Stmt.AutoBind(names.map(allocIfNew), mlsNeverValue(names.size))) + case _ => Ls(Stmt.AutoBind(names.map(allocIfNew), Expr.Call(Expr.Var("_mls_builtin_" + builtin), args.map(toExpr)))) + + lazy val builtinApply = Set( + "apply0", + "apply1", + "apply2", + "apply3", + "apply4", + "apply5", + "apply6", + "apply7", + "apply8", + "apply9", + ) + + def codegen(body: Node, storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using Ctx, Raise, Scope): (Ls[Decl], Ls[Stmt]) = + trace[(Ls[Decl], Ls[Stmt])](s"codegen $body begin"): + body match + case Node.Result(res) => + val expr = wrapMultiValues(res) + val stmts2 = stmts ++ Ls(Stmt.Assign(storeInto, expr)) + (decls, stmts2) + case Node.Jump(defn, args) => + codegenJumpWithCall(defn, args, S(storeInto)) + case Node.Panic(msg) => (decls, stmts :+ Stmt.Raw(s"throw std::runtime_error(\"$msg\");")) + case Node.LetExpr(name, expr, body) => + val stmts2 = stmts ++ Ls(Stmt.AutoBind(Ls(name |> allocIfNew), codegen(expr))) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetCall(names, bin: BuiltinSymbol, args, body) if bin.nme === "" => + val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, IExpr.Ref(bin: BuiltinSymbol) :: args, body) if bin.nme === "" => + val call = mlsThisCall(cls, method |> directName, args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(allocIfNew), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, args, body) if builtinApply.contains(method.nme) => + val call = mlsMethodCall(cls, method |> directName, args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(allocIfNew), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetMethodCall(names, cls, method, args, body) => + val call = mlsMethodCall(cls, method |> allocIfNew, args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(allocIfNew), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.LetCall(names, defn, args, body) => + val call = Expr.Call(Expr.Var(defn |> allocIfNew), args.map(toExpr)) + val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(allocIfNew), call)) + codegen(body, storeInto)(using decls, stmts2) + case Node.Case(scrut, cases, default) => + codegenCaseWithIfs(scrut, cases, default, storeInto) - def codegenDefn(using ctx: Ctx)(defn: Func): (Def, Decl) = defn match + def codegenDefn(using Ctx, Raise, Scope)(defn: Func): (Def, Decl) = defn match case Func(id, name, params, resultNum, body) => - val decls = Ls(mlsRetValueDecl) + val decls = Ls(mlsRetValueDecl(resultNum)) val stmts = Ls.empty[Stmt] val (decls2, stmts2) = codegen(body, mlsRetValue)(using decls, stmts) val stmtsWithReturn = stmts2 :+ Stmt.Return(Expr.Var(mlsRetValue)) - val theDef = Def.FuncDef(mlsValType, name |> mapName, params.map(x => (x |> mapName, mlsValType)), Stmt.Block(decls2, stmtsWithReturn)) - val decl = Decl.FuncDecl(mlsValType, name |> mapName, params.map(x => mlsValType)) + val theDef = Def.FuncDef(mlsRetValType(resultNum), name |> allocIfNew, params.map(x => (x |> allocIfNew, mlsValType)), Stmt.Block(decls2, stmtsWithReturn), false, false, None) + val decl = Decl.FuncDecl(mlsRetValType(resultNum), name |> allocIfNew, params.map(x => mlsValType), false, false) (theDef, decl) - def codegenTopNode(node: Node)(using ctx: Ctx): (Def, Decl) = - val decls = Ls(mlsRetValueDecl) - val stmts = Ls.empty[Stmt] - val (decls2, stmts2) = codegen(node, mlsRetValue)(using decls, stmts) - val stmtsWithReturn = stmts2 :+ Stmt.Return(Expr.Var(mlsRetValue)) - val theDef = Def.FuncDef(mlsValType, mlsMainName, Ls(), Stmt.Block(decls2, stmtsWithReturn)) - val decl = Decl.FuncDecl(mlsValType, mlsMainName, Ls()) - (theDef, decl) - // Topological sort of classes based on inheritance relationships - def sortClasses(prog: Program): Ls[ClassInfo] = - var depgraph = prog.classes.map(x => (x.name, x.parents)).toMap + def sortClasses(prog: Program)(using Raise, Scope): Ls[ClassInfo] = + var depgraph = prog.classes.map(x => (x.symbol, x.parents)).toMap + ++ builtinClassSymbols.map(x => (x, Set.empty[Symbol])) + log(s"depgraph: $depgraph") var degree = depgraph.view.mapValues(_.size).toMap - def removeNode(node: Str) = + def removeNode(node: Symbol) = degree -= node depgraph -= node - depgraph = depgraph.view.mapValues(_.filter(_ != node)).toMap + depgraph = depgraph.view.mapValues(_.filter(_ =/= node)).toMap degree = depgraph.view.mapValues(_.size).toMap val sorted = ListBuffer.empty[ClassInfo] - var work = degree.filter(_._2 == 0).keys.toSet + var work = degree.filter(_._2 === 0).keys.toSet while work.nonEmpty do val node = work.head work -= node - sorted.addOne(prog.classes.find(_.name == node).get) + prog.classes.find(x => (x.symbol) === node).foreach(sorted.addOne) removeNode(node) - val next = degree.filter(_._2 == 0).keys + val next = degree.filter(_._2 === 0).keys work ++= next if depgraph.nonEmpty then val cycle = depgraph.keys.mkString(", ") throw new Exception(s"Cycle detected in class hierarchy: $cycle") sorted.toList - def codegen(prog: Program): CompilationUnit = + def codegen(prog: Program)(using Raise, Scope): CompilationUnit = val sortedClasses = sortClasses(prog) - val defnCtx = prog.defs.map(_.name) - val (defs, decls) = sortedClasses.map(codegenClassInfo(using Ctx(defnCtx))).unzip - val (defs2, decls2) = prog.defs.map(codegenDefn(using Ctx(defnCtx))).unzip - val (defMain, declMain) = codegenTopNode(prog.main)(using Ctx(defnCtx)) - CompilationUnit(Ls(mlsPrelude), decls ++ decls2 :+ declMain, defs.flatten ++ defs2 :+ defMain :+ Def.RawDef(mlsEntryPoint)) + val fieldCtx = Set.empty[Local] + given Ctx = Ctx(fieldCtx) + val (defs, decls, methodsDef) = sortedClasses.map(codegenClassInfo).unzip3 + val (defs2, decls2) = prog.defs.map(codegenDefn).unzip + CompilationUnit(Ls(mlsPrelude), decls ++ decls2, defs.flatten ++ defs2 ++ methodsDef.flatten :+ Def.RawDef(mlsCallEntry(prog.entry |> allocIfNew)) :+ Def.RawDef(mlsEntryPoint)) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala index 042295874a..befe5ac59a 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Analysis.scala @@ -1,36 +1,33 @@ -package hkmc2.codegen.llir - -import mlscript._ -import hkmc2.codegen._ -import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } -import mlscript.utils._ -import mlscript.utils.shorthands._ -import hkmc2.semantics.BuiltinSymbol -import hkmc2.syntax.Tree.UnitLit -import hkmc2.{Raise, raise, Diagnostic, ErrorReport, Message} -import hkmc2.Message.MessageContext -import hkmc2.semantics.InnerSymbol -import hkmc2.codegen.llir.FuncRef.fromName -import scala.collection.mutable.ListBuffer +package hkmc2 +package codegen +package llir import scala.annotation.tailrec import scala.collection.immutable.* +import scala.collection.mutable.ListBuffer import scala.collection.mutable.{HashMap => MutHMap} import scala.collection.mutable.{HashSet => MutHSet, Set => MutSet} +import mlscript._ +import mlscript.utils._ +import mlscript.utils.shorthands._ + +import syntax.Tree.UnitLit +import semantics.{BuiltinSymbol, InnerSymbol} + class UsefulnessAnalysis(verbose: Bool = false): import Expr._ import Node._ def log(x: Any) = if verbose then println(x) - val uses = MutHMap[(Name, Int), Int]() - val defs = MutHMap[Name, Int]() + val uses = MutHMap[(Local, Int), Int]() + val defs = MutHMap[Local, Int]() - private def addDef(x: Name) = + private def addDef(x: Local) = defs.update(x, defs.getOrElse(x, 0) + 1) - private def addUse(x: Name) = + private def addUse(x: Local) = val def_count = defs.get(x) match case None => throw Exception(s"Use of undefined variable $x") case Some(value) => value @@ -68,58 +65,49 @@ class UsefulnessAnalysis(verbose: Bool = false): f(x.body) uses.toMap -class FreeVarAnalysis(extended_scope: Bool = true, verbose: Bool = false): +class FreeVarAnalysis(ctx: Local => Func): import Expr._ import Node._ - private val visited = MutHSet[Str]() - private def f(using defined: Set[Str])(defn: Func, fv: Set[Str]): Set[Str] = - val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) + private val visited = MutHSet[Local]() + private def f(using defined: Set[Local])(defn: Func, fv: Set[Local]): Set[Local] = + val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param) f(using defined2)(defn.body, fv) - private def f(using defined: Set[Str])(expr: Expr, fv: Set[Str]): Set[Str] = expr match - case Ref(name) => if defined.contains(name.str) then fv else fv + name.str + private def f(using defined: Set[Local])(expr: Expr, fv: Set[Local]): Set[Local] = expr match + case Ref(name) => if defined.contains(name) then fv else fv + name case Literal(lit) => fv case CtorApp(name, args) => args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) - case Select(name, cls, field) => if defined.contains(name.str) then fv else fv + name.str + case Select(name, cls, field) => if defined.contains(name) then fv else fv + name case BasicOp(name, args) => args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) case AssignField(assignee, _, _, value) => f(using defined)( value.toExpr, - if defined.contains(assignee.str) then fv + assignee.str else fv + if defined.contains(assignee) then fv + assignee else fv ) - private def f(using defined: Set[Str])(node: Node, fv: Set[Str]): Set[Str] = node match + private def f(using defined: Set[Local])(node: Node, fv: Set[Local]): Set[Local] = node match case Result(res) => res.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) - case Jump(defnref, args) => - var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) - if extended_scope && !visited.contains(defnref.name) then - val defn = defnref.expectFn - visited.add(defn.name) - val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) - fv2 = f(using defined2)(defn, fv2) - fv2 + case Jump(defn, args) => + args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) case Case(scrut, cases, default) => val fv2 = scrut match - case Ref(name) => if defined.contains(name.str) then fv else fv + name.str + case Ref(name) => if defined.contains(name) then fv else fv + name case _ => fv - val fv3 = cases.foldLeft(fv2) { + val fv3 = cases.foldLeft(fv2): case (acc, (cls, body)) => f(using defined)(body, acc) - } - fv3 + default match + case Some(body) => f(using defined)(body, fv3) + case None => fv3 + case Panic(msg) => fv case LetMethodCall(resultNames, cls, method, args, body) => var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) - val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name.str) + val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name) f(using defined2)(body, fv2) case LetExpr(name, expr, body) => val fv2 = f(using defined)(expr, fv) - val defined2 = defined + name.str + val defined2 = defined + name f(using defined2)(body, fv2) - case LetCall(resultNames, defnref, args, body) => + case LetCall(resultNames, defn, args, body) => var fv2 = args.foldLeft(fv)((acc, arg) => f(using defined)(arg.toExpr, acc)) - val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name.str) - if extended_scope && !visited.contains(defnref.name) then - val defn = defnref.expectFn - visited.add(defn.name) - val defined2 = defn.params.foldLeft(defined)((acc, param) => acc + param.str) - fv2 = f(using defined2)(defn, fv2) + val defined2 = resultNames.foldLeft(defined)((acc, name) => acc + name) f(using defined2)(body, fv2) def run(node: Node) = f(using Set.empty)(node, Set.empty) - def run_with(node: Node, defined: Set[Str]) = f(using defined)(node, Set.empty) + def run_with(node: Node, defined: Set[Local]) = f(using defined)(node, Set.empty) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala index 7705fa01a3..8592dfd090 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Builder.scala @@ -3,70 +3,80 @@ package codegen package llir import scala.collection.mutable.ListBuffer +import scala.collection.mutable.{HashMap => MutMap} import mlscript.utils.* import mlscript.utils.shorthands.* -import hkmc2.utils.* -import hkmc2.document.* -import hkmc2.Message.MessageContext +import utils.* +import document.* +import Message.MessageContext -import hkmc2.syntax.Tree -import hkmc2.semantics.* -import hkmc2.codegen.llir.{ Program => LlirProgram, Node, Func } -import hkmc2.codegen.Program +import syntax.Tree +import semantics.* +import codegen.llir.{ Program => LlirProgram, Node, Func } +import codegen.Program +import cpp.Expr.StrLit - -def err(msg: Message)(using Raise): Unit = +private def bErrStop(msg: Message)(using Raise) = raise(ErrorReport(msg -> N :: Nil, source = Diagnostic.Source.Compilation)) - -def errStop(msg: Message)(using Raise) = - err(msg) throw LowLevelIRError("stopped") +final case class FuncInfo(paramsSize: Int) + +final case class BuiltinSymbols( + var callableSym: Opt[Local] = None, + var thisSym: Opt[Local] = None, + var builtinSym: Opt[Local] = None, + fieldSym: MutMap[Int, VarSymbol] = MutMap.empty, + applySym: MutMap[Int, BlockMemberSymbol] = MutMap.empty, + tupleSym: MutMap[Int, MemberSymbol[? <: ClassLikeDef]] = MutMap.empty, + runtimeSym: Opt[TempSymbol] = None, +): + def hiddenClasses = callableSym.toSet final case class Ctx( - runtimeSymbol: TempSymbol, def_acc: ListBuffer[Func], class_acc: ListBuffer[ClassInfo], - symbol_ctx: Map[Str, Name] = Map.empty, - fn_ctx: Map[Local, Name] = Map.empty, // is a known function - class_ctx: Map[Local, Name] = Map.empty, - block_ctx: Map[Local, Name] = Map.empty, - is_top_level: Bool = true, + symbol_ctx: Map[Local, Local] = Map.empty, + fn_ctx: Map[Local, FuncInfo] = Map.empty, // is a known function + class_ctx: Map[MemberSymbol[? <: ClassLikeDef], ClassInfo] = Map.empty, + class_sym_ctx: Map[BlockMemberSymbol, MemberSymbol[? <: ClassLikeDef]] = Map.empty, + flow_ctx: Map[Path, Local] = Map.empty, + isTopLevel: Bool = true, + method_class: Opt[MemberSymbol[? <: ClassLikeDef]] = None, + builtinSym: BuiltinSymbols = BuiltinSymbols() ): - def addFuncName(n: Local, m: Name) = copy(fn_ctx = fn_ctx + (n -> m)) + def addFuncName(n: Local, paramsSize: Int) = copy(fn_ctx = fn_ctx + (n -> FuncInfo(paramsSize))) def findFuncName(n: Local)(using Raise) = fn_ctx.get(n) match - case None => - errStop(msg"Function name not found: ${n.toString()}") - Name("error") + case None => bErrStop(msg"Function name not found: ${n.toString()}") case Some(value) => value - def addClassName(n: Local, m: Name) = copy(class_ctx = class_ctx + (n -> m)) - def findClassName(n: Local)(using Raise) = class_ctx.get(n) match - case None => - errStop(msg"Class not found: ${n.toString}") + def addClassInfo(n: MemberSymbol[? <: ClassLikeDef], bsym: BlockMemberSymbol, m: ClassInfo) = + copy(class_ctx = class_ctx + (n -> m), class_sym_ctx = class_sym_ctx + (bsym -> n)) + def addName(n: Local, m: Local) = copy(symbol_ctx = symbol_ctx + (n -> m)) + def findName(n: Local)(using Raise) = symbol_ctx.get(n) match + case None => bErrStop(msg"Name not found: ${n.toString}") case Some(value) => value - def addName(n: Str, m: Name) = copy(symbol_ctx = symbol_ctx + (n -> m)) - def findName(n: Str)(using Raise): Name = symbol_ctx.get(n) match - case None => - errStop(msg"Name not found: $n") + def findClassInfo(n: MemberSymbol[? <: ClassLikeDef])(using Raise) = class_ctx.get(n) match + case None => bErrStop(msg"Class not found: ${n.toString}") case Some(value) => value - def nonTopLevel = copy(is_top_level = false) + def addKnownClass(n: Path, m: Local) = copy(flow_ctx = flow_ctx + (n -> m)) + def setClass(c: MemberSymbol[? <: ClassLikeDef]) = copy(method_class = Some(c)) + def nonTopLevel = copy(isTopLevel = false) object Ctx: def empty(using Elaborator.State) = - Ctx(Elaborator.State.runtimeSymbol, ListBuffer.empty, ListBuffer.empty) - + Ctx(ListBuffer.empty, ListBuffer.empty, builtinSym = BuiltinSymbols( + runtimeSym = Some(Elaborator.State.runtimeSymbol) + )) -final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: FreshInt): +final class LlirBuilder(using Elaborator.State)(tl: TraceLogger, uid: FreshInt): import tl.{trace, log, logs} def er = Expr.Ref def nr = Node.Result - def nme(x: Str) = Name(x) - def sr(x: Str) = er(Name(x)) - def sr(x: Name) = er(x) - def nsr(xs: Ls[Name]) = xs.map(er(_)) + def sr(x: Local) = er(x) + def nsr(xs: Ls[Local]) = xs.map(er(_)) private def allocIfNew(l: Local)(using Raise, Scope): String = trace[Str](s"allocIfNew begin: $l", x => s"allocIfNew end: $x"): @@ -88,19 +98,58 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case ts: semantics.InnerSymbol => summon[Scope].findThis_!(ts) case _ => summon[Scope].lookup_!(l) - - private def bBind(name: Opt[Str], e: Result, body: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + + private def symMap(s: Local)(using ctx: Ctx)(using Raise, Scope) = + ctx.findName(s) + + private def newTemp = TempSymbol(N, "x") + private def newNamedTemp(name: Str) = TempSymbol(N, name) + private def newNamedBlockMem(name: Str) = BlockMemberSymbol(name, Nil) + private def newNamed(name: Str) = VarSymbol(Tree.Ident(name)) + private def newClassSym(name: Str) = + ClassSymbol(Tree.TypeDef(hkmc2.syntax.Cls, Tree.Empty(), N, N), Tree.Ident(name)) + private def newTupleSym(len: Int) = + ClassSymbol(Tree.TypeDef(hkmc2.syntax.Cls, Tree.Empty(), N, N), Tree.Ident(s"Tuple$len")) + private def newVarSym(name: Str) = VarSymbol(Tree.Ident(name)) + private def newFunSym(name: Str) = BlockMemberSymbol(name, Nil) + private def newBuiltinSym(name: Str) = BuiltinSymbol(name, false, false, false, false) + private def builtinField(n: Int)(using Ctx) = summon[Ctx].builtinSym.fieldSym.getOrElseUpdate(n, newVarSym(s"field$n")) + private def builtinApply(n: Int)(using Ctx) = summon[Ctx].builtinSym.applySym.getOrElseUpdate(n, newFunSym(s"apply$n")) + private def builtinTuple(n: Int)(using Ctx) = summon[Ctx].builtinSym.tupleSym.getOrElseUpdate(n, newTupleSym(n)) + private def builtinCallable(using ctx: Ctx) : Local = + ctx.builtinSym.callableSym match + case None => + val sym = newBuiltinSym("Callable") + ctx.builtinSym.callableSym = Some(sym); + sym + case Some(value) => value + private def builtinThis(using ctx: Ctx) : Local = + ctx.builtinSym.thisSym match + case None => + val sym = newBuiltinSym("") + ctx.builtinSym.thisSym = Some(sym); + sym + case Some(value) => value + private def builtin(using ctx: Ctx) : Local = + ctx.builtinSym.builtinSym match + case None => + val sym = newBuiltinSym("") + ctx.builtinSym.builtinSym = Some(sym); + sym + case Some(value) => value + + private def bBind(name: Opt[Local], e: Result, body: Block)(k: TrivialExpr => Ctx ?=> Node)(ct: Block)(using ctx: Ctx)(using Raise, Scope): Node = trace[Node](s"bBind begin: $name", x => s"bBind end: ${x.show}"): bResult(e): case r: Expr.Ref => - given Ctx = ctx.addName(name.getOrElse(fresh.make.str), r.name) + given Ctx = ctx.addName(name.getOrElse(newTemp), r.sym) log(s"bBind ref: $name -> $r") - bBlock(body)(k) + bBlock(body)(k)(ct) case l: Expr.Literal => - val v = fresh.make - given Ctx = ctx.addName(name.getOrElse(fresh.make.str), v) + val v = newTemp + given Ctx = ctx.addName(name.getOrElse(newTemp), v) log(s"bBind lit: $name -> $v") - Node.LetExpr(v, l, bBlock(body)(k)) + Node.LetExpr(v, l, bBlock(body)(k)(ct)) private def bArgs(e: Ls[Arg])(k: Ls[TrivialExpr] => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = trace[Node](s"bArgs begin", x => s"bArgs end: ${x.show}"): @@ -118,82 +167,206 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case r: TrivialExpr => bPaths(xs): case rs: Ls[TrivialExpr] => k(r :: rs) + private def bNestedFunDef(e: FunDefn)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope): Node = + val FunDefn(_own, sym, params, body) = e + // generate it as a single named lambda expression that may be self-recursing + if params.length === 0 then + bErrStop(msg"Function without arguments not supported: ${params.length.toString}") + else + val fstParams = params.head + val wrappedLambda = params.tail.foldRight(body)((params, acc) => Return(Value.Lam(params, acc), false)) + bLam(Value.Lam(fstParams, wrappedLambda), S(sym.nme), S(sym))(k)(using ctx) + private def bFunDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = trace[Func](s"bFunDef begin: ${e.sym}", x => s"bFunDef end: ${x.show}"): val FunDefn(_own, sym, params, body) = e - if !ctx.is_top_level then - errStop(msg"Non top-level definition ${sym.nme} not supported") - else if params.length != 1 then - errStop(msg"Curried function or zero arguments function not supported: ${params.length.toString}") + assert(ctx.isTopLevel) + if params.length === 0 then + bErrStop(msg"Function without arguments not supported: ${params.length.toString}") else - val paramsList = params.head.params.map(x => x -> summon[Scope].allocateName(x.sym)) - val ctx2 = paramsList.foldLeft(ctx)((acc, x) => acc.addName(getVar_!(x._1.sym), x._2 |> nme)) - val ctx3 = ctx2.nonTopLevel - val pl = paramsList.map(_._2).map(nme) + val paramsList = params.head.params + val ctx2 = paramsList.foldLeft(ctx)((acc, x) => acc.addName(x.sym, x.sym)).nonTopLevel + val pl = paramsList.map(_.sym) + val wrappedLambda = params.tail.foldRight(body)((params, acc) => Return(Value.Lam(params, acc), false)) Func( - fnUid.make, - sym.nme, - params = pl, - resultNum = 1, - body = bBlock(body)(x => Node.Result(Ls(x)))(using ctx3) + uid.make, sym, params = pl, resultNum = 1, + body = bBlockWithEndCont(wrappedLambda)(x => Node.Result(Ls(x)))(using ctx2) ) - + + private def bMethodDef(e: FunDefn)(using ctx: Ctx)(using Raise, Scope): Func = + trace[Func](s"bFunDef begin: ${e.sym}", x => s"bFunDef end: ${x.show}"): + val FunDefn(_own, sym, params, body) = e + if !ctx.isTopLevel then + bErrStop(msg"Non top-level definition ${sym.nme} not supported") + else if params.length === 0 then + bErrStop(msg"Function without arguments not supported: ${params.length.toString}") + else + val paramsList = params.head.params + val ctx2 = paramsList.foldLeft(ctx)((acc, x) => acc.addName(x.sym, x.sym)).nonTopLevel + val pl = paramsList.map(_.sym) + val wrappedLambda = params.tail.foldRight(body)((params, acc) => Return(Value.Lam(params, acc), false)) + Func( + uid.make, sym, params = pl, resultNum = 1, + body = bBlockWithEndCont(wrappedLambda)(x => Node.Result(Ls(x)))(using ctx2) + ) + private def bClsLikeDef(e: ClsLikeDefn)(using ctx: Ctx)(using Raise, Scope): ClassInfo = trace[ClassInfo](s"bClsLikeDef begin", x => s"bClsLikeDef end: ${x.show}"): val ClsLikeDefn( - _own, _isym, sym, kind, paramsOpt, auxParams, parentSym, methods, privateFields, publicFields, preCtor, ctor) = e - if !ctx.is_top_level then - errStop(msg"Non top-level definition ${sym.nme} not supported") - else if !auxParams.isEmpty then - errStop(msg"The class ${sym.nme} has auxiliary parameters, which are not yet supported") + _own, isym, _sym, kind, paramsOpt, auxParams, parentSym, methods, privateFields, publicFields, preCtor, ctor) = e + if !ctx.isTopLevel then + bErrStop(msg"Non top-level definition ${isym.toString()} not supported") else - val clsDefn = sym.defn.getOrElse(die) val clsParams = paramsOpt.fold(Nil)(_.paramSyms) - val clsFields = publicFields + given Ctx = ctx.setClass(isym) + val funcs = methods.map(bMethodDef) + def parentFromPath(p: Path): Set[Local] = p match + case Value.Ref(l) => Set(fromMemToClass(l)) + case Select(Value.Ref(l), Tree.Ident("class")) => Set(fromMemToClass(l)) + case _ => bErrStop(msg"Unsupported parent path ${p.toString()}") ClassInfo( - clsUid.make, - sym.nme, - clsParams.map(_.nme) ++ clsFields.map(_.nme), + uid.make, + isym, + clsParams, + parentSym.fold(Set.empty)(parentFromPath), + funcs.map(f => f.name -> f).toMap, ) + private def bLam(lam: Value.Lam, nameHint: Opt[Str], recName: Opt[Local])(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + trace[Node](s"bLam begin", x => s"bLam end: ${x.show}"): + val Value.Lam(params, body) = lam + // Generate an auxiliary class inheriting from Callable + val freeVars = lam.freeVarsLLIR -- body.definedVars -- recName.iterator -- ctx.fn_ctx.keySet + log(s"Defined vars: ${body.definedVars}") + log(s"Match free vars: ${lam.freeVarsLLIR -- body.definedVars} ${ctx.fn_ctx.keySet} ${params.params.map(p => p.sym)}") + log(s"Lot: $lam") + val name = newClassSym(s"Lambda${nameHint.fold("")(x => "_" + x)}") + val freeVarsList = freeVars.toList + val args = freeVarsList.map(symMap) + // args may have the same name (with different uid) + // it's not allowed when generating the names of fields in the backend + val clsParams = args.zipWithIndex.map: + case (arg, i) => newVarSym(s"lam_arg$i") + val applyParams = params.params + // add the parameters of lambda expression to the context + val ctx2 = applyParams.foldLeft(ctx)((acc, x) => acc.addName(x.sym, x.sym)) + // add the free variables (class parameters) to the context + val ctx3 = clsParams.iterator.zip(freeVarsList).foldLeft(ctx2): + case (acc, (param, arg)) => acc.addName(arg, param) + val ctx4 = recName.fold(ctx3)(x => ctx3.addName(x, builtinThis)).nonTopLevel + val pl = applyParams.map(_.sym) + val method = Func( + uid.make, + builtinApply(params.params.length), + params = pl, + resultNum = 1, + body = bBlockWithEndCont(body)(x => Node.Result(Ls(x)))(using ctx4) + ) + ctx.class_acc += ClassInfo( + uid.make, + name, + clsParams, + Set(builtinCallable), + Map(method.name -> method), + ) + val v: Local = newTemp + val new_ctx = recName.fold(ctx)(x => ctx.addName(x, v)) + Node.LetExpr(v, Expr.CtorApp(name, args.map(sr)), k(v |> sr)(using new_ctx)) + private def bValue(v: Value)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bValue { $v } begin", x => s"bValue end: ${x.show}"): v match + case Value.Ref(l: TermSymbol) if l.owner.nonEmpty => + k(l |> sr) case Value.Ref(sym) if sym.nme.isCapitalized => - val v = fresh.make - Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(sym.nme), Ls()), k(v |> sr)) - case Value.Ref(l) => k(ctx.findName(getVar_!(l)) |> sr) - case Value.This(sym) => errStop(msg"Unsupported value: This"); Node.Result(Ls()) + val v: Local = newTemp + Node.LetExpr(v, Expr.CtorApp(fromMemToClass(sym), Ls()), k(v |> sr)) + case Value.Ref(l) => + ctx.fn_ctx.get(l) match + case Some(f) => + val tempSymbols = (0 until f.paramsSize).map(x => newNamed("arg")) + val paramsList = PlainParamList( + (0 until f.paramsSize).zip(tempSymbols).map((_n, sym) => + Param(FldFlags.empty, sym, N)).toList) + val app = Call(v, tempSymbols.map(x => Arg(false, Value.Ref(x))).toList)(true, false) + bLam(Value.Lam(paramsList, Return(app, false)), S(l.nme), N)(k) + case None => + k(ctx.findName(l) |> sr) + case Value.This(sym) => bErrStop(msg"Unsupported value: This") case Value.Lit(lit) => k(Expr.Literal(lit)) - case Value.Lam(params, body) => errStop(msg"Unsupported value: Lam"); Node.Result(Ls()) - case Value.Arr(elems) => errStop(msg"Unsupported value: Arr"); Node.Result(Ls()) + case lam @ Value.Lam(params, body) => bLam(lam, N, N)(k) + case Value.Arr(elems) => + bArgs(elems): + case args: Ls[TrivialExpr] => + val v: Local = newTemp + Node.LetExpr(v, Expr.CtorApp(builtinTuple(elems.length), args), k(v |> sr)) + case Value.Rcd(fields) => bErrStop(msg"Unsupported value: Rcd") + - private def getClassOfMem(p: FieldSymbol)(using ctx: Ctx)(using Raise, Scope): Local = - trace[Local](s"bMemSym { $p } begin", x => s"bMemSym end: $x"): + private def getClassOfField(p: FieldSymbol)(using ctx: Ctx)(using Raise, Scope): Local = + trace[Local](s"bClassOfField { $p } begin", x => s"bClassOfField end: $x"): p match case ts: TermSymbol => ts.owner.get - case ms: MemberSymbol[?] => ms.defn.get.sym + case ms: MemberSymbol[?] => + ms.defn match + case Some(d: ClassLikeDef) => d.owner.get + case Some(d: TermDefinition) => d.owner.get + case Some(value) => bErrStop(msg"Member symbol without class definition ${value.toString}") + case None => bErrStop(msg"Member symbol without definition ${ms.toString}") + + private def fromMemToClass(m: Symbol)(using ctx: Ctx)(using Raise, Scope): MemberSymbol[? <: ClassLikeDef] = + trace[MemberSymbol[? <: ClassLikeDef]](s"bFromMemToClass $m", x => s"bFromMemToClass end: $x"): + m match + case ms: MemberSymbol[?] => + ms.defn match + case Some(d: ClassLikeDef) => d.sym.asClsLike.getOrElse(bErrStop(msg"Class definition without symbol")) + case Some(value) => bErrStop(msg"Member symbol without class definition ${value.toString}") + case None => bErrStop(msg"Member symbol without definition ${ms.toString}") + case _ => bErrStop(msg"Unsupported symbol kind ${m.toString}") + private def bPath(p: Path)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bPath { $p } begin", x => s"bPath end: ${x.show}"): p match - case s @ Select(Value.Ref(sym), Tree.Ident("Unit")) if sym is ctx.runtimeSymbol => + case s @ Select(Value.Ref(sym), Tree.Ident("Unit")) if sym is ctx.builtinSym.runtimeSym.get => bPath(Value.Lit(Tree.UnitLit(false)))(k) + case s @ Select(Value.Ref(cls: ClassSymbol), name) if ctx.method_class.contains(cls) => + s.symbol match + case None => + ctx.flow_ctx.get(p) match + case Some(cls) => + k(cls |> sr) + case None => + bErrStop(msg"Unsupported selection by users") + case Some(s) => + k(s |> sr) + case s @ DynSelect(qual, fld, arrayIdx) => + bErrStop(msg"Unsupported dynamic selection") case s @ Select(qual, name) => log(s"bPath Select: $qual.$name with ${s.symbol}") s.symbol match case None => - errStop(msg"Unsupported selection by users") + ctx.flow_ctx.get(qual) match + case Some(cls) => + bPath(qual): + case q: Expr.Ref => + val v: Local = newTemp + val field = name.name + Node.LetExpr(v, Expr.Select(q.sym, cls, field), k(v |> sr)) + case q: Expr.Literal => + bErrStop(msg"Unsupported select on literal") + case None => + log(s"${ctx.flow_ctx}") + bErrStop(msg"Unsupported selection by users") case Some(value) => bPath(qual): case q: Expr.Ref => - val v = fresh.make - val cls = ClassRef.fromName(getClassOfMem(s.symbol.get).nme) + val v: Local = newTemp + val cls = getClassOfField(s.symbol.get) val field = name.name - Node.LetExpr(v, Expr.Select(q.name, cls, field), k(v |> sr)) + Node.LetExpr(v, Expr.Select(q.sym, cls, field), k(v |> sr)) case q: Expr.Literal => - errStop(msg"Unsupported select on literal") - Node.Result(Ls()) + bErrStop(msg"Unsupported select on literal") case x: Value => bValue(x)(k) private def bResult(r: Result)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = @@ -202,72 +375,106 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case Call(Value.Ref(sym: BuiltinSymbol), args) => bArgs(args): case args: Ls[TrivialExpr] => - val v = fresh.make - Node.LetExpr(v, Expr.BasicOp(sym.nme, args), k(v |> sr)) - case Call(Value.Ref(sym), args) if sym.nme.head.isUpper => + val v: Local = newTemp + Node.LetExpr(v, Expr.BasicOp(sym, args), k(v |> sr)) + case Call(Value.Ref(sym: MemberSymbol[?]), args) if sym.defn.exists(defn => defn match + case cls: ClassLikeDef => true + case _ => false + ) => + log(s"xxx $sym is ${sym.getClass()}") + bArgs(args): + case args: Ls[TrivialExpr] => + val v: Local = newTemp + Node.LetExpr(v, Expr.CtorApp(fromMemToClass(sym), args), k(v |> sr)) + case Call(s @ Value.Ref(sym), args) => + val v: Local = newTemp + ctx.fn_ctx.get(sym) match + case Some(f) => + bArgs(args): + case args: Ls[TrivialExpr] => + Node.LetCall(Ls(v), sym, args, k(v |> sr)) + case None => + bPath(s): + case f: TrivialExpr => + bArgs(args): + case args: Ls[TrivialExpr] => + Node.LetMethodCall(Ls(v), builtinCallable, builtinApply(args.length), f :: args, k(v |> sr)) + case Call(Select(Value.Ref(_: TopLevelSymbol), Tree.Ident("builtin")), args) => + bArgs(args): + case args: Ls[TrivialExpr] => + val v: Local = newTemp + Node.LetCall(Ls(v), builtin, args, k(v |> sr)) + case Call(Select(Select(Value.Ref(_: TopLevelSymbol), Tree.Ident("console")), Tree.Ident("log")), args) => bArgs(args): case args: Ls[TrivialExpr] => - val v = fresh.make - Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(sym.nme), args), k(v |> sr)) - case Call(Value.Ref(sym), args) => + val v: Local = newTemp + Node.LetCall(Ls(v), builtin, Expr.Literal(Tree.StrLit("println")) :: args, k(v |> sr)) + case Call(Select(Select(Value.Ref(_: TopLevelSymbol), Tree.Ident("Math")), Tree.Ident(mathPrimitive)), args) => bArgs(args): case args: Ls[TrivialExpr] => - val v = fresh.make - Node.LetCall(Ls(v), FuncRef.fromName(sym.nme), args, k(v |> sr)) - case Call(fn, args) => - bPath(fn): - case f: TrivialExpr => + val v: Local = newTemp + Node.LetCall(Ls(v), builtin, Expr.Literal(Tree.StrLit(mathPrimitive)) :: args, k(v |> sr)) + case Call(s @ Select(r @ Value.Ref(sym), Tree.Ident(fld)), args) if s.symbol.isDefined => + bPath(r): + case r => bArgs(args): case args: Ls[TrivialExpr] => - val v = fresh.make - Node.LetMethodCall(Ls(v), ClassRef(R("Callable")), Name("apply" + args.length), f :: args, k(v |> sr)) + val v: Local = newTemp + log(s"Method Call Select: $r.$fld with ${s.symbol}") + Node.LetMethodCall(Ls(v), getClassOfField(s.symbol.get), s.symbol.get, r :: args, k(v |> sr)) + case Call(_, _) => bErrStop(msg"Unsupported kind of Call ${r.toString()}") case Instantiate( Select(Value.Ref(sym), Tree.Ident("class")), args) => bPaths(args): case args: Ls[TrivialExpr] => - val v = fresh.make - Node.LetExpr(v, Expr.CtorApp(ClassRef.fromName(sym.nme), args), k(v |> sr)) + val v: Local = newTemp + Node.LetExpr(v, Expr.CtorApp(fromMemToClass(sym), args), k(v |> sr)) case Instantiate(cls, args) => - errStop(msg"Unsupported kind of Instantiate") - Node.Result(Ls()) + bErrStop(msg"Unsupported kind of Instantiate") case x: Path => bPath(x)(k) - private def bBlock(blk: Block)(k: TrivialExpr => Ctx ?=> Node)(using ctx: Ctx)(using Raise, Scope) : Node = + private def bBlockWithEndCont(blk: Block)(k: TrivialExpr => Ctx ?=> Node)(using Ctx)(using Raise, Scope) : Node = + bBlock(blk)(k)(End("")) + + private def bBlock(blk: Block)(k: TrivialExpr => Ctx ?=> Node)(ct: Block)(using ctx: Ctx)(using Raise, Scope) : Node = trace[Node](s"bBlock begin", x => s"bBlock end: ${x.show}"): blk match case Match(scrut, arms, dflt, rest) => bPath(scrut): case e: TrivialExpr => - val jp = fresh.make("j") - val fvset = (rest.freeVarsLLIR -- rest.definedVars).map(allocIfNew) + val nextCont = Begin(rest, ct) + val jp: BlockMemberSymbol = newNamedBlockMem("j") + val fvset = nextCont.freeVarsLLIR -- nextCont.definedVars -- ctx.fn_ctx.keySet val fvs1 = fvset.toList - val new_ctx = fvs1.foldLeft(ctx)((acc, x) => acc.addName(x, fresh.make)) + log(s"Match free vars: $fvset ${nextCont.freeVarsLLIR -- nextCont.definedVars} $fvs1") + val new_ctx = fvs1.foldLeft(ctx)((acc, x) => acc.addName(x, x)) val fvs = fvs1.map(new_ctx.findName(_)) def cont(x: TrivialExpr)(using ctx: Ctx) = Node.Jump( - FuncRef.fromName(jp), + jp, fvs1.map(x => ctx.findName(x)).map(sr) ) val casesList: Ls[(Pat, Node)] = arms.map: case (Case.Lit(lit), body) => - (Pat.Lit(lit), bBlock(body)(cont)(using ctx)) + (Pat.Lit(lit), bBlock(body)(cont)(nextCont)(using ctx)) case (Case.Cls(cls, _), body) => - (Pat.Class(ClassRef.fromName(cls.nme)), bBlock(body)(cont)(using ctx)) + (Pat.Class(cls), bBlock(body)(cont)(nextCont)(using ctx)) case (Case.Tup(len, inf), body) => - (Pat.Class(ClassRef.fromName("Tuple" + len.toString())), bBlock(body)(cont)(using ctx)) - val defaultCase = dflt.map(bBlock(_)(cont)) + val ctx2 = ctx.addKnownClass(scrut, builtinTuple(len)) + (Pat.Class(builtinTuple(len)), bBlock(body)(cont)(nextCont)(using ctx2)) + val defaultCase = dflt.map(bBlock(_)(cont)(nextCont)(using ctx)) val jpdef = Func( - fnUid.make, - jp.str, + uid.make, + jp, params = fvs, resultNum = 1, - bBlock(rest)(k)(using new_ctx), + bBlock(rest)(k)(ct)(using new_ctx), ) summon[Ctx].def_acc += jpdef Node.Case(e, casesList, defaultCase) case Return(res, implct) => bResult(res)(x => Node.Result(Ls(x))) - case Throw(Instantiate(Select(Value.Ref(_), ident), Ls(Value.Lit(Tree.StrLit(e))))) if ident.name == "Error" => + case Throw(Instantiate(Select(Value.Ref(_), ident), Ls(Value.Lit(Tree.StrLit(e))))) if ident.name === "Error" => Node.Panic(e) - case Label(label, body, rest) => ??? + case Label(label, body, rest) => TODO("Label not supported") case Break(label) => TODO("Break not supported") case Continue(label) => TODO("Continue not supported") case Begin(sub, rest) => @@ -276,55 +483,112 @@ final class LlirBuilder(tl: TraceLogger)(fresh: Fresh, fnUid: FreshInt, clsUid: case _: BlockTail => val definedVars = sub.definedVars definedVars.foreach(allocIfNew) - bBlock(sub): - x => bBlock(rest)(k) + bBlock(sub)(x => bBlock(rest)(k)(ct))(Begin(rest, ct)) case Assign(lhs, rhs, rest2) => - bBlock(Assign(lhs, rhs, Begin(rest2, rest)))(k) + bBlock(Assign(lhs, rhs, Begin(rest2, rest)))(k)(ct) case Begin(sub, rest2) => - bBlock(Begin(sub, Begin(rest2, rest)))(k) + bBlock(Begin(sub, Begin(rest2, rest)))(k)(ct) case Define(defn, rest2) => - bBlock(Define(defn, Begin(rest2, rest)))(k) + bBlock(Define(defn, Begin(rest2, rest)))(k)(ct) case Match(scrut, arms, dflt, rest2) => - bBlock(Match(scrut, arms, dflt, Begin(rest2, rest)))(k) + bBlock(Match(scrut, arms, dflt, Begin(rest2, rest)))(k)(ct) case _ => TODO(s"Other non-tail sub components of Begin not supported $sub") case TryBlock(sub, finallyDo, rest) => TODO("TryBlock not supported") case Assign(lhs, rhs, rest) => - val name = allocIfNew(lhs) - bBind(S(name), rhs, rest)(k) + bBind(S(lhs), rhs, rest)(k)(ct) case AssignField(lhs, nme, rhs, rest) => TODO("AssignField not supported") case Define(fd @ FunDefn(_own, sym, params, body), rest) => - val f = bFunDef(fd) - ctx.def_acc += f - val new_ctx = ctx.addFuncName(sym, Name(f.name)) - bBlock(rest)(k)(using new_ctx) - case Define(_: ClsLikeDefn, rest) => bBlock(rest)(k) + if ctx.isTopLevel then + val f = bFunDef(fd) + ctx.def_acc += f + bBlock(rest)(k)(ct) + else + bNestedFunDef(fd): + case r: TrivialExpr => + bBlock(rest)(k)(ct) + case Define(_: ClsLikeDefn, rest) => bBlock(rest)(k)(ct) case End(msg) => k(Expr.Literal(Tree.UnitLit(false))) case _: Block => val docBlock = blk.showAsTree - errStop(msg"Unsupported block: $docBlock") - Node.Result(Ls()) + bErrStop(msg"Unsupported block: $docBlock") def registerClasses(b: Block)(using ctx: Ctx)(using Raise, Scope): Ctx = b match case Define(cd @ ClsLikeDefn(_own, isym, sym, kind, _paramsOpt, auxParams, parentSym, methods, privateFields, publicFields, preCtor, ctor), rest) => if !auxParams.isEmpty then - errStop(msg"The class ${sym.nme} has auxiliary parameters, which are not yet supported") + bErrStop(msg"The class ${sym.nme} has auxiliary parameters, which are not yet supported") val c = bClsLikeDef(cd) ctx.class_acc += c - val new_ctx = ctx.addClassName(sym, Name(c.name)).addClassName(isym, Name(c.name)) - log(s"Define class: ${sym.nme} -> ${new_ctx}") + val new_ctx = ctx.addClassInfo(isym, sym, c) + log(s"Define class: ${isym.toString()} -> ${ctx}") registerClasses(rest)(using new_ctx) case _ => b.subBlocks.foldLeft(ctx)((ctx, rest) => registerClasses(rest)(using ctx)) + + def registerBuiltinClasses(using ctx: Ctx)(using Raise, Scope): Ctx = + ctx.builtinSym.tupleSym.foldLeft(ctx): + case (ctx, (len, sym)) => + val c = ClassInfo(uid.make, sym, (0 until len).map(x => builtinField(x)).toList, Set.empty, Map.empty) + ctx.class_acc += c + ctx.addClassInfo(sym, BlockMemberSymbol(sym.nme, Nil), c) + + def registerFunctions(b: Block)(using ctx: Ctx)(using Raise, Scope): Ctx = + var ctx2 = ctx + new BlockTraverser: + applyBlock(b) + + override def applyBlock(b: Block): Unit = b match + case Match(scrut, arms, dflt, rest) => applyBlock(rest) + case Return(res, implct) => + case Throw(exc) => + case Label(label, body, rest) => applyBlock(rest) + case Break(label) => + case Continue(label) => + case Begin(sub, rest) => applyBlock(rest) + case TryBlock(sub, finallyDo, rest) => applyBlock(rest) + case Assign(lhs, rhs, rest) => applyBlock(rest) + case AssignField(_, _, _, rest) => applyBlock(rest) + case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => applyBlock(rest) + case Define(defn, rest) => applyDefn(defn); applyBlock(rest) + case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) => applyBlock(rest) + case End(msg) => + + override def applyDefn(defn: Defn): Unit = defn match + case f: FunDefn => applyFunDefn(f) + case _ => () + + override def applyFunDefn(fun: FunDefn): Unit = + val FunDefn(_own, sym, params, body) = fun + if params.length === 0 then + bErrStop(msg"Function without arguments not supported: ${params.length.toString}") + ctx2 = ctx2.addFuncName(sym, params.head.params.length) + log(s"Define function: ${sym.nme} -> ${ctx2}") + ctx2 - def bProg(e: Program)(using Raise, Scope, Elaborator.State): LlirProgram = - var ctx = Ctx.empty + def bProg(e: Program)(using Raise, Scope, Ctx): (LlirProgram, Ctx) = + var ctx = summon[Ctx] // * Classes may be defined after other things such as functions, // * especially now that the elaborator moves all functions to the top of the block. ctx = registerClasses(e.main)(using ctx) + ctx = registerFunctions(e.main)(using ctx) + + log(s"Classes: ${ctx.class_ctx}") + + val entryBody = bBlockWithEndCont(e.main)(x => Node.Result(Ls(x)))(using ctx) + val entryFunc = Func( + uid.make, newFunSym("entry"), params = Ls.empty, resultNum = 1, + body = entryBody + ) + ctx.def_acc += entryFunc + + ctx = registerBuiltinClasses(using ctx) + + val prog = LlirProgram(ctx.class_acc.toSet, ctx.def_acc.toSet, entryFunc.name) + + ctx.class_acc.clear() + ctx.def_acc.clear() - val entry = bBlock(e.main)(x => Node.Result(Ls(x)))(using ctx) - LlirProgram(ctx.class_acc.toSet, ctx.def_acc.toSet, entry) + (prog, ctx) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala index 0c5688eab6..ea12ef9a8b 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Fresh.scala @@ -1,23 +1,5 @@ package hkmc2.codegen.llir -import collection.mutable.{HashMap => MutHMap} -import mlscript.utils.shorthands._ - -final class Fresh(div : Char = '$'): - private val counter = MutHMap[Str, Int]() - private def gensym(s: Str) = { - val n = s.lastIndexOf(div) - val (ts, suffix) = s.splitAt(if n == -1 then s.length() else n) - var x = if suffix.stripPrefix(div.toString).forall(_.isDigit) then ts else s - val count = counter.getOrElse(x, 0) - val tmp = s"$x$div$count" - counter.update(x, count + 1) - Name(tmp) - } - - def make(s: Str) = gensym(s) - def make = gensym("x") - final class FreshInt: private var counter = 0 def make: Int = { diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala index 97af2f0bd5..33380a55dd 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Interp.scala @@ -1,4 +1,6 @@ -package hkmc2.codegen.llir +package hkmc2 +package codegen +package llir import mlscript.* import mlscript.utils.* @@ -7,8 +9,9 @@ import scala.collection.mutable.ListBuffer import shorthands.* import scala.util.boundary, boundary.break -import hkmc2.codegen.llir.* -import hkmc2.syntax.Tree +import syntax.Tree +import hkmc2.utils.TraceLogger +import semantics.BuiltinSymbol enum Stuck: case StuckExpr(expr: Expr, msg: Str) @@ -21,8 +24,8 @@ enum Stuck: final case class InterpreterError(message: String) extends Exception(message) -class Interpreter(verbose: Bool): - private def log(x: Any) = if verbose then println(x) +class Interpreter(tl: TraceLogger): + import tl.{trace, log, logs} import Stuck._ private case class Configuration( @@ -38,7 +41,7 @@ class Interpreter(verbose: Bool): override def toString: String = import hkmc2.syntax.Tree.* this match - case Class(cls, fields) => s"${cls.name}(${fields.mkString(",")})" + case Class(cls, fields) => s"${cls.symbol.nme}(${fields.mkString(",")})" case Literal(IntLit(lit)) => lit.toString case Literal(BoolLit(lit)) => lit.toString case Literal(DecLit(lit)) => lit.toString @@ -47,9 +50,10 @@ class Interpreter(verbose: Bool): if isNullNotUndefined then "null" else "undefined" private final case class Ctx( - bindingCtx: Map[Str, Value], - classCtx: Map[Str, ClassInfo], - funcCtx: Map[Str, Func], + bindingCtx: Map[Local, Value], + classCtx: Map[Local, ClassInfo], + funcCtx: Map[Local, Func], + thisVal: Opt[Value], ) import Node._ @@ -66,7 +70,10 @@ class Interpreter(verbose: Bool): case ("-", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x - y))) case ("*", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x * y))) case ("/", Li(IntLit(x)), Li(IntLit(y))) => S(Li(IntLit(x / y))) - case ("==", Li(IntLit(x)), Li(IntLit(y))) => S(if x == y then getTrue else getFalse) + case ("&&", Li(BoolLit(x)), Li(BoolLit(y))) => S(if x && y then getTrue else getFalse) + case ("||", Li(BoolLit(x)), Li(BoolLit(y))) => S(if x || y then getTrue else getFalse) + case ("==", Li(IntLit(x)), Li(IntLit(y))) => S(if x === y then getTrue else getFalse) + case ("===", Li(IntLit(x)), Li(IntLit(y))) => S(if x === y then getTrue else getFalse) case ("!=", Li(IntLit(x)), Li(IntLit(y))) => S(if x != y then getTrue else getFalse) case ("<=", Li(IntLit(x)), Li(IntLit(y))) => S(if x <= y then getTrue else getFalse) case (">=", Li(IntLit(x)), Li(IntLit(y))) => S(if x >= y then getTrue else getFalse) @@ -79,29 +86,41 @@ class Interpreter(verbose: Bool): var stuck: Opt[Stuck] = None exprs foreach { expr => stuck match - case None => eval(expr) match + case None => eval_t(expr) match case L(x) => stuck = Some(x) case R(x) => values += x case _ => () } stuck.toLeft(values.toList) - private def eval(expr: TrivialExpr)(using ctx: Ctx): Result[Value] = expr match - case e @ Ref(name) => ctx.bindingCtx.get(name.str).toRight(StuckExpr(e, s"undefined variable $name")) - case Literal(lit) => R(Value.Literal(lit)) + private def eval_t(expr: TrivialExpr)(using ctx: Ctx): Result[Value] = expr match + case Ref(x: BuiltinSymbol) => x.nme match + case "" => ctx.thisVal.toRight(StuckExpr(expr.toExpr, s"undefined this value")) + case _ => L(StuckExpr(expr.toExpr, s"undefined builtin ${x.nme}")) + case Ref(x) => ctx.bindingCtx.get(x).toRight(StuckExpr(expr.toExpr, s"undefined variable $x")) + case Literal(x) => R(Value.Literal(x)) private def eval(expr: Expr)(using ctx: Ctx): Result[Value] = expr match - case Ref(Name(x)) => ctx.bindingCtx.get(x).toRight(StuckExpr(expr, s"undefined variable $x")) - case Literal(x) => R(Value.Literal(x)) + case x: TrivialExpr => eval_t(x) case CtorApp(cls, args) => for xs <- evalArgs(args) - cls <- ctx.classCtx.get(cls.name).toRight(StuckExpr(expr, s"undefined class ${cls.name}")) + cls <- ctx.classCtx.get(cls).toRight(StuckExpr(expr, s"undefined class ${cls.nme}")) yield Value.Class(cls, xs) + case Select(name, cls, field) if field.forall(_.isDigit) => + val nth = field.toInt + ctx.bindingCtx.get(name).toRight(StuckExpr(expr, s"undefined variable $name")).flatMap { + case Value.Class(cls2, xs) if cls === cls2.symbol => + xs.lift(nth) match + case Some(x) => R(x) + case None => L(StuckExpr(expr, s"unable to find selected field $field")) + case Value.Class(cls2, xs) => L(StuckExpr(expr, s"unexpected class $cls2")) + case x => L(StuckExpr(expr, s"unexpected value $x")) + } case Select(name, cls, field) => - ctx.bindingCtx.get(name.str).toRight(StuckExpr(expr, s"undefined variable $name")).flatMap { - case Value.Class(cls2, xs) if cls.name == cls2.name => - xs.zip(cls2.fields).find{_._2 == field} match + ctx.bindingCtx.get(name).toRight(StuckExpr(expr, s"undefined variable $name")).flatMap { + case Value.Class(cls2, xs) if cls === cls2.symbol => + xs.zip(cls2.fields).find{_._2.nme === field} match case Some((x, _)) => R(x) case None => L(StuckExpr(expr, s"unable to find selected field $field")) case Value.Class(cls2, xs) => L(StuckExpr(expr, s"unexpected class $cls2")) @@ -110,21 +129,21 @@ class Interpreter(verbose: Bool): case BasicOp(name, args) => boundary: evalArgs(args).flatMap( xs => - name match - case "+" | "-" | "*" | "/" | "==" | "!=" | "<=" | ">=" | "<" | ">" => + name.nme match + case "+" | "-" | "*" | "/" | "==" | "===" | "!=" | "<=" | ">=" | "<" | ">" => if xs.length < 2 then break: L(StuckExpr(expr, s"not enough arguments for basic operation $name")) - else eval(name, xs.head, xs.tail.head).toRight(StuckExpr(expr, s"unable to evaluate basic operation")) + else eval(name.nme, xs.head, xs.tail.head).toRight(StuckExpr(expr, s"unable to evaluate basic operation")) case _ => L(StuckExpr(expr, s"unexpected basic operation $name"))) case AssignField(assignee, cls, field, value) => for - x <- eval(Ref(assignee): TrivialExpr) - y <- eval(value) + x <- eval_t(Ref(assignee): TrivialExpr) + y <- eval_t(value) res <- x match - case obj @ Value.Class(cls2, xs) if cls.name == cls2.name => - xs.zip(cls2.fields).find{_._2 == field} match + case obj @ Value.Class(cls2, xs) if cls === cls2 => + xs.zip(cls2.fields).find{_._2.nme === field} match case Some((_, _)) => - obj.fields = xs.map(x => if x == obj then y else x) + obj.fields = xs.map(x => if x === obj then y else x) // Ideally, we should return a unit value here, but here we return the assignee value for simplicity. R(obj) case None => L(StuckExpr(expr, s"unable to find selected field $field")) @@ -137,26 +156,26 @@ class Interpreter(verbose: Bool): case Jump(func, args) => for xs <- evalArgs(args) - func <- ctx.funcCtx.get(func.name).toRight(StuckNode(node, s"undefined function ${func.name}")) - ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.map{_.str}.zip(xs)) + func <- ctx.funcCtx.get(func).toRight(StuckNode(node, s"undefined function ${func.nme}")) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.zip(xs)) res <- eval(func.body)(using ctx1) yield res case Case(scrut, cases, default) => - eval(scrut) flatMap { + eval_t(scrut) flatMap { case Value.Class(cls, fields) => cases.find { - case (Pat.Class(cls2), _) => cls.name == cls2.name + case (Pat.Class(cls2), _) => cls.symbol === cls2 case _ => false } match { case Some((_, x)) => eval(x) case None => default match case S(x) => eval(x) - case N => L(StuckNode(node, s"can not find the matched case, ${cls.name} expected")) + case N => L(StuckNode(node, s"can not find the matched case, ${cls.symbol} expected")) } case Value.Literal(lit) => cases.find { - case (Pat.Lit(lit2), _) => lit == lit2 + case (Pat.Lit(lit2), _) => lit === lit2 case _ => false } match { case Some((_, x)) => eval(x) @@ -169,41 +188,53 @@ class Interpreter(verbose: Bool): case LetExpr(name, expr, body) => for x <- eval(expr) - ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx + (name.str -> x)) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx + (name -> x)) res <- eval(body)(using ctx1) yield res case LetMethodCall(names, cls, method, args, body) => + def lookup_method(cls: ClassInfo, method: Local): Option[Func] = + // The methods with the same name in a subclass will override the method in the superclass. + // But they have different symbols for the method definition. + // So, we don't directly use the method symbol to find the method. + // Instead, we fallback to use the method name. + cls.methods.find(_._1.nme === method.nme).map(_._2) for ys <- evalArgs(args).flatMap { - case Value.Class(cls2, xs) :: args => - cls2.methods.get(method.str).toRight(StuckNode(node, s"undefined method ${method.str}")).flatMap { method => - val ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ cls2.fields.zip(xs) ++ method.params.map{_.str}.zip(args)) + case (ths @ Value.Class(cls2, xs)) :: args => + lookup_method(cls2, method).toRight(StuckNode(node, s"undefined method ${method.nme}")).flatMap { method => + val ctx1 = ctx.copy( + bindingCtx = ctx.bindingCtx ++ cls2.fields.zip(xs) ++ method.params.zip(args), + thisVal = S(ths) + ) eval(method.body)(using ctx1) } case _ => L(StuckNode(node, s"not enough arguments for method call, or the first argument is not a class")) } - ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) + ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.zip(ys)) res <- eval(body)(using ctx2) yield res case LetCall(names, func, args, body) => for xs <- evalArgs(args) - func <- ctx.funcCtx.get(func.name).toRight(StuckNode(node, s"undefined function ${func.name}")) - ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.map{_.str}.zip(xs)) + func <- ctx.funcCtx.get(func).toRight(StuckNode(node, s"undefined function ${func.nme}")) + ctx1 = ctx.copy(bindingCtx = ctx.bindingCtx ++ func.params.zip(xs)) ys <- eval(func.body)(using ctx1) - ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.map{_.str}.zip(ys)) + ctx2 = ctx.copy(bindingCtx = ctx.bindingCtx ++ names.zip(ys)) res <- eval(body)(using ctx2) yield res case Panic(msg) => L(StuckNode(node, msg)) private def f(prog: Program): Ls[Value] = - val Program(classes, defs, main) = prog + val Program(classes, defs, entry) = prog given Ctx = Ctx( bindingCtx = Map.empty, - classCtx = classes.map{cls => cls.name -> cls}.toMap, - funcCtx = defs.map{func => func.name -> func}.toMap, + classCtx = classes.map(cls => (cls.symbol, cls)).toMap, + funcCtx = defs.map(func => (func.name, func)).toMap, + thisVal = None, ) - eval(main) match + val entryFunc = summon[Ctx].funcCtx.getOrElse(entry, throw InterpreterError("Entry doesn't exist")) + assert(entryFunc.params.isEmpty, "Entry function should not have parameters") + eval(entryFunc.body) match case R(x) => x case L(x) => throw InterpreterError("Stuck evaluation: " + x.toString) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala index c749014a1b..bfe1aff914 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Llir.scala @@ -1,44 +1,54 @@ -package hkmc2.codegen.llir +package hkmc2 +package codegen.llir import mlscript._ import mlscript.utils._ import mlscript.utils.shorthands._ -import hkmc2.syntax._ -import hkmc2.Message.MessageContext -import hkmc2.document._ +import syntax._ +import Message.MessageContext +import document._ +import codegen._ import util.Sorting import collection.immutable.SortedSet import language.implicitConversions import collection.mutable.{Map as MutMap, Set as MutSet, HashMap, ListBuffer} +import hkmc2.semantics._ private def raw(x: String): Document = doc"$x" final case class LowLevelIRError(message: String) extends Exception(message) +private def docSymWithUid(sym: Local): Document = doc"${sym.nme}$$${sym.uid.toString()}" + +val hiddenPrefixes = Set("Tuple") + +def defaultHidden(x: Str): Bool = + hiddenPrefixes.exists(x.startsWith) + case class Program( classes: Set[ClassInfo], defs: Set[Func], - main: Node, + entry: Local, ): override def toString: String = val t1 = classes.toArray val t2 = defs.toArray Sorting.quickSort(t1) Sorting.quickSort(t2) - s"Program({${t1.mkString(",\n")}}, {\n${t2.mkString("\n")}\n},\n$main)" + s"Program({${t1.mkString(",\n")}}, {\n${t2.mkString("\n")}\n},\n$entry)" - def show(hiddenNames: Set[Str] = Set.empty) = toDocument(hiddenNames).toString - def toDocument(hiddenNames: Set[Str] = Set.empty) : Document = - val t1 = classes.toArray + def show(hide: Str => Bool = defaultHidden) = toDocument(hide).toString + def toDocument(hide: Str => Bool = defaultHidden) : Document = + val t1 = classes.iterator.filterNot(c => hide(c.symbol.nme)).toArray val t2 = defs.toArray Sorting.quickSort(t1) Sorting.quickSort(t2) given Conversion[String, Document] = raw - val docClasses = t1.filter(x => !hiddenNames.contains(x.name)).map(_.toDocument).toList.mkDocument(doc" # ") + val docClasses = t1.map(_.toDocument).toList.mkDocument(doc" # ") val docDefs = t2.map(_.toDocument).toList.mkDocument(doc" # ") - val docMain = main.toDocument + val docMain = doc"entry = ${entry.nme}$$${entry.uid.toString()}" doc" #{ $docClasses\n$docDefs\n$docMain #} " implicit object ClassInfoOrdering extends Ordering[ClassInfo] { @@ -47,56 +57,31 @@ implicit object ClassInfoOrdering extends Ordering[ClassInfo] { case class ClassInfo( id: Int, - name: Str, - fields: Ls[Str], + symbol: MemberSymbol[? <: ClassLikeDef], + fields: Ls[VarSymbol], + parents: Set[Local], + methods: Map[Local, Func], ): - var parents: Set[Str] = Set.empty - var methods: Map[Str, Func] = Map.empty override def hashCode: Int = id override def toString: String = - s"ClassInfo($id, $name, [${fields mkString ","}], parents: ${parents mkString ","}, methods:\n${methods mkString ",\n"})" + s"ClassInfo($id, $symbol, [${fields mkString ","}], parents: ${parents mkString ","}, methods:\n${methods mkString ",\n"})" def show = toDocument.toString def toDocument: Document = given Conversion[String, Document] = raw - val ext = if parents.isEmpty then "" else " extends " + parents.mkString(", ") + val ext = if parents.isEmpty then "" else " extends " + parents.map(_.nme).mkString(", ") if methods.isEmpty then - doc"class $name(${fields.mkString(",")})$ext" + doc"class ${symbol.nme}(${fields.map(docSymWithUid).mkString(",")})$ext" else - val docFirst = doc"class $name (${fields.mkString(",")})$ext {" + val docFirst = doc"class ${symbol.nme}(${fields.map(docSymWithUid).mkString(",")})$ext {" val docMethods = methods.map { (_, func) => func.toDocument }.toList.mkDocument(doc" # ") val docLast = doc"}" - doc"$docFirst #{ # $docMethods # #} $docLast" - -case class Name(str: Str): - def trySubst(map: Map[Str, Name]) = map.getOrElse(str, this) - override def toString: String = str - -object FuncRef: - def fromName(name: Str) = FuncRef(Right(name)) - def fromName(name: Name) = FuncRef(Right(name.str)) - def fromFunc(func: Func) = FuncRef(Left(func)) - -class FuncRef(var func: Either[Func, Str]): - def name: String = func.fold(_.name, x => x) - def expectFn: Func = func.fold(identity, x => throw Exception(s"Expected a def, but got $x")) - def getFunc: Opt[Func] = func.left.toOption - override def equals(o: Any): Bool = o match { - case o: FuncRef => o.name == this.name - case _ => false - } - -object ClassRef: - def fromName(name: Str) = ClassRef(Right(name)) - def fromName(name: Name) = ClassRef(Right(name.str)) - def fromClass(cls: ClassInfo) = ClassRef(Left(cls)) + doc"$docFirst #{ # $docMethods #} # $docLast" -class ClassRef(var cls: Either[ClassInfo, Str]): - def name: String = cls.fold(_.name, x => x) - def expectCls: ClassInfo = cls.fold(identity, x => throw Exception(s"Expected a class, but got $x")) - def getClass: Opt[ClassInfo] = cls.left.toOption +class FuncRef(var func: Local): + def name: String = func.nme override def equals(o: Any): Bool = o match { - case o: ClassRef => o.name == this.name + case o: FuncRef => o.name === this.name case _ => false } @@ -106,8 +91,8 @@ implicit object FuncOrdering extends Ordering[Func] { case class Func( id: Int, - name: Str, - params: Ls[Name], + name: BlockMemberSymbol, + params: Ls[Local], resultNum: Int, body: Node ): @@ -121,7 +106,7 @@ case class Func( def show = toDocument def toDocument: Document = given Conversion[String, Document] = raw - val docFirst = doc"def $name(${params.map(_.toString).mkString(",")}) =" + val docFirst = doc"def ${docSymWithUid(name)}(${params.map(docSymWithUid).mkString(",")}) =" val docBody = body.toDocument doc"$docFirst #{ # $docBody #} " @@ -131,16 +116,22 @@ sealed trait TrivialExpr: def show: String def toDocument: Document def toExpr: Expr = this match { case x: Expr => x } + def foldRef(f: Local => TrivialExpr): TrivialExpr = this match + case Ref(sym) => f(sym) + case _ => this + def iterRef(f: Local => Unit): Unit = this match + case Ref(sym) => f(sym) + case _ => () private def showArguments(args: Ls[TrivialExpr]) = args map (_.show) mkString "," enum Expr: - case Ref(name: Name) extends Expr, TrivialExpr + case Ref(sym: Local) extends Expr, TrivialExpr case Literal(lit: hkmc2.syntax.Literal) extends Expr, TrivialExpr - case CtorApp(cls: ClassRef, args: Ls[TrivialExpr]) - case Select(name: Name, cls: ClassRef, field: Str) - case BasicOp(name: Str, args: Ls[TrivialExpr]) - case AssignField(assignee: Name, cls: ClassRef, field: Str, value: TrivialExpr) + case CtorApp(cls: MemberSymbol[? <: ClassLikeDef], args: Ls[TrivialExpr]) + case Select(name: Local, cls: Local, field: Str) + case BasicOp(name: BuiltinSymbol, args: Ls[TrivialExpr]) + case AssignField(assignee: Local, cls: Local, field: Str, value: TrivialExpr) override def toString: String = show @@ -149,40 +140,40 @@ enum Expr: def toDocument: Document = given Conversion[String, Document] = raw this match - case Ref(s) => s.toString + case Ref(s) => docSymWithUid(s) case Literal(Tree.BoolLit(lit)) => s"$lit" case Literal(Tree.IntLit(lit)) => s"$lit" case Literal(Tree.DecLit(lit)) => s"$lit" - case Literal(Tree.StrLit(lit)) => s"$lit" + case Literal(Tree.StrLit(lit)) => s"${lit.escaped}" case Literal(Tree.UnitLit(isNullNotUndefined)) => if isNullNotUndefined then "null" else "undefined" case CtorApp(cls, args) => - doc"${cls.name}(${args.map(_.toString).mkString(",")})" + doc"${docSymWithUid(cls)}(${args.map(_.toString).mkString(",")})" case Select(s, cls, fld) => - doc"${s.toString}.<${cls.name}:$fld>" - case BasicOp(name: Str, args) => - doc"$name(${args.map(_.toString).mkString(",")})" + doc"${docSymWithUid(s)}.<${docSymWithUid(cls)}:$fld>" + case BasicOp(sym, args) => + doc"${sym.nme}(${args.map(_.toString).mkString(",")})" case AssignField(assignee, clsInfo, fieldName, value) => - doc"${assignee.toString}.${fieldName} := ${value.toString}" + doc"${docSymWithUid(assignee)}.${fieldName} := ${value.toString}" enum Pat: case Lit(lit: hkmc2.syntax.Literal) - case Class(cls: ClassRef) + case Class(cls: Local) override def toString: String = this match case Lit(lit) => s"$lit" - case Class(cls) => s"${cls.name}" + case Class(cls) => s"${{docSymWithUid(cls)}}" enum Node: // Terminal forms: case Result(res: Ls[TrivialExpr]) - case Jump(func: FuncRef, args: Ls[TrivialExpr]) + case Jump(func: Local, args: Ls[TrivialExpr]) case Case(scrutinee: TrivialExpr, cases: Ls[(Pat, Node)], default: Opt[Node]) case Panic(msg: Str) // Intermediate forms: - case LetExpr(name: Name, expr: Expr, body: Node) - case LetMethodCall(names: Ls[Name], cls: ClassRef, method: Name, args: Ls[TrivialExpr], body: Node) - case LetCall(names: Ls[Name], func: FuncRef, args: Ls[TrivialExpr], body: Node) + case LetExpr(name: Local, expr: Expr, body: Node) + case LetMethodCall(names: Ls[Local], cls: Local, method: Local, args: Ls[TrivialExpr], body: Node) + case LetCall(names: Ls[Local], func: Local, args: Ls[TrivialExpr], body: Node) override def toString: String = show @@ -193,7 +184,7 @@ enum Node: this match case Result(res) => (res |> showArguments) case Jump(jp, args) => - doc"jump ${jp.name}(${args |> showArguments})" + doc"jump ${docSymWithUid(jp)}(${args |> showArguments})" case Case(x, cases, default) => val docFirst = doc"case ${x.toString} of" val docCases = cases.map { @@ -207,9 +198,9 @@ enum Node: case Panic(msg) => doc"panic ${s"\"$msg\""}" case LetExpr(x, expr, body) => - doc"let ${x.toString} = ${expr.toString} in # ${body.toDocument}" + doc"let ${docSymWithUid(x)} = ${expr.toString} in # ${body.toDocument}" case LetMethodCall(xs, cls, method, args, body) => - doc"let ${xs.map(_.toString).mkString(",")} = ${cls.name}.${method.toString}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" + doc"let ${xs.map(docSymWithUid).mkString(",")} = ${cls.nme}.${docSymWithUid(method)}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" case LetCall(xs, func, args, body) => - doc"let* (${xs.map(_.toString).mkString(",")}) = ${func.name}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" + doc"let* (${xs.map(docSymWithUid).mkString(",")}) = ${func.nme}(${args.map(_.toString).mkString(",")}) in # ${body.toDocument}" diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala deleted file mode 100644 index 5b6da3eab9..0000000000 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/RefResolver.scala +++ /dev/null @@ -1,56 +0,0 @@ -package hkmc2.codegen.llir - -import mlscript.utils.shorthands._ - -import Node._ - -// Resolves the definition references by turning them from Right(name) to Left(Func). -private final class RefResolver(defs: Map[Str, Func], classes: Map[Str, ClassInfo], allowInlineJp: Bool): - private def f(x: Expr): Unit = x match - case Expr.Ref(name) => - case Expr.Literal(lit) => - case Expr.CtorApp(cls, args) => classes.get(cls.name) match - case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") - case Some(value) => cls.cls = Left(value) - case Expr.Select(name, cls, field) => classes.get(cls.name) match - case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") - case Some(value) => cls.cls = Left(value) - case Expr.BasicOp(name, args) => - case Expr.AssignField(name, cls, field, value) => classes.get(cls.name) match - case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") - case Some(value) => cls.cls = Left(value) - - private def f(x: Pat): Unit = x match - case Pat.Lit(lit) => - case Pat.Class(cls) => classes.get(cls.name) match - case None => throw LowLevelIRError(f"unknown class ${cls.name} in ${classes.keySet.mkString(",")}") - case Some(value) => cls.cls = Left(value) - - private def f(x: Node): Unit = x match - case Result(res) => - case Case(scrut, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f - case LetExpr(name, expr, body) => f(expr); f(body) - case LetMethodCall(names, cls, method, args, body) => f(body) - case LetCall(resultNames, defnref, args, body) => - defs.get(defnref.name) match - case Some(defn) => defnref.func = Left(defn) - case None => throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") - f(body) - case Jump(defnref, _) => - // maybe not promoted yet - defs.get(defnref.name) match - case Some(defn) => defnref.func = Left(defn) - case None => - if !allowInlineJp then - throw LowLevelIRError(f"unknown function ${defnref.name} in ${defs.keySet.mkString(",")}") - case Panic(_) => - def run(node: Node) = f(node) - def run(node: Func) = f(node.body) - -def resolveRef(entry: Node, defs: Set[Func], classes: Set[ClassInfo], allowInlineJp: Bool = false): Unit = - val defsMap = defs.map(x => x.name -> x).toMap - val classesMap = classes.map(x => x.name -> x).toMap - val rl = RefResolver(defsMap, classesMap, allowInlineJp) - rl.run(entry) - defs.foreach(rl.run(_)) - diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala deleted file mode 100644 index a660f9f078..0000000000 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/llir/Validator.scala +++ /dev/null @@ -1,45 +0,0 @@ -package hkmc2.codegen.llir - -import hkmc2.utils._ - -private final class FuncRefInSet(defs: Set[Func], classes: Set[ClassInfo]): - import Node._ - import Expr._ - - private def f(x: Expr): Unit = x match - case Ref(name) => - case Literal(lit) => - case CtorApp(name, args) => - case Select(name, ref, field) => ref.getClass match { - case Some(real_class) => if !classes.exists(_ eq real_class) then throw LowLevelIRError("ref is not in the set") - case _ => - } - case BasicOp(name, args) => - case AssignField(assignee, ref, _, value) => ref.getClass match { - case Some(real_class) => if !classes.exists(_ eq real_class) then throw LowLevelIRError("ref is not in the set") - case _ => - } - - private def f(x: Node): Unit = x match - case Result(res) => - case Jump(func, args) => - case Case(x, cases, default) => cases foreach { (_, body) => f(body) }; default foreach f - case Panic(_) => - case LetExpr(name, expr, body) => f(body) - case LetMethodCall(names, cls, method, args, body) => f(body) - case LetCall(res, ref, args, body) => - ref.getFunc match { - case Some(real_func) => if !defs.exists(_ eq real_func) then throw LowLevelIRError("ref is not in the set") - case _ => - } - f(body) - def run(node: Node) = f(node) - def run(func: Func) = f(func.body) - -def validateRefInSet(entry: Node, defs: Set[Func], classes: Set[ClassInfo]): Unit = - val funcRefInSet = FuncRefInSet(defs, classes) - defs.foreach(funcRefInSet.run(_)) - -def validate(entry: Node, defs: Set[Func], classes: Set[ClassInfo]): Unit = - validateRefInSet(entry, defs, classes) - diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile b/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile index 45aae4802c..bba9ca31e9 100644 --- a/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/Makefile @@ -24,4 +24,4 @@ clean: auto: $(TARGET) $(TARGET): $(SRC) $(INCLUDES) - $(CXX) $(CFLAGS) $(LDFLAGS) $(SRC) $(LDLIBS) -o $(TARGET) + $(CXX) $(CFLAGS) $(LDFLAGS) mlsaux.cxx $(SRC) $(LDLIBS) -o $(TARGET) diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsaux.cxx b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsaux.cxx new file mode 100644 index 0000000000..edc8d63c3f --- /dev/null +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsaux.cxx @@ -0,0 +1,4 @@ +#include "mlsprelude.h" + +#include +#include diff --git a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h index fed52549bb..5002497f9b 100644 --- a/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h +++ b/hkmc2/shared/src/test/mlscript-compile/cpp/mlsprelude.h @@ -2,13 +2,15 @@ #include #include #include +#include #include #include #include #include +#include +#include #include #include -#include #include constexpr std::size_t _mlsAlignment = 8; @@ -47,6 +49,11 @@ consteval auto nextTypeTag() { return nextTypeTag(); } +constexpr static inline uint32_t unitTag = nextTypeTag(); +constexpr static inline uint32_t floatTag = nextTypeTag(); +constexpr static inline uint32_t strTag = nextTypeTag(); +constexpr static inline uint32_t lazyTag = nextTypeTag(); + struct _mlsObject { uint32_t refCount; uint32_t tag; @@ -67,17 +74,52 @@ struct _mlsObject { virtual void destroy() = 0; }; +#define BOOST_STACKTRACE_GNU_SOURCE_NOT_REQUIRED +#include + +class _mlsUtil { +public: + [[noreturn]] static void panic(const char *msg) { + std::fprintf(stderr, "Panic: %s\n", msg); + std::string st = + boost::stacktrace::to_string(boost::stacktrace::stacktrace()); + std::fprintf(stderr, "%s\n", st.c_str()); + std::abort(); + } + [[noreturn]] static void panic_with(const char *msg, const char *func, + const char *file, int line) { + std::fprintf(stderr, "Panic: %s at %s %s:%d\n", msg, func, file, line); + std::string st = + boost::stacktrace::to_string(boost::stacktrace::stacktrace()); + std::fprintf(stderr, "%s\n", st.c_str()); + std::abort(); + } +}; + +#define _mls_assert(e) \ + (__builtin_expect(!(e), 0) \ + ? _mlsUtil::panic_with("assertion failed", __func__, \ + __FILE__, __LINE__) \ + : (void)0) + +struct _mlsFloatShape : public _mlsObject { + double f; +}; + class _mlsValue { using uintptr_t = std::uintptr_t; - using uint64_t = std::uint64_t; + using intptr_t = std::intptr_t; + using int64_t = std::int64_t; - void *value alignas(_mlsAlignment); + void *value; bool isInt63() const { return (reinterpret_cast(value) & 1) == 1; } + bool isFloat() const { return isPtr() && asObject()->tag == floatTag; } + bool isPtr() const { return (reinterpret_cast(value) & 1) == 0; } - uint64_t asInt63() const { return reinterpret_cast(value) >> 1; } + int64_t asInt63() const { return reinterpret_cast(value) >> 1; } uintptr_t asRawInt() const { return reinterpret_cast(value); } @@ -85,17 +127,17 @@ class _mlsValue { return _mlsValue(reinterpret_cast(i)); } - static _mlsValue fromInt63(uint64_t i) { + static _mlsValue fromInt63(int64_t i) { return _mlsValue(reinterpret_cast((i << 1) | 1)); } void *asPtr() const { - assert(!isInt63()); + _mls_assert(!isInt63()); return value; } _mlsObject *asObject() const { - assert(isPtr()); + _mls_assert(isPtr()); return static_cast<_mlsObject *>(value); } @@ -119,25 +161,70 @@ class _mlsValue { return fromInt63(asInt63() / other.asInt63()); } + _mlsValue modInt63(const _mlsValue &other) const { + return fromInt63(asInt63() % other.asInt63()); + } + _mlsValue gtInt63(const _mlsValue &other) const { - return _mlsValue::fromBoolLit(asInt63() > other.asInt63()); + return fromBoolLit(asInt63() > other.asInt63()); } _mlsValue ltInt63(const _mlsValue &other) const { - return _mlsValue::fromBoolLit(asInt63() < other.asInt63()); + return fromBoolLit(asInt63() < other.asInt63()); } _mlsValue geInt63(const _mlsValue &other) const { - return _mlsValue::fromBoolLit(asInt63() >= other.asInt63()); + return fromBoolLit(asInt63() >= other.asInt63()); } _mlsValue leInt63(const _mlsValue &other) const { - return _mlsValue::fromBoolLit(asInt63() <= other.asInt63()); + return fromBoolLit(asInt63() <= other.asInt63()); + } + + _mlsValue minInt63(const _mlsValue &other) const { + int64_t a = asInt63(); + int64_t b = other.asInt63(); + return fromInt63(a < b ? a : b); + } + + _mlsValue maxInt63(const _mlsValue &other) const { + int64_t a = asInt63(); + int64_t b = other.asInt63(); + return fromInt63(a > b ? a : b); + } + + _mlsValue absInt63() const { + int64_t a = asInt63(); + return fromInt63(a < 0 ? -a : a); + } + + _mlsValue floorDivInt63(const _mlsValue &other) const { + int64_t a = asInt63(); + int64_t b = other.asInt63(); + int64_t q = a / b; + int64_t r = a % b; + if ((r > 0 && b < 0) || (r < 0 && b > 0)) + q = q - 1; + return fromInt63(q); + } + + _mlsValue floorModInt63(const _mlsValue &other) const { + int64_t a = asInt63(); + int64_t b = other.asInt63(); + long r = a % b; + if ((r > 0 && b < 0) || (r < 0 && b > 0)) + r = r + b; + return fromInt63(r); } public: + struct inc_ref_tag {}; explicit _mlsValue() : value(nullptr) {} explicit _mlsValue(void *value) : value(value) {} + explicit _mlsValue(void *value, inc_ref_tag) : value(value) { + if (isPtr()) + asObject()->incRef(); + } _mlsValue(const _mlsValue &other) : value(other.value) { if (isPtr()) asObject()->incRef(); @@ -160,12 +247,12 @@ class _mlsValue { } } - uint64_t asInt() const { - assert(isInt63()); + int64_t asInt() const { + _mls_assert(isInt63()); return asInt63(); } - static _mlsValue fromIntLit(uint64_t i) { return fromInt63(i); } + static _mlsValue fromIntLit(int64_t i) { return fromInt63(i); } static _mlsValue fromBoolLit(bool b) { return fromInt63(b); } @@ -184,11 +271,11 @@ class _mlsValue { return v.asObject()->tag == T::typeTag; } - static bool isIntLit(const _mlsValue &v, uint64_t n) { + static bool isIntLit(const _mlsValue &v, int64_t n) { return v.asInt63() == n; } - static bool isIntLit(const _mlsValue &v) { return v.isInt63(); } + static bool isInt(const _mlsValue &v) { return v.isInt63(); } template static T *as(const _mlsValue &v) { return dynamic_cast(v.asObject()); @@ -198,61 +285,43 @@ class _mlsValue { return static_cast(v.asObject()); } + _mlsValue floorDiv(const _mlsValue &other) const; + + _mlsValue floorMod(const _mlsValue &other) const; + + _mlsValue pow(const _mlsValue &other) const; + + _mlsValue abs() const; + // Operators - _mlsValue operator==(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return _mlsValue::fromBoolLit(eqInt63(other)); - assert(false); - } + _mlsValue operator==(const _mlsValue &other) const; - _mlsValue operator+(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return addInt63(other); - assert(false); - } + _mlsValue operator!=(const _mlsValue &other) const; - _mlsValue operator-(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return subInt63(other); - assert(false); - } + _mlsValue operator&&(const _mlsValue &other) const; - _mlsValue operator*(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return mulInt63(other); - assert(false); - } + _mlsValue operator||(const _mlsValue &other) const; - _mlsValue operator/(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return divInt63(other); - assert(false); - } + _mlsValue operator+(const _mlsValue &other) const; - _mlsValue operator>(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return gtInt63(other); - assert(false); - } + _mlsValue operator-(const _mlsValue &other) const; - _mlsValue operator<(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return ltInt63(other); - assert(false); - } + _mlsValue operator-() const; - _mlsValue operator>=(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return geInt63(other); - assert(false); - } + _mlsValue operator*(const _mlsValue &other) const; - _mlsValue operator<=(const _mlsValue &other) const { - if (isInt63() && other.isInt63()) - return leInt63(other); - assert(false); - } + _mlsValue operator/(const _mlsValue &other) const; + + _mlsValue operator%(const _mlsValue &other) const; + + _mlsValue operator>(const _mlsValue &other) const; + + _mlsValue operator<(const _mlsValue &other) const; + + _mlsValue operator>=(const _mlsValue &other) const; + + _mlsValue operator<=(const _mlsValue &other) const; // Auxiliary functions @@ -265,7 +334,9 @@ class _mlsValue { }; struct _mls_Callable : public _mlsObject { - virtual _mlsValue _mls_apply0() { throw std::runtime_error("Not implemented"); } + virtual _mlsValue _mls_apply0() { + throw std::runtime_error("Not implemented"); + } virtual _mlsValue _mls_apply1(_mlsValue) { throw std::runtime_error("Not implemented"); } @@ -303,8 +374,7 @@ inline static _mlsValue _mlsCall(_mlsValue f, U... args) { return _mlsToCallable(f)->_mls_apply4(args...); } -template -inline static T *_mlsMethodCall(_mlsValue self) { +template inline static T *_mlsMethodCall(_mlsValue self) { auto *ptr = _mlsValue::as(self); if (!ptr) throw std::runtime_error("unable to convert object for method calls"); @@ -315,7 +385,7 @@ inline int _mlsLargeStack(void *(*fn)(void *)) { pthread_t thread; pthread_attr_t attr; - size_t stacksize = 512 * 1024 * 1024; + size_t stacksize = 1024 * 1024 * 1024; pthread_attr_init(&attr); pthread_attr_setstacksize(&attr, stacksize); @@ -338,7 +408,7 @@ inline void *_mlsMainWrapper(void *) { struct _mls_Unit final : public _mlsObject { constexpr static inline const char *typeName = "Unit"; - constexpr static inline uint32_t typeTag = nextTypeTag(); + constexpr static inline uint32_t typeTag = unitTag; virtual void print() const override { std::printf(typeName); } static _mlsValue create() { static _mls_Unit mlsUnit alignas(_mlsAlignment); @@ -349,6 +419,160 @@ struct _mls_Unit final : public _mlsObject { virtual void destroy() override {} }; +struct _mls_Float final : public _mlsFloatShape { + constexpr static inline const char *typeName = "Float"; + constexpr static inline uint32_t typeTag = floatTag; + virtual void print() const override { + std::printf(typeName); + std::printf("("); + std::printf("%f", f); + std::printf(")"); + } + static _mlsValue create(double f) { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Float; + _mlsVal->f = f; + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + _mlsValue operator+(const _mls_Float &other) const { + return _mlsValue::create<_mls_Float>(f + other.f); + } + _mlsValue operator-(const _mls_Float &other) const { + return _mlsValue::create<_mls_Float>(f - other.f); + } + _mlsValue operator*(const _mls_Float &other) const { + return _mlsValue::create<_mls_Float>(f * other.f); + } + _mlsValue operator/(const _mls_Float &other) const { + return _mlsValue::create<_mls_Float>(f / other.f); + } + _mlsValue operator==(const _mls_Float &other) const { + return _mlsValue::fromBoolLit(f == other.f); + } + _mlsValue operator!=(const _mls_Float &other) const { + return _mlsValue::fromBoolLit(f == other.f); + } + _mlsValue operator>(const _mls_Float &other) const { + return _mlsValue::fromBoolLit(f > other.f); + } + _mlsValue operator<(const _mls_Float &other) const { + return _mlsValue::fromBoolLit(f < other.f); + } + _mlsValue operator>=(const _mls_Float &other) const { + return _mlsValue::fromBoolLit(f >= other.f); + } + _mlsValue operator<=(const _mls_Float &other) const { + return _mlsValue::fromBoolLit(f <= other.f); + } + virtual void destroy() override { + operator delete(this, std::align_val_t(_mlsAlignment)); + } +}; + +struct _mls_Str final : public _mlsObject { + std::string str; + constexpr static inline const char *typeName = "Str"; + constexpr static inline uint32_t typeTag = strTag; + virtual void print() const override { + std::printf("\""); + for (const auto c : str) { + switch (c) { + case '\'': + std::printf("\\\'"); + break; + case '\"': + std::printf("\\\""); + break; + case '\?': + std::printf("\\\?"); + break; + case '\\': + std::printf("\\\\"); + break; + case '\a': + std::printf("\\a"); + break; + case '\b': + std::printf("\\b"); + break; + case '\f': + std::printf("\\f"); + break; + case '\n': + std::printf("\\n"); + break; + case '\r': + std::printf("\\r"); + break; + case '\t': + std::printf("\\t"); + break; + case '\v': + std::printf("\\v"); + break; + default: + if (c < 32 || c > 126) + std::printf("\\x%02x", c); + else + std::putchar(c); + } + } + std::printf("\""); + std::fflush(stdout); + } + static _mlsValue create(const std::string_view str) { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Str; + _mlsVal->str = str; + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + static _mlsValue create(const std::string_view str1, + const std::string_view str2) { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Str; + _mlsVal->str = str1; + _mlsVal->str += str2; + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + return _mlsValue(_mlsVal); + } + virtual void destroy() override { + str.~basic_string(); + operator delete(this, std::align_val_t(_mlsAlignment)); + } +}; + +struct _mls_Lazy final : public _mlsObject { + _mlsValue init; + _mlsValue value; + bool evaluated; + constexpr static inline const char *typeName = "Lazy"; + constexpr static inline uint32_t typeTag = lazyTag; + virtual void print() const override { std::printf(typeName); } + static _mlsValue create(_mlsValue init) { + auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Lazy; + _mlsVal->refCount = 1; + _mlsVal->tag = typeTag; + _mlsVal->init = init; + _mlsVal->value = _mlsValue::create<_mls_Unit>(); + _mlsVal->evaluated = false; + return _mlsValue(_mlsVal); + } + virtual void destroy() override { + _mlsValue::destroy(init); + _mlsValue::destroy(value); + operator delete(this, std::align_val_t(_mlsAlignment)); + } + _mlsValue _mls_get() { + if (!evaluated) { + value = _mlsCall(init); + evaluated = true; + } + return value; + } +}; + #include struct _mls_ZInt final : public _mlsObject { @@ -368,7 +592,7 @@ struct _mls_ZInt final : public _mlsObject { static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_ZInt; _mlsVal->refCount = 1; - _mlsVal->tag = typeTag; + _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } static _mlsValue create(_mlsValue z) { @@ -429,85 +653,135 @@ struct _mls_ZInt final : public _mlsObject { } _mlsValue toInt() const { - return _mlsValue::fromIntLit(z.convert_to()); + return _mlsValue::fromIntLit(z.convert_to()); } - static _mlsValue fromInt(uint64_t i) { + static _mlsValue fromInt(int64_t i) { return _mlsValue::create<_mls_ZInt>(_mlsValue::fromIntLit(i)); } }; -__attribute__((noinline)) inline void _mlsNonExhaustiveMatch() { - throw std::runtime_error("Non-exhaustive match"); +[[noreturn, gnu::noinline]] inline void _mlsNonExhaustiveMatch() { + _mlsUtil::panic("Non-exhaustive match"); +} + +inline _mlsValue _mls_builtin_pow(_mlsValue a, _mlsValue b) { return a.pow(b); } + +inline _mlsValue _mls_builtin_abs(_mlsValue a) { return a.abs(); } + +inline _mlsValue _mls_builtin_floor_div(_mlsValue a, _mlsValue b) { + return a.floorDiv(b); +} + +inline _mlsValue _mls_builtin_floor_mod(_mlsValue a, _mlsValue b) { + return a.floorMod(b); +} + +inline _mlsValue _mls_builtin_trunc_div(_mlsValue a, _mlsValue b) { + _mls_assert(_mlsValue::isInt(a)); + _mls_assert(_mlsValue::isInt(b)); + return a / b; +} + +inline _mlsValue _mls_builtin_trunc_mod(_mlsValue a, _mlsValue b) { + _mls_assert(_mlsValue::isInt(a)); + _mls_assert(_mlsValue::isInt(b)); + return a % b; +} + +inline _mlsValue _mls_builtin_int2str(_mlsValue a) { + _mls_assert(_mlsValue::isInt(a)); + char buf[32]; + std::snprintf(buf, sizeof(buf), "%" PRIu64, a.asInt()); + return _mlsValue::create<_mls_Str>(buf); +} + +inline _mlsValue _mls_builtin_float2str(_mlsValue a) { + _mls_assert(_mlsValue::isValueOf<_mls_Float>(a)); + char buf[128]; + std::snprintf(buf, sizeof(buf), "%f", _mlsValue::cast<_mls_Float>(a)->f); + return _mlsValue::create<_mls_Str>(buf); +} + +inline _mlsValue _mls_builtin_int2float(_mlsValue a) { + return _mlsValue::create<_mls_Float>(a.asInt()); +} + +inline _mlsValue _mls_builtin_str_concat(_mlsValue a, _mlsValue b) { + _mls_assert(_mlsValue::isValueOf<_mls_Str>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_Str>(b)); + auto *strA = _mlsValue::cast<_mls_Str>(a); + auto *strB = _mlsValue::cast<_mls_Str>(b); + return _mlsValue::create<_mls_Str>(strA->str.c_str(), strB->str.c_str()); } inline _mlsValue _mls_builtin_z_add(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) + *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_sub(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) - *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_mul(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) * *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_div(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) / *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_mod(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) % *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_equal(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) == *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_gt(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) > *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_lt(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) < *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_geq(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) >= *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_leq(_mlsValue a, _mlsValue b) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); - assert(_mlsValue::isValueOf<_mls_ZInt>(b)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(b)); return *_mlsValue::cast<_mls_ZInt>(a) <= *_mlsValue::cast<_mls_ZInt>(b); } inline _mlsValue _mls_builtin_z_to_int(_mlsValue a) { - assert(_mlsValue::isValueOf<_mls_ZInt>(a)); + _mls_assert(_mlsValue::isValueOf<_mls_ZInt>(a)); return _mlsValue::cast<_mls_ZInt>(a)->toInt(); } inline _mlsValue _mls_builtin_z_of_int(_mlsValue a) { - assert(_mlsValue::isIntLit(a)); + _mls_assert(_mlsValue::isInt(a)); return _mlsValue::create<_mls_ZInt>(a); } @@ -527,3 +801,148 @@ inline _mlsValue _mls_builtin_debug(_mlsValue a) { std::puts(""); return a; } + +inline _mlsValue _mlsValue::floorDiv(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return floorDivInt63(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::floorMod(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return floorModInt63(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::pow(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return fromInt63(std::pow(asInt63(), other.asInt63())); + if (isFloat() && other.isFloat()) + return _mlsValue::create<_mls_Float>( + std::pow(as<_mls_Float>(*this)->f, as<_mls_Float>(other)->f)); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::abs() const { + if (isInt63()) + return absInt63(); + if (isFloat()) + return _mlsValue::create<_mls_Float>(std::abs(as<_mls_Float>(*this)->f)); + _mls_assert(false); +} + +// Operators + +inline _mlsValue _mlsValue::operator==(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return _mlsValue::fromBoolLit(eqInt63(other)); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) == *as<_mls_Float>(other); + bool sameTag = + isPtr() && other.isPtr() && asObject()->tag == other.asObject()->tag; + if (!sameTag) + return _mlsValue::fromBoolLit(false); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator!=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return _mlsValue::fromBoolLit(!eqInt63(other)); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) != *as<_mls_Float>(other); + bool sameTag = + isPtr() && other.isPtr() && asObject()->tag == other.asObject()->tag; + if (!sameTag) + return _mlsValue::fromBoolLit(true); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator&&(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return _mlsValue::fromBoolLit(asInt63() && other.asInt63()); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator||(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return _mlsValue::fromBoolLit(asInt63() || other.asInt63()); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator+(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return addInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) + *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator-(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return subInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) - *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator-() const { + if (isInt63()) + return fromInt63(-asInt63()); + if (isFloat()) + return _mlsValue::create<_mls_Float>(-as<_mls_Float>(*this)->f); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator*(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return mulInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) * *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator/(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return divInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) / *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator%(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return modInt63(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator>(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return gtInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) > *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator<(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return ltInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) < *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator>=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return geInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) >= *as<_mls_Float>(other); + _mls_assert(false); +} + +inline _mlsValue _mlsValue::operator<=(const _mlsValue &other) const { + if (isInt63() && other.isInt63()) + return leInt63(other); + if (isFloat() && other.isFloat()) + return *as<_mls_Float>(*this) <= *as<_mls_Float>(other); + _mls_assert(false); +} diff --git a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls index 636650519a..87dd0252e6 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls @@ -34,5 +34,5 @@ let x = x + 1 //│ set tmp1 = +(x3, 1) in //│ set x4 = tmp1 in //│ set block$res3 = undefined in -//│ end +//│ end //│ x = 1 diff --git a/hkmc2/shared/src/test/mlscript/decls/Prelude.mls b/hkmc2/shared/src/test/mlscript/decls/Prelude.mls index bf2d2047de..1442d01c3f 100644 --- a/hkmc2/shared/src/test/mlscript/decls/Prelude.mls +++ b/hkmc2/shared/src/test/mlscript/decls/Prelude.mls @@ -29,8 +29,18 @@ declare class Str with fun length: Int fun concat: Str -> Str -// declare module Math // TODO: list members -declare val Math // so we can, eg, `open { pow }` in the meantime +declare module Math with + // TODO: add more functions + fun abs: Num -> Num + fun sqrt: Num -> Num + fun sin: Num -> Num + fun cos: Num -> Num + fun tan: Num -> Num + fun pow: Num -> Num + fun round: Num -> Num + fun trunc: Num -> Num + fun min: (Num, Num) -> Num + fun max: (Num, Num) -> Num declare val Reflect declare val console diff --git a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls index b9b69f1f48..4c231a2836 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BadPrograms.mls @@ -1,15 +1,13 @@ - -:global :llir :cpp - -:todo :ge fun oops(a) = class A with fun m = a let x = 1 +//│ ═══[COMPILATION ERROR] Function without arguments not supported: 0 +//│ Stopped due to an error during the Llir generation :ge let x = "oops" @@ -17,4 +15,3 @@ x.m //│ ═══[COMPILATION ERROR] Unsupported selection by users //│ Stopped due to an error during the Llir generation - diff --git a/hkmc2/shared/src/test/mlscript/llir/BasicCpp.mls b/hkmc2/shared/src/test/mlscript/llir/BasicCpp.mls index be1d2e5400..50b35ad37e 100644 --- a/hkmc2/shared/src/test/mlscript/llir/BasicCpp.mls +++ b/hkmc2/shared/src/test/mlscript/llir/BasicCpp.mls @@ -1,46 +1,68 @@ - -:global +:js :llir :cpp - -:scpp fun foo(a) = let x if a > 0 do x = 1 x + 1 + +:showWholeCpp +fun bar(x) = + x + 1 +foo(1) +//│ = 2 //│ -//│ Cpp: +//│ WholeProgramCpp: //│ #include "mlsprelude.h" -//│ _mlsValue _mls_j_0(_mlsValue); +//│ _mlsValue _mls_entry2(); //│ _mlsValue _mls_foo(_mlsValue); -//│ _mlsValue _mlsMain(); -//│ _mlsValue _mls_j_0(_mlsValue _mls_x_2) { +//│ _mlsValue _mls_entry(); +//│ _mlsValue _mls_j(_mlsValue); +//│ _mlsValue _mls_entry1(); +//│ _mlsValue _mls_bar(_mlsValue); +//│ _mlsValue _mls_bar(_mlsValue _mls_x8) { //│ _mlsValue _mls_retval; -//│ auto _mls_x_6 = (_mls_x_2 + _mlsValue::fromIntLit(1)); -//│ _mls_retval = _mls_x_6; +//│ auto _mls_x7 = (_mls_x8 + _mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x7; //│ return _mls_retval; //│ } -//│ _mlsValue _mls_foo(_mlsValue _mls_a) { +//│ _mlsValue _mls_j(_mlsValue _mls_x1) { //│ _mlsValue _mls_retval; -//│ auto _mls_x_0 = _mlsValue::create<_mls_Unit>(); -//│ auto _mls_x_1 = (_mls_a > _mlsValue::fromIntLit(0)); -//│ if (_mlsValue::isIntLit(_mls_x_1, 1)) { -//│ auto _mls_x_3 = _mlsValue::fromIntLit(1); -//│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_3); -//│ } else { -//│ auto _mls_x_5 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_0); -//│ } +//│ auto _mls_x = (_mls_x1 + _mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry2() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x9 = _mls_foo(_mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x9; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry1() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); //│ return _mls_retval; //│ } -//│ _mlsValue _mlsMain() { +//│ _mlsValue _mls_entry() { //│ _mlsValue _mls_retval; //│ _mls_retval = _mlsValue::create<_mls_Unit>(); //│ return _mls_retval; //│ } +//│ _mlsValue _mls_foo(_mlsValue _mls_a) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x2 = _mlsValue::create<_mls_Unit>(); +//│ auto _mls_x3 = (_mls_a > _mlsValue::fromIntLit(0)); +//│ if (_mlsValue::isIntLit(_mls_x3, 1)) { +//│ auto _mls_x5 = _mlsValue::fromIntLit(1); +//│ auto _mls_x6 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j(_mls_x5); +//│ } else { +//│ auto _mls_x4 = _mlsValue::create<_mls_Unit>(); +//│ _mls_retval = _mls_j(_mls_x2); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { return _mls_entry2(); } //│ int main() { return _mlsLargeStack(_mlsMainWrapper); } - - diff --git a/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls b/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls new file mode 100644 index 0000000000..12f12bb424 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/BasisLLIR.mls @@ -0,0 +1,484 @@ +:js +:llir +:cpp + +// This file contains all tests for LLIR in the original MLscript compiler. + +:intl +data class Pair[A, B](x: A, y: B) +fun mktup2(x, y) = mktup(x, y) +fun mktup(x, y) = Pair(x, y) +fun foo() = + mktup2(1, 2) +foo() +//│ = Pair(1, 2) +//│ +//│ Interpreted: +//│ Pair(1,2) + +:intl +data class Pair[A, B](x: A, y: B) +fun foo(pair) = + if pair is + Pair(x, y) then Pair(x, y) +fun bar() = + foo(Pair(1, 2)) +bar() +//│ = Pair(1, 2) +//│ +//│ Interpreted: +//│ Pair(1,2) + +:intl +data class Pair[A, B](x: A, y: B) +fun foo(pair) = + if pair is + Pair(x, y) then Pair(x, y) +fun bar() = + foo(Pair(1, 2)) +bar() +//│ = Pair(1, 2) +//│ +//│ Interpreted: +//│ Pair(1,2) + +:intl +data class Pair[A, B](x: A, y: B) +fun silly(pair) = + let x = 0 + let n = if pair is + Pair(x1, x2) then + if pair is + Pair (x3, x4) then x3 + 1 + n + 1 +fun foo() = + let a = Pair(0, 1) + let b = silly(a) + b +foo() +//│ = 2 +//│ +//│ Interpreted: +//│ 2 + +:intl +data class Pair[A, B](x: A, y: B) +fun inc_fst(pair) = + let c = 2 + if pair is + Pair(x1, x2) then x1 + c +fun foo() = + let a = Pair(0, 1) + let b = inc_fst(a) + b +foo() +//│ = 2 +//│ +//│ Interpreted: +//│ 2 + +:intl +data class Pair[A, B](x: A, y: B) +fun inc_fst(pair) = + let x = 0 + if pair is + Pair(x1, x2) then x2 + 1 +fun foo() = + let b = inc_fst(Pair(0, 1)) + b +foo() +//│ = 2 +//│ +//│ Interpreted: +//│ 2 + +:intl +abstract class Either[out A, out B]: Left[A, B] | Right[A, B] +data class Left[out A, out B](x: A) extends Either[A, B] +data class Right[out A, out B](y: B) extends Either[A, B] +fun foo(a, b) = + let t = if a is + Left(x) then Left(x + 1) + Right(y) then Right(b) + if t is + Left(x) then x + Right(y) then y +fun bar() = + foo(Right(2), 2) +bar() +//│ = 2 +//│ +//│ Interpreted: +//│ 2 + +:intl +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun foo() = + bar(S(O)) +fun bar(x) = + baz(x) +fun baz(x) = + if x is + S(s) then s + O then x +foo() +//│ = O +//│ +//│ Interpreted: +//│ O() + +:intl +data class A(x, y, z) +data class B(m, n) +fun complex_foo(t) = + let r = if t is + A(x, y, z) then x + y * z + B(m, n) then m - n + let s = B(1, 2) + let u = if s is + A(x, y, z) then 3 + B(m, n) then 4 + r + u +fun bar() = + complex_foo(A(6, 7, 8)) + complex_foo(B(9, 10)) +bar() +//│ = 3 +//│ +//│ Interpreted: +//│ 3 + +:intl +data class A(w, x) +data class B(y) +data class C(z) +fun complex_foo(t) = + let a = 1 + 2 + let b = 1 * 2 + let x = if t is + A(x, y) then y + B(x) then B(x + b) + C(x) then C(0) + let z = A(5, x) + let v = B(6) + let y = if x is + A(x, y) then + let m = x + a + b + if y is + A(x, y) then x + B(x) then m + C(x) then 0 + B(x) then 2 + C(x) then 3 + if z is + A(x, y) then x + B(x) then 4 + C(x) then + if v is + A(x, y) then x + B(x) then 7 + C(x) then 8 +fun bar() = + complex_foo(A(10, A(9, B(10)))) +bar() +//│ = 5 +//│ +//│ Interpreted: +//│ 5 + +:intl +fun fib(n) = if n < 2 then n else fib(n-1) + fib(n-2) +fib(20) +//│ = 6765 +//│ +//│ Interpreted: +//│ 6765 + +:intl +fun odd(x) = if x == 0 then false else even(x-1) +fun even(x) = if x == 0 then true else odd(x-1) +fun foo() = odd(10) +foo() +//│ = false +//│ +//│ Interpreted: +//│ false + +:intl +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option +fun not(x) = + if x then false else true +fun foo(x) = + if x then None + else Some(foo(not(x))) +fun main() = foo(false) +main() +//│ = Some(None) +//│ +//│ Interpreted: +//│ Some(None()) + +:intl +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option +fun fromSome(s) = if s is Some(x) then x +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun aaa() = + let m = 1 + let n = 2 + let p = 3 + let q = 4 + m + n - p + q +fun bbb() = + let x = aaa() + x * 100 + 4 +fun not(x) = + if x then false else true +fun foo(x) = + if x then None + else Some(foo(not(x))) +fun main() = + let x = foo(false) + if x is + None then aaa() + Some(b1) then bbb() +main() +//│ = 404 +//│ +//│ Interpreted: +//│ 404 + +:intl +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun odd(x) = + if x is + O then false + S(s) then even(s) +fun even(x) = + if x is + O then true + S(s) then odd(s) +fun foo() = odd(S(S(S(O)))) +foo() +//│ = true +//│ +//│ Interpreted: +//│ true + +:intl +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun odd(x) = + if x is + O then false + S(s) then even(s) +fun even(x) = + if x is + O then true + S(s) then odd(s) +fun mk(n) = if n > 0 then S(mk(n - 1)) else O +fun foo() = odd(mk(10)) +foo() +//│ = false +//│ +//│ Interpreted: +//│ false + +:intl +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun odd(x) = + if x is + O then false + S(s) then even(s) +fun even(x) = + if x is + O then true + S(s) then odd(s) +fun mk(n) = if n > 0 then S(mk(n - 1)) else O +fun foo() = odd(S(S(mk(10)))) +foo() +//│ = false +//│ +//│ Interpreted: +//│ false + +:intl +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun odd(x) = + if x is + O then false + S(s) then even(s) +fun even(x) = + if x is + O then true + S(s) then odd(s) +fun foo() = odd(if 10 > 0 then S(O) else O) +fun bar() = if 10 > 0 then odd(S(O)) else odd(O) +fun main() = + foo() + bar() +main() +//│ = true +//│ +//│ Interpreted: +//│ true + +:intl +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option +abstract class List[out T]: Cons[T] | Nil +data class (::) Cons[out T](head: T, tail: List[T]) extends List[T] +object Nil extends List +fun head_opt(l) = + if l is + Nil then None + Cons(h, t) then Some(h) +fun is_none(o) = + if o is + None then true + Some(x) then false +fun is_empty(l) = + is_none(head_opt(l)) +fun main() = + is_empty(Cons(1, Cons(2, Nil))) +main() +//│ = false +//│ +//│ Interpreted: +//│ false + +:intl +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option +abstract class List[out T]: Cons[T] | Nil +data class (::) Cons[out T](head: T, tail: List[T]) extends List[T] +object Nil extends List +fun mk_list(n) = + if n == 0 then Nil else Cons(n, mk_list(n - 1)) +fun head_opt(l) = + if l is + Nil then None + Cons(h, t) then Some(h) +fun is_none(o) = + if o is + None then true + Some(x) then false +fun is_empty(l) = + is_none(head_opt(l)) +fun main() = + is_empty(mk_list(10)) +main() +//│ = false +//│ +//│ Interpreted: +//│ false + + +:intl +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option +abstract class List[out T]: Cons[T] | Nil +data class (::) Cons[out T](head: T, tail: List[T]) extends List[T] +object Nil extends List +fun mk_list(n) = + if n == 0 then Nil else Cons(n, mk_list(n - 1)) +fun last_opt(l) = + if l is + Nil then None + Cons(h, t) then + if t is + Nil then Some(h) + Cons(h2, t2) then last_opt(t) +fun main() = + last_opt(mk_list(10)) +main() +//│ = Some(1) +//│ +//│ Interpreted: +//│ Some(1) + +:intl +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option +fun is_some(o) = + if o is + Some(x) then true + None then false +fun e0(w) = + w + 8 + 9 + 10 +fun e1(a, c) = + a + 1 + 2 + 3 + 4 +fun e3(c) = + let m = 4 + let n = 5 + let p = 6 + let q = 7 + if c then m + n + p + q else m + n - p + q +fun e2(x) = + x + 12 + 13 + 14 +fun f(x) = + let c1 = is_some(x) + let z = e3(c1) + let w = if x is + Some(a) then e1(a, z) + None then e2(z) + e0(w) +fun main() = + f(Some(2)) + f(None) +main() +//│ = 115 +//│ +//│ Interpreted: +//│ 115 + +:intl +abstract class Nat: S[Nat] | O +data class S(s: Nat) extends Nat +object O extends Nat +fun pred(n) = + if n is + S(p) then p + O then O +fun plus(n1, n2) = + if n1 is + O then n2 + S(p) then S(plus(p, n2)) +fun fib(n) = + if n is + O then S(O) + S(p) then + if p is + O then S(O) + S(q) then plus(fib(p), fib(q)) +fun to_int(n) = + if n is + O then 0 + S(p) then 1 + to_int(p) +fun to_nat(n) = + if n == 0 then O + else S(to_nat(n - 1)) +fun main() = + to_int(fib(to_nat(10))) +main() +//│ = 89 +//│ +//│ Interpreted: +//│ 89 diff --git a/hkmc2/shared/src/test/mlscript/llir/Classes.mls b/hkmc2/shared/src/test/mlscript/llir/Classes.mls new file mode 100644 index 0000000000..5c01fa8941 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Classes.mls @@ -0,0 +1,49 @@ +:llir +:cpp + +:intl +abstract class Callable +object FnLike1 extends Callable with + fun apply1(x) = x * 2 +fun apply(f, x) = f(x) +fun main() = + let mul2 = FnLike1 + apply(mul2, 3) +main() +//│ +//│ Interpreted: +//│ 6 + +:intl +:sllir +class Base() with + fun get() = 1 +class Derived() extends Base with + fun get() = 2 +fun main() = + let d = Derived() + d.Base#get() * d.Derived#get() +main() +//│ LLIR: +//│ class Base() { +//│ def get$758() = +//│ 1 +//│ } +//│ class Derived() extends Base { +//│ def get$759() = +//│ 2 +//│ } +//│ def main$761() = +//│ let x$782 = Derived$767() in +//│ let x$783 = Base.get$758(x$782) in +//│ let x$784 = Derived.get$759(x$782) in +//│ let x$785 = *(x$783,x$784) in +//│ x$785 +//│ def entry$787() = +//│ let* (x$786) = main() in +//│ x$786 +//│ entry = entry$787 +//│ +//│ Interpreted: +//│ 4 + diff --git a/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls b/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls new file mode 100644 index 0000000000..aec85f5cfa --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/ControlFlow.mls @@ -0,0 +1,314 @@ +:js +:llir +:cpp + +:sllir +:intl +fun f1() = + let x = 1 + let x = 2 + x +f1() +//│ = 2 +//│ LLIR: +//│ +//│ def f1$729() = +//│ let x$737 = 1 in +//│ let x$738 = 2 in +//│ x$738 +//│ def entry$740() = +//│ let* (x$739) = f1() in +//│ x$739 +//│ entry = entry$740 +//│ +//│ Interpreted: +//│ 2 + +:sllir +:intl +fun f2() = + let x = 0 + if x == 1 then 2 else 3 +f2() +//│ = 3 +//│ LLIR: +//│ +//│ def f2$741() = +//│ let x$750 = 0 in +//│ let x$751 = ==(x$750,1) in +//│ case x$751 of +//│ BoolLit(true) => +//│ 2 +//│ _ => +//│ 3 +//│ def j$752() = +//│ undefined +//│ def entry$754() = +//│ let* (x$753) = f2() in +//│ x$753 +//│ entry = entry$754 +//│ +//│ Interpreted: +//│ 3 + + +:sllir +fun f3() = + let x1 = 0 + let x2 = 1 + if true then x1 else x2 +f3() +//│ = 0 +//│ LLIR: +//│ +//│ def f3$755() = +//│ let x$764 = 0 in +//│ let x$765 = 1 in +//│ let x$766 = true in +//│ case x$766 of +//│ BoolLit(true) => +//│ x$764 +//│ _ => +//│ x$765 +//│ def j$767() = +//│ undefined +//│ def entry$769() = +//│ let* (x$768) = f3() in +//│ x$768 +//│ entry = entry$769 + + +:sllir +:intl +fun f4() = + let x = 0 + let x = if x == 1 then 2 else 3 + x +f4() +//│ = 3 +//│ LLIR: +//│ +//│ def f4$770() = +//│ let x$782 = 0 in +//│ let x$783 = ==(x$782,1) in +//│ case x$783 of +//│ BoolLit(true) => +//│ let x$785 = 2 in +//│ jump j$784(x$785) +//│ _ => +//│ let x$786 = 3 in +//│ jump j$784(x$786) +//│ def j$784(tmp$781) = +//│ tmp$781 +//│ def entry$788() = +//│ let* (x$787) = f4() in +//│ x$787 +//│ entry = entry$788 +//│ +//│ Interpreted: +//│ 3 + +:sllir +:intl +fun f5() = + let x = 0 + let x = if x == 1 then 2 else 3 + let x = if x == 2 then 4 else 5 + x +f5() +//│ = 5 +//│ LLIR: +//│ +//│ def f5$789() = +//│ let x$806 = 0 in +//│ let x$807 = ==(x$806,1) in +//│ case x$807 of +//│ BoolLit(true) => +//│ let x$809 = 2 in +//│ jump j$808(x$809) +//│ _ => +//│ let x$810 = 3 in +//│ jump j$808(x$810) +//│ def j$808(tmp$804) = +//│ let x$811 = ==(tmp$804,2) in +//│ case x$811 of +//│ BoolLit(true) => +//│ let x$813 = 4 in +//│ jump j$812(x$813) +//│ _ => +//│ let x$814 = 5 in +//│ jump j$812(x$814) +//│ def j$812(tmp$805) = +//│ tmp$805 +//│ def entry$816() = +//│ let* (x$815) = f5() in +//│ x$815 +//│ entry = entry$816 +//│ +//│ Interpreted: +//│ 5 + +:sllir +fun test() = + if true do test() +//│ LLIR: +//│ +//│ def test$817() = +//│ let x$824 = true in +//│ case x$824 of +//│ BoolLit(true) => +//│ let* (x$826) = test() in +//│ x$826 +//│ _ => +//│ undefined +//│ def j$825() = +//│ undefined +//│ def entry$827() = +//│ undefined +//│ entry = entry$827 + +:sllir +fun test() = + (if true then test()) + 1 +//│ LLIR: +//│ +//│ def test$828() = +//│ let x$838 = true in +//│ case x$838 of +//│ BoolLit(true) => +//│ let* (x$840) = test() in +//│ jump j$839(x$840) +//│ _ => +//│ panic "match error" +//│ def j$839(tmp$837) = +//│ let x$841 = +(tmp$837,1) in +//│ x$841 +//│ def entry$842() = +//│ undefined +//│ entry = entry$842 + + +:sllir +:intl +fun f() = + let x = 10 + if true do + set x += 1 + x +f() +//│ = 11 +//│ LLIR: +//│ +//│ def f$843() = +//│ let x$856 = 10 in +//│ let x$857 = true in +//│ case x$857 of +//│ BoolLit(true) => +//│ let x$859 = +(x$856,1) in +//│ let x$860 = undefined in +//│ jump j$858(x$859) +//│ _ => +//│ let x$861 = undefined in +//│ jump j$858(x$856) +//│ def j$858(x$845) = +//│ x$845 +//│ def entry$863() = +//│ let* (x$862) = f() in +//│ x$862 +//│ entry = entry$863 +//│ +//│ Interpreted: +//│ 11 + +:sllir +:intl +data class A(x) +data class B(y) +fun f(a) = + let t = if a is + A(_) then 1 + B(_) then 2 + t +f(A(1)) +//│ = 1 +//│ LLIR: +//│ class A(x$869) +//│ class B(y$874) +//│ def f$866(a$878) = +//│ case a$878 of +//│ A$867 => +//│ let x$894 = a$878. in +//│ let x$895 = 1 in +//│ jump j$893(x$895) +//│ B$872 => +//│ let x$896 = a$878. in +//│ let x$897 = 2 in +//│ jump j$893(x$897) +//│ _ => +//│ panic "match error" +//│ def j$893(tmp$891) = +//│ tmp$891 +//│ def entry$900() = +//│ let x$898 = A$867(1) in +//│ let* (x$899) = f(x$898) in +//│ x$899 +//│ entry = entry$900 +//│ +//│ Interpreted: +//│ 1 + +:sllir +:intl +data class A(x) +data class B(y) +fun f(a) = + let t = if a is + A(_) then if a is + A(1) then 1 + B(_) then 2 + B(_) then 3 + t +f(A(1)) +//│ = 1 +//│ LLIR: +//│ class A(x$906) +//│ class B(y$911) +//│ def f$903(a$915) = +//│ case a$915 of +//│ A$904 => +//│ let x$935 = a$915. in +//│ case a$915 of +//│ A$904 => +//│ let x$937 = a$915. in +//│ case x$937 of +//│ IntLit(1) => +//│ let x$939 = 1 in +//│ jump j$938(x$939) +//│ _ => +//│ panic "match error" +//│ B$909 => +//│ let x$940 = a$915. in +//│ let x$941 = 2 in +//│ jump j$936(x$941) +//│ _ => +//│ panic "match error" +//│ B$909 => +//│ let x$942 = a$915. in +//│ let x$943 = 3 in +//│ jump j$934(x$943) +//│ _ => +//│ panic "match error" +//│ def j$938(tmp$931) = +//│ jump j$936(tmp$931) +//│ def j$936(tmp$931) = +//│ jump j$934(tmp$931) +//│ def j$934(tmp$932) = +//│ tmp$932 +//│ def entry$946() = +//│ let x$944 = A$904(1) in +//│ let* (x$945) = f(x$944) in +//│ x$945 +//│ entry = entry$946 +//│ +//│ Interpreted: +//│ 1 diff --git a/hkmc2/shared/src/test/mlscript/llir/Ctor.mls b/hkmc2/shared/src/test/mlscript/llir/Ctor.mls new file mode 100644 index 0000000000..fbbcc0b22e --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Ctor.mls @@ -0,0 +1,11 @@ +:js +:llir +:cpp + +object None +fun testCtor1() = None +fun testCtor2() = new None + +class A(x) +fun testCtor1() = A(1) +fun testCtor2() = new A(1) diff --git a/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls b/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls new file mode 100644 index 0000000000..3a4e5397f6 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/HigherOrder.mls @@ -0,0 +1,117 @@ +:js +:llir +:cpp +:intl + +//│ +//│ Interpreted: +//│ undefined +:sllir +fun add(x) = y => x + y +fun add_curried(x)(y) = x + y +add(1)(2) +//│ = 3 +//│ LLIR: +//│ class Lambda_lambda(lam_arg0$751) extends Callable { +//│ def apply1$753(y$733) = +//│ let x$754 = +(lam_arg0$751,y$733) in +//│ x$754 +//│ } +//│ class Lambda(lam_arg0$759) extends Callable { +//│ def apply1$753(y$738) = +//│ let x$760 = +(lam_arg0$759,y$738) in +//│ x$760 +//│ } +//│ def add$729(x$732) = +//│ let x$756 = Lambda_lambda$749(x$732) in +//│ x$756 +//│ def add_curried$730(x$737) = +//│ let x$761 = Lambda$757(x$737) in +//│ x$761 +//│ def entry$764() = +//│ let* (x$762) = add(1) in +//│ let x$763 = Callable.apply1$753(x$762,2) in +//│ x$763 +//│ entry = entry$764 +//│ +//│ Interpreted: +//│ 3 + +fun add4(a, b) = (c, d) => a + b + c + d +fun add4_curried(a, b)(c, d) = a + b + c + d +add4(1, 2)(3, 4) +//│ = 10 +//│ +//│ Interpreted: +//│ 10 + +fun add(a, b) = a + b +fun dummy() = add +dummy()(1, 2) +//│ = 3 +//│ +//│ Interpreted: +//│ 3 + +abstract class List[out T]: Cons[T] | Nil +data class (::) Cons[out T](head: T, tail: List[T]) extends List[T] +object Nil extends List +fun map(f, l) = + if l is + Cons(h, t) then Cons(f(h), map(f, t)) + Nil then Nil +fun inc(x) = x + 1 +fun main() = + map(x => inc(x), 1 :: 2 :: Nil) + map(inc, 3 :: 4 :: Nil) +main() +//│ = Cons(4, Cons(5, Nil)) +//│ +//│ Interpreted: +//│ Cons(4,Cons(5,Nil())) + +abstract class List[out T]: Cons[T] | Nil +data class (::) Cons[out T](head: T, tail: List[T]) extends List[T] +object Nil extends List +fun not(c) = if c then false else true +fun filter(f, ls) = if ls is + Nil then Nil + h :: t and + f(h) then h :: filter(f, t) + else filter(f, t) +fun nubBy(eq, ls) = if ls is + Nil then Nil + h :: t then h :: nubBy(eq, filter(y => not(eq(h, y)), t)) +nubBy((x, y) => x == y, 1 :: 2 :: 3 :: 3 :: Nil) +//│ = Cons(1, Cons(2, Cons(3, Nil))) +//│ +//│ Interpreted: +//│ Cons(1,Cons(2,Cons(3,Nil()))) + +:intl +fun f(x) = + fun self_rec(x) = + if x == 0 then 0 + else x + self_rec(x - 1) + self_rec(x) +f(3) +//│ = 6 +//│ +//│ Interpreted: +//│ 6 + +fun f(x) = + fun even(x) = + fun odd(x) = + if x == 0 then true + else if x == 1 then false + else even(x - 1) + if x == 0 then true + else if x == 1 then false + else odd(x - 1) + even(x) +f(3) +//│ = false +//│ +//│ Interpreted: +//│ false diff --git a/hkmc2/shared/src/test/mlscript/llir/Lazy.mls b/hkmc2/shared/src/test/mlscript/llir/Lazy.mls new file mode 100644 index 0000000000..c0a81aa551 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Lazy.mls @@ -0,0 +1,89 @@ +:llir + +// This should ideally be a declaration. +// Now it is a built-in class specially handled in the C++ backend, +// with related logic defined in `mlsprelude.h`. +abstract class Lazy[out A](init: () -> A) with + fun get: A +fun lazy(x) = Lazy(x) +fun force(x) = if x is Lazy then x.Lazy#get() +type LazyList[out T] = Lazy[LzList[T]] +abstract class LzList[out T]: LzCons[T] | LzNil +data class LzCons[out T](head: T, tail: LazyList[T]) extends LzList[T] +object LzNil extends LzList + +:sllir +:scpp +fun side_effect() = + console.log("executed") + 1 +fun main() = + let x = lazy(() => side_effect()) + let y = force(x) + let y1 = force(x) // force again, but should not execute side_effect again + () +main() +//│ LLIR: +//│ class Lambda_lambda() extends Callable { +//│ def apply0$797() = +//│ let* (x$798) = side_effect() in +//│ x$798 +//│ } +//│ def side_effect$772() = +//│ let* (x$792) = ("println","executed") in +//│ 1 +//│ def main$771() = +//│ let x$800 = Lambda_lambda$794() in +//│ let* (x$801) = lazy(x$800) in +//│ let* (x$802) = force(x$801) in +//│ let* (x$803) = force(x$801) in +//│ undefined +//│ def entry$805() = +//│ let* (x$804) = main() in +//│ x$804 +//│ entry = entry$805 +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ struct _mls_Lambda_lambda; +//│ _mlsValue _mls_side_effect(); +//│ _mlsValue _mls_main(); +//│ _mlsValue _mls_entry(); +//│ struct _mls_Lambda_lambda: public _mls_Callable { +//│ +//│ constexpr static inline const char *typeName = "Lambda_lambda"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); } +//│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Lambda_lambda; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } +//│ virtual _mlsValue _mls_apply0(); +//│ }; +//│ _mlsValue _mls_side_effect() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x1 = _mls_builtin_println(_mlsValue::create<_mls_Str>("executed")); +//│ _mls_retval = _mlsValue::fromIntLit(1); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_main() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x2 = _mlsValue::create<_mls_Lambda_lambda>(); +//│ auto _mls_x3 = _mls_lazy(_mls_x2); +//│ auto _mls_x4 = _mls_force(_mls_x3); +//│ auto _mls_x5 = _mls_force(_mls_x3); +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x6 = _mls_main(); +//│ _mls_retval = _mls_x6; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_Lambda_lambda::_mls_apply0() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x = _mls_side_effect(); +//│ _mls_retval = _mls_x; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { return _mls_entry(); } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } diff --git a/hkmc2/shared/src/test/mlscript/llir/LazyCycle.mls b/hkmc2/shared/src/test/mlscript/llir/LazyCycle.mls new file mode 100644 index 0000000000..10f087bdd6 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/LazyCycle.mls @@ -0,0 +1,166 @@ +:llir + +// This should ideally be a declaration. +// Now it is a built-in class specially handled in the C++ backend, +// with related logic defined in `mlsprelude.h`. +abstract class Lazy[out A](init: () -> A) with + fun get: A +fun lazy(x) = Lazy(x) +fun force(x) = if x is Lazy then x.Lazy#get() +type LazyList[out T] = Lazy[LzList[T]] +abstract class LzList[out T]: LzCons[T] | LzNil +data class LzCons[out T](head: T, tail: LazyList[T]) extends LzList[T] +object LzNil extends LzList + +:sllir +:showWholeCpp +fun llist(x) = + fun f(x) = lazy(() => LzCons(x, f(x + 1))) + f(x) +llist(1) +//│ LLIR: +//│ class Lambda_lambda(lam_arg0$794,lam_arg1$795) extends Callable { +//│ def apply0$796() = +//│ let x$797 = +(lam_arg0$794,1) in +//│ let x$798 = Callable.apply1$791(lam_arg1$795,x$797) in +//│ let x$800 = LzCons$756(lam_arg0$794,x$798) in +//│ x$800 +//│ } +//│ class Lambda_f() extends Callable { +//│ def apply1$791(x$776) = +//│ let x$801 = Lambda_lambda$792(x$776,$790) in +//│ let* (x$802) = lazy(x$801) in +//│ x$802 +//│ } +//│ def llist$772(x$774) = +//│ let x$803 = Lambda_f$788() in +//│ let x$804 = Callable.apply1$791(x$803,x$774) in +//│ x$804 +//│ def entry$806() = +//│ let* (x$805) = llist(1) in +//│ x$805 +//│ entry = entry$806 +//│ +//│ WholeProgramCpp: +//│ #include "mlsprelude.h" +//│ struct _mls_LzList; +//│ struct _mls_LzCons; +//│ struct _mls_Lambda_f; +//│ struct _mls_Lazy; +//│ struct _mls_Lambda_lambda; +//│ struct _mls_LzNil; +//│ _mlsValue _mls_entry2(); +//│ _mlsValue _mls_entry(); +//│ _mlsValue _mls_llist(_mlsValue); +//│ _mlsValue _mls_force(_mlsValue); +//│ _mlsValue _mls_lazy(_mlsValue); +//│ _mlsValue _mls_j(); +//│ _mlsValue _mls_entry1(); +//│ struct _mls_LzList: public _mlsObject { +//│ +//│ constexpr static inline const char *typeName = "LzList"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); } +//│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_LzList; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } +//│ +//│ }; +//│ struct _mls_LzCons: public _mls_LzList { +//│ _mlsValue _mls_head; +//│ _mlsValue _mls_tail; +//│ constexpr static inline const char *typeName = "LzCons"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); std::printf("("); this->_mls_head.print(); std::printf(", "); this->_mls_tail.print(); std::printf(")"); } +//│ virtual void destroy() override { _mlsValue::destroy(this->_mls_head); _mlsValue::destroy(this->_mls_tail); operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create(_mlsValue _mls_head, _mlsValue _mls_tail) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_LzCons; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; _mlsVal->_mls_head = _mls_head; _mlsVal->_mls_tail = _mls_tail; return _mlsValue(_mlsVal); } +//│ +//│ }; +//│ struct _mls_Lambda_f: public _mls_Callable { +//│ +//│ constexpr static inline const char *typeName = "Lambda_f"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); } +//│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Lambda_f; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } +//│ virtual _mlsValue _mls_apply1(_mlsValue); +//│ }; +//│ struct _mls_Lambda_lambda: public _mls_Callable { +//│ _mlsValue _mls_lam_arg0; +//│ _mlsValue _mls_lam_arg1; +//│ constexpr static inline const char *typeName = "Lambda_lambda"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); std::printf("("); this->_mls_lam_arg0.print(); std::printf(", "); this->_mls_lam_arg1.print(); std::printf(")"); } +//│ virtual void destroy() override { _mlsValue::destroy(this->_mls_lam_arg0); _mlsValue::destroy(this->_mls_lam_arg1); operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create(_mlsValue _mls_lam_arg0, _mlsValue _mls_lam_arg1) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Lambda_lambda; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; _mlsVal->_mls_lam_arg0 = _mls_lam_arg0; _mlsVal->_mls_lam_arg1 = _mls_lam_arg1; return _mlsValue(_mlsVal); } +//│ virtual _mlsValue _mls_apply0(); +//│ }; +//│ struct _mls_LzNil: public _mls_LzList { +//│ +//│ constexpr static inline const char *typeName = "LzNil"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); } +//│ virtual void destroy() override { operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create() { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_LzNil; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; return _mlsValue(_mlsVal); } +//│ +//│ }; +//│ _mlsValue _mls_j() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_lazy(_mlsValue _mls_x7) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x6 = _mlsValue::create<_mls_Lazy>(_mls_x7); +//│ _mls_retval = _mls_x6; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_llist(_mlsValue _mls_x12) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x11 = _mlsValue::create<_mls_Lambda_f>(); +//│ auto _mls_x13 = _mlsMethodCall<_mls_Callable>(_mls_x11)->_mls_apply1(_mls_x12); +//│ _mls_retval = _mls_x13; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry1() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x8 = _mls_llist(_mlsValue::fromIntLit(1)); +//│ _mls_retval = _mls_x8; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry2() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_force(_mlsValue _mls_x9) { +//│ _mlsValue _mls_retval; +//│ if (_mlsValue::isValueOf<_mls_Lazy>(_mls_x9)) { +//│ auto _mls_x10 = _mlsMethodCall<_mls_Lazy>(_mls_x9)->_mls_get(); +//│ _mls_retval = _mls_x10; +//│ } else { +//│ throw std::runtime_error("match error"); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_Lambda_f::_mls_apply1(_mlsValue _mls_x1) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x = _mlsValue::create<_mls_Lambda_lambda>(_mls_x1, _mlsValue(this, _mlsValue::inc_ref_tag{})); +//│ auto _mls_x2 = _mls_lazy(_mls_x); +//│ _mls_retval = _mls_x2; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_Lambda_lambda::_mls_apply0() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x3 = (_mls_lam_arg0 + _mlsValue::fromIntLit(1)); +//│ auto _mls_x4 = _mlsMethodCall<_mls_Callable>(_mls_lam_arg1)->_mls_apply1(_mls_x3); +//│ auto _mls_x5 = _mlsValue::create<_mls_LzCons>(_mls_lam_arg0, _mls_x4); +//│ _mls_retval = _mls_x5; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { return _mls_entry1(); } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } diff --git a/hkmc2/shared/src/test/mlscript/llir/Method.mls b/hkmc2/shared/src/test/mlscript/llir/Method.mls new file mode 100644 index 0000000000..a6378aa72a --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Method.mls @@ -0,0 +1,26 @@ +:js +:llir +:cpp + +:sllir +class A(m) with + fun f() = m +fun main() = + let a = A(1) + a.A#f() +main() +//│ = 1 +//│ LLIR: +//│ class A(m$734) { +//│ def f$729() = +//│ m$735 +//│ } +//│ def main$731() = +//│ let x$749 = A$732(1) in +//│ let x$750 = A.f$729(x$749) in +//│ x$750 +//│ def entry$752() = +//│ let* (x$751) = main() in +//│ x$751 +//│ entry = entry$752 + diff --git a/hkmc2/shared/src/test/mlscript/llir/Playground.mls b/hkmc2/shared/src/test/mlscript/llir/Playground.mls deleted file mode 100644 index c8d06317dd..0000000000 --- a/hkmc2/shared/src/test/mlscript/llir/Playground.mls +++ /dev/null @@ -1,438 +0,0 @@ -:js -:llir - - -:sllir -abstract class Option[out T]: Some[T] | None -data class Some[out T](x: T) extends Option[T] -object None extends Option -fun fromSome(s) = if s is Some(x) then x -data class Lazy[out A](init: () -> A) with - mut val cache: Option[A] = None -fun lazy(x) = Lazy(x) -//│ LLIR: -//│ class Option() -//│ class Some(x,x) -//│ class None() -//│ class Lazy(init,init,cache) -//│ def fromSome(s) = -//│ case s of -//│ Some => -//│ let x$0 = s. in -//│ x$0 -//│ _ => -//│ panic "match error" -//│ def j$0() = -//│ undefined -//│ def lazy(x1) = -//│ let x$1 = Lazy(x1) in -//│ x$1 -//│ undefined - -:sllir -fun testCtor1() = None -fun testCtor2() = new None -//│ LLIR: -//│ -//│ def testCtor1() = -//│ let x$0 = None() in -//│ x$0 -//│ def testCtor2() = -//│ let x$1 = None() in -//│ x$1 -//│ undefined - -:sllir -:intl -abstract class Option[out T]: Some[T] | None -data class Some[out T](x: T) extends Option[T] -object None extends Option -fun fromSome(s) = if s is Some(x) then x -abstract class Nat: S[Nat] | O -data class S(s: Nat) extends Nat -object O extends Nat -fun aaa() = - let m = 1 - let n = 2 - let p = 3 - let q = 4 - m + n - p + q -fun bbb() = - let x = aaa() - x * 100 + 4 -fun not(x) = - if x then false else true -fun foo(x) = - if x then None - else Some(foo(not(x))) -fun main() = - let x = foo(false) - if x is - None then aaa() - Some(b1) then bbb() -main() -//│ = 404 -//│ LLIR: -//│ class Option() -//│ class Some(x,x) -//│ class None() -//│ class Nat() -//│ class S(s,s) -//│ class O() -//│ def fromSome(s) = -//│ case s of -//│ Some => -//│ let x$0 = s. in -//│ x$0 -//│ _ => -//│ panic "match error" -//│ def j$0() = -//│ undefined -//│ def aaa() = -//│ let x$1 = 1 in -//│ let x$2 = 2 in -//│ let x$3 = 3 in -//│ let x$4 = 4 in -//│ let x$5 = +(x$1,x$2) in -//│ let x$6 = -(x$5,x$3) in -//│ let x$7 = +(x$6,x$4) in -//│ x$7 -//│ def bbb() = -//│ let* (x$8) = aaa() in -//│ let x$9 = *(x$8,100) in -//│ let x$10 = +(x$9,4) in -//│ x$10 -//│ def not(x2) = -//│ case x2 of -//│ BoolLit(true) => -//│ false -//│ _ => -//│ true -//│ def j$1() = -//│ undefined -//│ def foo(x3) = -//│ case x3 of -//│ BoolLit(true) => -//│ let x$11 = None() in -//│ x$11 -//│ _ => -//│ let* (x$12) = not(x3) in -//│ let* (x$13) = foo(x$12) in -//│ let x$14 = Some(x$13) in -//│ x$14 -//│ def j$2() = -//│ undefined -//│ def main() = -//│ let* (x$15) = foo(false) in -//│ case x$15 of -//│ None => -//│ let* (x$16) = aaa() in -//│ x$16 -//│ Some => -//│ let x$17 = x$15. in -//│ let* (x$18) = bbb() in -//│ x$18 -//│ _ => -//│ panic "match error" -//│ def j$3() = -//│ undefined -//│ let* (x$19) = main() in -//│ x$19 -//│ -//│ Interpreted: -//│ 404 - -:sllir -:intl -fun f1() = - let x = 1 - let x = 2 - x -f1() -//│ = 2 -//│ LLIR: -//│ -//│ def f1() = -//│ let x$0 = 1 in -//│ let x$1 = 2 in -//│ x$1 -//│ let* (x$2) = f1() in -//│ x$2 -//│ -//│ Interpreted: -//│ 2 - -:sllir -:intl -fun f2() = - let x = 0 - if x == 1 then 2 else 3 -f2() -//│ = 3 -//│ LLIR: -//│ -//│ def f2() = -//│ let x$0 = 0 in -//│ let x$1 = ==(x$0,1) in -//│ case x$1 of -//│ BoolLit(true) => -//│ 2 -//│ _ => -//│ 3 -//│ def j$0() = -//│ undefined -//│ let* (x$2) = f2() in -//│ x$2 -//│ -//│ Interpreted: -//│ 3 - - -:sllir -fun f3() = - let x1 = 0 - let x2 = 1 - if true then x1 else x2 -f3() -//│ = 0 -//│ LLIR: -//│ -//│ def f3() = -//│ let x$0 = 0 in -//│ let x$1 = 1 in -//│ let x$2 = true in -//│ case x$2 of -//│ BoolLit(true) => -//│ x$0 -//│ _ => -//│ x$1 -//│ def j$0() = -//│ undefined -//│ let* (x$3) = f3() in -//│ x$3 - - -:sllir -:intl -fun f4() = - let x = 0 - let x = if x == 1 then 2 else 3 - x -f4() -//│ = 3 -//│ LLIR: -//│ -//│ def f4() = -//│ let x$0 = 0 in -//│ let x$1 = ==(x$0,1) in -//│ case x$1 of -//│ BoolLit(true) => -//│ let x$3 = 2 in -//│ jump j$0(x$3) -//│ _ => -//│ let x$4 = 3 in -//│ jump j$0(x$4) -//│ def j$0(x$2) = -//│ x$2 -//│ let* (x$5) = f4() in -//│ x$5 -//│ -//│ Interpreted: -//│ 3 - -:sllir -:intl -fun f5() = - let x = 0 - let x = if x == 1 then 2 else 3 - let x = if x == 2 then 4 else 5 - x -f5() -//│ = 5 -//│ LLIR: -//│ -//│ def f5() = -//│ let x$0 = 0 in -//│ let x$1 = ==(x$0,1) in -//│ case x$1 of -//│ BoolLit(true) => -//│ let x$3 = 2 in -//│ jump j$0(x$3) -//│ _ => -//│ let x$4 = 3 in -//│ jump j$0(x$4) -//│ def j$0(x$2) = -//│ let x$5 = ==(x$2,2) in -//│ case x$5 of -//│ BoolLit(true) => -//│ let x$7 = 4 in -//│ jump j$1(x$7) -//│ _ => -//│ let x$8 = 5 in -//│ jump j$1(x$8) -//│ def j$1(x$6) = -//│ x$6 -//│ let* (x$9) = f5() in -//│ x$9 -//│ -//│ Interpreted: -//│ 5 - -:sllir -:scpp -fun test() = - if true do test() -//│ LLIR: -//│ -//│ def test() = -//│ let x$0 = true in -//│ case x$0 of -//│ BoolLit(true) => -//│ let* (x$1) = test() in -//│ x$1 -//│ _ => -//│ undefined -//│ def j$0() = -//│ undefined -//│ undefined -//│ -//│ Cpp: -//│ #include "mlsprelude.h" -//│ _mlsValue _mls_j_0(); -//│ _mlsValue _mls_test(); -//│ _mlsValue _mlsMain(); -//│ _mlsValue _mls_j_0() { -//│ _mlsValue _mls_retval; -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); -//│ return _mls_retval; -//│ } -//│ _mlsValue _mls_test() { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); -//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { -//│ auto _mls_x_1 = _mls_test(); -//│ _mls_retval = _mls_x_1; -//│ } else { -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); -//│ } -//│ return _mls_retval; -//│ } -//│ _mlsValue _mlsMain() { -//│ _mlsValue _mls_retval; -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); -//│ return _mls_retval; -//│ } -//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } - -:sllir -:scpp -fun test() = - (if true then test()) + 1 -//│ LLIR: -//│ -//│ def test() = -//│ let x$0 = true in -//│ case x$0 of -//│ BoolLit(true) => -//│ let* (x$2) = test() in -//│ jump j$0(x$2) -//│ _ => -//│ panic "match error" -//│ def j$0(x$1) = -//│ let x$3 = +(x$1,1) in -//│ x$3 -//│ undefined -//│ -//│ Cpp: -//│ #include "mlsprelude.h" -//│ _mlsValue _mls_j_0(_mlsValue); -//│ _mlsValue _mls_test(); -//│ _mlsValue _mlsMain(); -//│ _mlsValue _mls_j_0(_mlsValue _mls_x_1) { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_3 = (_mls_x_1 + _mlsValue::fromIntLit(1)); -//│ _mls_retval = _mls_x_3; -//│ return _mls_retval; -//│ } -//│ _mlsValue _mls_test() { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_0 = _mlsValue::fromIntLit(1); -//│ if (_mlsValue::isIntLit(_mls_x_0, 1)) { -//│ auto _mls_x_2 = _mls_test(); -//│ _mls_retval = _mls_j_0(_mls_x_2); -//│ } else { -//│ throw std::runtime_error("match error"); -//│ } -//│ return _mls_retval; -//│ } -//│ _mlsValue _mlsMain() { -//│ _mlsValue _mls_retval; -//│ _mls_retval = _mlsValue::create<_mls_Unit>(); -//│ return _mls_retval; -//│ } -//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } - - -:sllir -:intl -:scpp -fun f() = - let x = 10 - if true do - set x += 1 - x -f() -//│ = 11 -//│ LLIR: -//│ -//│ def f() = -//│ let x$0 = 10 in -//│ let x$1 = true in -//│ case x$1 of -//│ BoolLit(true) => -//│ let x$3 = +(x$0,1) in -//│ let x$4 = undefined in -//│ jump j$0(x$3) -//│ _ => -//│ let x$5 = undefined in -//│ jump j$0(x$0) -//│ def j$0(x$2) = -//│ x$2 -//│ let* (x$6) = f() in -//│ x$6 -//│ -//│ Cpp: -//│ #include "mlsprelude.h" -//│ _mlsValue _mls_j_0(_mlsValue); -//│ _mlsValue _mls_f(); -//│ _mlsValue _mlsMain(); -//│ _mlsValue _mls_j_0(_mlsValue _mls_x_2) { -//│ _mlsValue _mls_retval; -//│ _mls_retval = _mls_x_2; -//│ return _mls_retval; -//│ } -//│ _mlsValue _mls_f() { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_0 = _mlsValue::fromIntLit(10); -//│ auto _mls_x_1 = _mlsValue::fromIntLit(1); -//│ if (_mlsValue::isIntLit(_mls_x_1, 1)) { -//│ auto _mls_x_3 = (_mls_x_0 + _mlsValue::fromIntLit(1)); -//│ auto _mls_x_4 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_3); -//│ } else { -//│ auto _mls_x_5 = _mlsValue::create<_mls_Unit>(); -//│ _mls_retval = _mls_j_0(_mls_x_0); -//│ } -//│ return _mls_retval; -//│ } -//│ _mlsValue _mlsMain() { -//│ _mlsValue _mls_retval; -//│ auto _mls_x_6 = _mls_f(); -//│ _mls_retval = _mls_x_6; -//│ return _mls_retval; -//│ } -//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } -//│ -//│ Interpreted: -//│ 11 - diff --git a/hkmc2/shared/src/test/mlscript/llir/Split.mls b/hkmc2/shared/src/test/mlscript/llir/Split.mls new file mode 100644 index 0000000000..af0f4ad158 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Split.mls @@ -0,0 +1,55 @@ +:llir + +:intl +abstract class Iter[T] +object Done +data class Yield(n) +data class Map(f, it) +fun done(it) = if it is Done then true else false +fun map(f, it) = Map(f, it) +fun dec(n) = Yield(n) +fun next(it) = if it is + Yield(n) then + if n == 0 then [0, Done] + else [n, Yield(n - 1)] + Map(f, it) then + if next(it) is [n, it] then + if done(it) then [f(n), Done] + else [f(n), Map(f, it)] +fun fold(acc, it) = + if next(it) is [n, it] then + if done(it) then acc else fold(n + acc, it) +fun map_sum(f, n) = let it = map(f, dec(n)) in fold(0, it) +map_sum(x => x, 200) +//│ +//│ Interpreted: +//│ 20100 + +:intl +abstract class Iter[T] +object Done +data class Yield(n) +data class Map(f, it) +fun done(it) = if it is Done then true else false +fun map(f, it) = Map(f, it) +fun dec(n) = Yield(n) +fun next(it) = if it is + Yield(n) then + if n == 0 then [0, Done] + else [n, Yield(n - 1)] + Map(f, it) then + if next(it) is [n, it] then + if done(it) then [f(n), Done] + else [f(n), Map(f, it)] +fun fold(acc, iter) = if iter is + Yield(n) then + if n == 0 then acc + else fold(n + acc, Yield(n - 1)) + Map(f, it) then + if next(it) is [n, it] then + if it is Done then acc else fold(f(n) + acc, Map(f, it)) +fun map_sum(f, n) = let it = map(f, dec(n)) in fold(0, it) +map_sum(x => x, 200) +//│ +//│ Interpreted: +//│ 20100 diff --git a/hkmc2/shared/src/test/mlscript/llir/Tuple.mls b/hkmc2/shared/src/test/mlscript/llir/Tuple.mls new file mode 100644 index 0000000000..ebe154bf09 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/Tuple.mls @@ -0,0 +1,81 @@ +:js +:llir +:cpp + +:intl +:sllir +:scpp +fun mkTup(x, y) = [x, y] +fun fst(t) = if t is [x, y] then x +mkTup(1, 2) +//│ = [1, 2] +//│ LLIR: +//│ +//│ def mkTup$729(x$732,y$733) = +//│ let x$745 = Tuple2$746(x$732,y$733) in +//│ x$745 +//│ def fst$730(t$736) = +//│ case t$736 of +//│ Tuple2$746 => +//│ let x$749 = t$736. in +//│ let x$750 = t$736. in +//│ x$749 +//│ _ => +//│ panic "match error" +//│ def j$748() = +//│ undefined +//│ def entry$752() = +//│ let* (x$751) = mkTup(1,2) in +//│ x$751 +//│ entry = entry$752 +//│ +//│ Cpp: +//│ #include "mlsprelude.h" +//│ struct _mls_Tuple2; +//│ _mlsValue _mls_mkTup(_mlsValue, _mlsValue); +//│ _mlsValue _mls_j(); +//│ _mlsValue _mls_fst(_mlsValue); +//│ _mlsValue _mls_entry1(); +//│ struct _mls_Tuple2: public _mlsObject { +//│ _mlsValue _mls_field0; +//│ _mlsValue _mls_field1; +//│ constexpr static inline const char *typeName = "Tuple2"; +//│ constexpr static inline uint32_t typeTag = nextTypeTag(); +//│ virtual void print() const override { std::printf("%s", typeName); std::printf("("); this->_mls_field0.print(); std::printf(", "); this->_mls_field1.print(); std::printf(")"); } +//│ virtual void destroy() override { _mlsValue::destroy(this->_mls_field0); _mlsValue::destroy(this->_mls_field1); operator delete (this, std::align_val_t(_mlsAlignment)); } +//│ static _mlsValue create(_mlsValue _mls_field0, _mlsValue _mls_field1) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) _mls_Tuple2; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; _mlsVal->_mls_field0 = _mls_field0; _mlsVal->_mls_field1 = _mls_field1; return _mlsValue(_mlsVal); } +//│ +//│ }; +//│ _mlsValue _mls_mkTup(_mlsValue _mls_x1, _mlsValue _mls_y) { +//│ _mlsValue _mls_retval; +//│ auto _mls_x = _mlsValue::create<_mls_Tuple2>(_mls_x1, _mls_y); +//│ _mls_retval = _mls_x; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_j() { +//│ _mlsValue _mls_retval; +//│ _mls_retval = _mlsValue::create<_mls_Unit>(); +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_fst(_mlsValue _mls_t) { +//│ _mlsValue _mls_retval; +//│ if (_mlsValue::isValueOf<_mls_Tuple2>(_mls_t)) { +//│ auto _mls_x2 = _mlsValue::cast<_mls_Tuple2>(_mls_t)->_mls_field0; +//│ auto _mls_x3 = _mlsValue::cast<_mls_Tuple2>(_mls_t)->_mls_field1; +//│ _mls_retval = _mls_x2; +//│ } else { +//│ throw std::runtime_error("match error"); +//│ } +//│ return _mls_retval; +//│ } +//│ _mlsValue _mls_entry1() { +//│ _mlsValue _mls_retval; +//│ auto _mls_x4 = _mls_mkTup(_mlsValue::fromIntLit(1), _mlsValue::fromIntLit(2)); +//│ _mls_retval = _mls_x4; +//│ return _mls_retval; +//│ } +//│ _mlsValue _mlsMain() { return _mls_entry1(); } +//│ int main() { return _mlsLargeStack(_mlsMainWrapper); } +//│ +//│ Interpreted: +//│ Tuple2(1,2) diff --git a/hkmc2/shared/src/test/mlscript/llir/nofib/NofibPrelude.mls b/hkmc2/shared/src/test/mlscript/llir/nofib/NofibPrelude.mls new file mode 100644 index 0000000000..599c2ad7d8 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/nofib/NofibPrelude.mls @@ -0,0 +1,354 @@ +// shadow the one in Predef +fun not(c) = if c then false else true + +type Char = String + +abstract class Option[out T]: Some[T] | None +data class Some[out T](x: T) extends Option[T] +object None extends Option + +// This should ideally be a declaration. +// Now it is a built-in class specially handled in the C++ backend, +// with related logic defined in `mlsprelude.h`. +abstract class Lazy[out A](init: () -> A) with + fun get() +fun lazy(x) = Lazy(x) +fun force(x) = if x is Lazy then x.Lazy#get() + +fun fromSome(s) = if s is Some(x) then x + +abstract class List[out T]: Cons[T] | Nil +data class (::) Cons[out T](head: T, tail: List[T]) extends List[T] +object Nil extends List +fun ltList(xs, ys, lt, gt) = if xs is + Nil and + ys is Nil then false + else true + x :: xs and ys is + Nil then false + y :: ys and + lt(x, y) then true + gt(x, y) then false + else ltList(xs, ys, lt, gt) + +type LazyList[out T] = Lazy[LzList[T]] +abstract class LzList[out T]: LzCons[T] | LzNil +data class LzCons[out T](head: T, tail: LazyList[T]) extends LzList[T] +object LzNil extends LzList + +fun ltTup2(t1, t2, lt1, gt1, lt2) = if t1 is [a, b] and t2 is [c, d] and + lt1(a, c) then true + gt1(a, c) then false + else lt2(b, d) +fun eqTup2(t1, t2) = if t1 is [a, b] and t2 is [c, d] then a == c and b == d + +fun compose(f, g) = x => f(g(x)) + +fun snd(x) = if x is [f, s] then s +fun fst(x) = if x is [f, s] then f + +fun until(p, f, i) = if p(i) then i else until(p, f, f(i)) + +fun flip(f, x, y) = f(y)(x) + +fun power(a, n) = Math.pow(a, n) + +fun intDiv(a, b) = globalThis.builtin("floor_div", a, b) +fun intQuot(a, b) = globalThis.builtin("trunc_div", a, b) + +fun intMod(a, b) = globalThis.builtin("floor_mod", a, b) +fun intRem(a, b) = globalThis.builtin("trunc_mod", a, b) + +fun quotRem(a, b) = [intQuot(a, b), intRem(a, b)] +fun divMod(a, b) = [intDiv(a, b), intMod(a, b)] + +fun head(l) = if l is h :: t then h +fun tail(l) = if l is h :: t then t + +fun while_(p, f, x) = if p(x) then while_(p, f, f(x)) else x + +fun reverse(l) = + fun r(l', l) = if l is x :: xs then r(x :: l', xs) else l' + r(Nil, l) + +fun map(f, xs) = if xs is + x :: xs then f(x) :: map(f, xs) + Nil then Nil + +fun listLen(ls) = + fun l(ls, a) = if ls is + Nil then a + h :: t then l(t, a + 1) + l(ls, 0) + +fun listEq(xs, ys) = if + xs is Nil and ys is Nil then true + xs is hx :: tx and ys is hy :: ty and (hx == hy) then listEq(tx, ty) + else false + +fun listEqBy(f, a, b) = if a is + Nil and b is Nil then true + x :: xs and b is y :: ys then f(x, y) && listEqBy(f, xs, ys) + else false + +fun listNeq(xs, ys) = if + xs is Nil and ys is Nil then false + xs is hx :: tx and ys is hy :: ty and (hx == hy) then listNeq(tx, ty) + else true + +fun enumFromTo(a, b) = if a <= b then a :: enumFromTo(a + 1, b) else Nil + +fun enumFromThenTo(a, t, b) = if a <= b then a :: enumFromThenTo(t, 2 * t - a, b) else Nil + +fun drop(n, ls) = if ls is + Nil then Nil + h :: t and + n <= 0 then ls + else drop(n - 1, t) + +fun take(n, ls) = if ls is + Nil then Nil + h :: t and + n <= 0 then Nil + else h :: take(n - 1, t) + +fun splitAt(n, ls) = [take(n, ls), drop(n, ls)] + +fun zip(xs, ys) = if xs is + x :: xs and ys is y :: ys then [x, y] :: zip(xs, ys) + else Nil + +fun inList(x, ls) = if ls is + h :: t and + x === h then true + else inList(x, t) + Nil then false + +fun notElem(x, ls) = not(inList(x, ls)) + +fun (+:) append(xs, ys) = if xs is + Nil then ys + x :: xs then x :: append(xs, ys) + +fun concat(ls) = if ls is + Nil then Nil + x :: xs then append(x, concat(xs)) + +fun filter(f, ls) = if ls is + Nil then Nil + h :: t and + f(h) then h :: filter(f, t) + else filter(f, t) + +fun filterCurried(f)(ls) = if ls is + Nil then Nil + h :: t and + f(h) then h :: filter(f, t) + else filter(f, t) + +fun all(p, ls) = if ls is + Nil then true + h :: t and + p(h) then all(p, t) + else false + +fun orList(ls) = if ls is + Nil then false + h :: t and + h then true + else orList(t) + +fun dropWhile(f, ls) = if ls is + Nil then Nil + h :: t and + f(h) then dropWhile(f, t) + else h :: t + +fun foldl(f, a, xs) = if xs is + Nil then a + h :: t then foldl(f, f(a, h), t) + +fun scanl(f, q, ls) = if ls is + Nil then q :: Nil + x :: xs then q :: scanl(f, f(q, x), xs) + +fun scanr(f, q, ls) = if ls is + Nil then q :: Nil + x :: xs and scanr(f, q, xs) is q :: t then f(x, q) :: q :: t + +fun foldr(f, z, xs) = if xs is + Nil then z + h :: t then f(h, foldr(f, z, t)) + +fun foldl1(f, ls) = if + ls is x :: xs then foldl(f, x, xs) + +fun foldr1(f, ls) = if ls is + x :: Nil then x + x :: xs then f(x, foldr1(f, xs)) + +fun maximum(xs) = foldl1((x, y) => if x > y then x else y, xs) + +fun nubBy(eq, ls) = if ls is + Nil then Nil + h :: t then h :: nubBy(eq, filter(y => not(eq(h, y)), t)) + +fun zipWith(f, xss, yss) = if + xss is x :: xs and yss is y :: ys then f(x, y) :: zipWith(f, xs, ys) + else Nil + +fun deleteBy(eq, x, ys) = if ys is + Nil then Nil + y :: ys and + eq(x, y) then ys + else y :: deleteBy(eq, x, ys) + +fun unionBy(eq, xs, ys) = append(xs, foldl((acc, y) => deleteBy(eq, y, acc), nubBy(eq, ys), xs)) + +fun union(xs, ys) = unionBy((x, y) => x == y, xs, ys) + +fun atIndex(i, ls) = if ls is + h :: t and + i == 0 then h + else atIndex(i - 1, t) + +fun sum(xs) = + fun go(xs, a) = if xs is + Nil then a + h :: t then go(t, a + h) + go(xs, 0) + +fun null_(ls) = if ls is + Nil then true + else false + +fun replicate(n, x) = if n == 0 then Nil else x :: replicate(n - 1, x) + +fun unzip(l) = + fun f(l, a, b) = if l is + Nil then [reverse(a), reverse(b)] + [x, y] :: t then f(t, x :: a, y :: b) + f(l, Nil, Nil) + +fun zip3(xs, ys, zs) = if + xs is x :: xs and ys is y :: ys and zs is z :: zs then [x, y, z] :: zip3(xs, ys, zs) + else Nil + +fun transpose(xss) = + fun lscomp(ls) = if ls is + Nil then Nil + h :: t and h is + hd :: tl then [hd, tl] :: lscomp(t) + else lscomp(t) + fun combine(y, h, ys, t) = (y :: h) :: transpose(ys :: t) + if xss is + Nil then Nil + Nil :: xss then transpose(xss) + (x :: xs) :: xss and unzip(lscomp(xss)) is [hds, tls] then combine(x, hds, xs, tls) + +fun break_(p, ls) = if ls is + Nil then [Nil, Nil] + x :: xs and + p(x) then [Nil, x :: xs] + break_(p, xs) is [ys, zs] then [x :: ys, zs] + +fun flatMap(f, ls) = if ls is + Nil then Nil + h :: t then append(f(h), flatMap(f, t)) + + +// ===================== + +fun map_lz(f, ls) = lazy of () => + if force(ls) is + LzNil then LzNil + LzCons(h, t) then LzCons(f(h), map_lz(f, t)) + +fun filter_lz(p, ls) = Lazy of () => + if force(ls) is + LzNil then LzNil + LzCons(h, t) and + p(h) then LzCons(h, filter_lz(p, t)) + else force(filter_lz(p, t)) + +fun nubBy_lz(eq, ls) = Lazy of () => + if force(ls) is + LzNil then LzNil + LzCons(h, t) then LzCons(h, nubBy_lz(eq, filter_lz(y => not(eq(h, y)), t))) + +fun nub_lz(ls) = nubBy_lz((x, y) => x == y, ls) + +fun take_lz(n, ls) = if + n > 0 and force(ls) is + LzNil then Nil + LzCons(h, t) then h :: take_lz(n - 1, t) + else Nil + +fun take_lz_lz(n, ls) = lazy of () => + if n > 0 and force(ls) is + LzNil then LzNil + LzCons(h, t) then LzCons(h, take_lz_lz(n - 1, t)) + else LzNil + +fun drop_lz(n, ls) = if + n <= 0 then ls + force(ls) is + LzNil then lazy of () => LzNil + LzCons(h, t) then drop_lz(n - 1, t) + +fun splitAt_lz(n, ls) = [take_lz(n, ls), drop_lz(n, ls)] + +fun zip_lz_nl(xs, ys) = if + force(xs) is LzCons(x, xs) and ys is y :: ys then [x, y] :: zip_lz_nl(xs, ys) + else Nil + +fun zip_lz_lz(xs, ys) = if + force(xs) is LzCons(x, xs) and force(ys) is LzCons(y, ys) then lazy of () => LzCons([x, y], zip_lz_lz(xs, ys)) + else lazy of () => LzNil + +fun zipWith_lz_lz(f, xss, yss) = lazy of () => if + force(xss) is LzCons(x, xs) and (force(yss)) is LzCons(y, ys) then LzCons(f(x, y), zipWith_lz_lz(f, xs, ys)) + else LzNil + +fun zipWith_lz_nl(f, xss, yss) = if + force(xss) is LzCons(x, xs) and yss is y :: ys then f(x, y) :: zipWith_lz_nl(f, xs, ys) + else Nil + +fun iterate(f, x) = lazy of () => LzCons(x, iterate(f, f(x))) + +fun append_nl_lz(xs, ys) = if xs is + Nil then ys + h :: t then lazy of () => LzCons(h, append_nl_lz(t, ys)) + +fun append_lz_lz(xs, ys) = lazy of () => if force(xs) is + LzNil then force(ys) + LzCons(h, t) then LzCons(h, append_lz_lz(t, ys)) + +fun replicate_lz(n, x) = if n == 0 then lazy of () => LzNil else lazy of () => LzCons(x, replicate_lz(n - 1, x)) + +fun enumFrom(a) = lazy of () => LzCons(a, enumFrom(a + 1)) + +fun head_lz(ls) = if force(ls) is LzCons(h, t) then h + +fun repeat(x) = lazy of () => LzCons(x, repeat(x)) +// ===================== + +fun stringOfInt(x) = globalThis.builtin("int2str", x) +fun stringOfFloat(x) = globalThis.builtin("float2str", x) +fun stringConcat(x, y) = globalThis.builtin("str_concat", x, y) +fun stringListConcat(ls) = if ls is + Nil then "" + h :: t then stringConcat(h, stringListConcat(t)) +fun print(x) = globalThis.builtin("println", x) +fun abs(x) = globalThis.Math.abs(x) +fun floatOfInt(x) = globalThis.builtin("int2float", x) + +// fun max(a, b) = Math.min(a, b) +// fun min(a, b) = Math.max(a, b) +// fun abs(x) = Math.abs(x) +// fun sqrt(x) = Math.sqrt(x) +// fun tan(x) = Math.tan(x) +// fun sin(x) = Math.sin(x) +// fun cos(x) = Math.cos(x) +// fun round(x) = Math.round(x) + diff --git a/hkmc2/shared/src/test/mlscript/llir/nofib/atom.mls b/hkmc2/shared/src/test/mlscript/llir/nofib/atom.mls new file mode 100644 index 0000000000..9ba81686cc --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/nofib/atom.mls @@ -0,0 +1,45 @@ +:llir + +:import NofibPrelude.mls +//│ Imported 104 member(s) + +data class State(position: List[Num], velocity: List[Num]) + +fun dotPlus(fs, gs) = if + fs is Nil then gs + gs is Nil then fs + fs is f :: fs and gs is g :: gs then (f + g) :: dotPlus(fs, gs) + +fun dotMult(fs, gs) = if + fs is f :: fs and gs is g :: gs then (f * g) :: dotMult(fs, gs) + else Nil + +fun scalarMut(c, fs) = if fs is + Nil then Nil + f :: fs then (c * f) :: scalarMut(c, fs) + +fun testforce(k, ss) = lazy of () => + if force(ss) is + LzCons(State(pos, vel), atoms) then LzCons(dotMult(scalarMut(-1.0, k), pos), testforce(k, atoms)) + +fun show(s) = + fun lscomp(ls) = if ls is + Nil then Nil + component :: t then Cons(stringConcat(stringOfFloat(component), "\t"), lscomp(t)) + if s is State(pos, vel) then + stringListConcat of lscomp(pos) + +fun propagate(dt, aforce, state) = if state is State(pos, vel) then + State(dotPlus(pos, scalarMut(dt, vel)), dotPlus(vel, scalarMut(dt, aforce))) + +fun runExperiment(law, dt, param, init) = lazy of () => + let stream = runExperiment(law, dt, param, init) + LzCons(init, zipWith_lz_lz((x, y) => propagate(dt, x, y), law(param, stream), stream)) + +fun testAtom_nofib(n) = + fun lscomp(ls) = if ls is + Nil then Nil + state :: t then stringConcat(show(state), "\n") :: lscomp(t) + stringListConcat of lscomp(take_lz(n, runExperiment(testforce, 0.02, 1.0 :: Nil, State(1.0 :: Nil, 0.0 :: Nil)))) + +testAtom_nofib(20) diff --git a/hkmc2/shared/src/test/mlscript/llir/nofib/awards.mls b/hkmc2/shared/src/test/mlscript/llir/nofib/awards.mls new file mode 100644 index 0000000000..5696241d2e --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/nofib/awards.mls @@ -0,0 +1,71 @@ +:llir + +:import NofibPrelude.mls +//│ Imported 104 member(s) + +fun delete_(xs, e) = deleteBy((x, y) => x == y, e, xs) + +fun listDiff(a, ls) = foldl(delete_, a, ls) + +:... +//│ ———————————————————————————————————————————————————————————————————————————————— +fun qsort(le, ls, r) = if ls is + Nil then r + x :: Nil then x :: r + x :: xs then qpart(le, x, xs, Nil, Nil, r) + +fun qpart(le, x, ys, rlt, rge, r) = if ys is + Nil then rqsort(le, rlt, x :: rqsort(le, rge, r)) + y :: ys and + le(x, y) then qpart(le, x, ys, rlt, y :: rge, r) + else qpart(le, x, ys, y :: rlt, rge, r) + +fun rqsort(le, ls, r) = if ls is + Nil then r + x :: Nil then x :: r + x :: xs then rqpart(le, x, xs, Nil, Nil, r) + +fun rqpart(le, x, yss, rle, rgt, r) = if yss is + Nil then qsort(le, rle, x :: qsort(le, rgt, r)) + y :: ys and + le(y, x) then rqpart(le, x, ys, y :: rle, rgt, r) + else rqpart(le, x, ys, rle, y :: rgt, r) +//│ ———————————————————————————————————————————————————————————————————————————————— + +fun sort(l) = qsort((a, b) => ltTup2(a, b, (a, b) => a < b, (a, b) => a > b, (a, b) => ltList(a, b, (a, b) => a < b, (a, b) => a > b)), l, Nil) + +fun perms(m, nns) = if + nns is Nil then Nil + m == 1 then map(x => x :: Nil, nns) + nns is n :: ns then map(x => n :: x, perms(m-1, ns)) +: perms(m, ns) + +fun atleast(threshold, sumscores) = + filter(case { [sum_, p] then sum_ >= threshold }, sumscores) + +fun award(name_threshold, sumscores) = if name_threshold is [name, threshold] then + map(ps => [name, ps], sort(atleast(threshold, sumscores))) + +fun awards(scores) = + let sumscores = map(p => [sum(p), p], perms(3, scores)) + + award(["Gold", 70], sumscores) +: award(["Silver", 60], sumscores) +: award(["Bronze", 50], sumscores) + +fun findawards(scores) = if awards(scores) is + Nil then Nil + head_ :: tail_ and head_ is [award, [sum_, perm]] then + [award, [sum_, perm]] :: findawards(listDiff(scores, perm)) + +fun findallawards(competitors) = + map(case { [name, scores] then [name, findawards(scores)] }, competitors) + +fun competitors(i) = + ["Simon", (35 :: 27 :: 40 :: i :: 34 :: 21 :: Nil)] :: + ["Hans", (23 :: 19 :: 45 :: i :: 17 :: 10 :: 5 :: 8 :: 14 :: Nil)] :: + ["Phil", (1 :: 18 :: i :: 20 :: 21 :: 19 :: 34 :: 8 :: 16 :: 21 :: Nil)] :: + ["Kevin", (9 :: 23 :: 17 :: 54 :: i :: 41 :: 9 :: 18 :: 14 :: Nil)] :: + Nil + +fun testAwards_nofib(n) = + map(x => print(findallawards(competitors(intMod(x, 100)))), enumFromTo(1, n)) + +testAwards_nofib(100) diff --git a/hkmc2/shared/src/test/mlscript/llir/nofib/constraints.mls b/hkmc2/shared/src/test/mlscript/llir/nofib/constraints.mls new file mode 100644 index 0000000000..00cb73aeeb --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/nofib/constraints.mls @@ -0,0 +1,279 @@ +:llir + +:import NofibPrelude.mls +//│ Imported 104 member(s) + + +// -- Figure 1. CSPs in Haskell. +data class Assign(varr: Int, value: Int) + +data class CSP(vars: Int, vals: Int, rel: Int) + +:... +//│ ———————————————————————————————————————————————————————————————————————————————— +fun qsort(le, ls, r) = if ls is + Nil then r + x :: Nil then x :: r + x :: xs then qpart(le, x, xs, Nil, Nil, r) + +fun qpart(le, x, ls, rlt, rge, r) = if ls is + Nil then rqsort(le, rlt, x :: rqsort(le, rge, r)) + y :: ys and + le(x, y) then qpart(le, x, ys, rlt, y :: rge, r) + else qpart(le, x, ys, y :: rlt, rge, r) + +fun rqsort(le, ls, r) = if ls is + Nil then r + x :: Nil then x :: r + x :: xs then rqpart(le, x, xs, Nil, Nil, r) + +fun rqpart(le, x, ls, rle, rgt, r) = if ls is + Nil then rqsort(le, rle, x :: qsort(le, rgt, r)) + y :: ys and + le(y, x) then rqpart(le, x, ys, y :: rle, rgt, r) + else rqpart(le, x, ys, rle, y :: rgt, r) +//│ ———————————————————————————————————————————————————————————————————————————————— + +fun level(a) = if a is Assign(v, _) then v + +fun value(a) = if a is Assign(_, v) then v + +fun maxLevel(ls) = if ls is + Nil then 0 + Assign(v, _) :: t then v + +fun complete(csp, s) = if csp is CSP(v, _, _) then maxLevel(s) == v + +fun generate(csp) = + fun g(vals, var_) = + fun lscomp1(ls) = if ls is + Nil then Nil + val_ :: t1 then + fun lscomp2(ls) = if ls is + Nil then lscomp1(t1) + st :: t2 then (Assign(var_, val_) :: st) :: lscomp2(t2) + lscomp2(g(vals, var_ - 1)) + if var_ == 0 then + Nil :: Nil + else + lscomp1(enumFromTo(1, vals)) + + if csp is CSP(vars, vals, rel) then g(vals, vars) + + +fun inconsistencies(csp, as_) = if csp is CSP(vars, vals, rel) then + fun lscomp1(ls) = if ls is + Nil then Nil + a :: t1 then + fun lscomp2(ls) = if ls is + Nil then lscomp1(t1) + b :: t2 and + a > b and not(rel(a, b)) then [level(a), (b)] :: lscomp2(t2) + else lscomp2(t2) + lscomp2(reverse(as_)) + + lscomp1(as_) + +fun consistent(csp)(x) = null_(inconsistencies(csp, x)) + +fun test(csp, ls) = filter(consistent(csp), ls) + +fun solver(csp) = test(csp, generate(csp)) + +fun safe(as1, as2) = if as1 is Assign(i, m) and as2 is Assign(j, n) then not(m == n) and not(abs(i - j) == abs(m - n)) + +fun queens(n) = CSP(n, n, safe) + +// -- Figure 2. Trees in Haskell. +data class Node[out T](lab: T, children: List[Node[T]]) + +fun label(n) = if n is Node(l, _) then l + +fun mapTree(f, n) = if n is Node(l, c) then Node(f(l), map((x => mapTree(f, x)), c)) + +fun foldTree(f, n) = if n is Node(l, c) then f(l, map((x => foldTree(f, x)), c)) + +fun filterTree(p, t) = + fun f1(a, cs) = Node(a, filter(x => p(label(x)), cs)) + foldTree(f1, t) + +fun prune(p, t) = filterTree(x => not(p(x)), t) + +fun leaves(t) = if t is + Node(leaf, Nil) then leaf :: Nil + Node(_, cs) then concat(map(leaves, cs)) + +fun initTree(f, x) = Node(x, map(y => initTree(f, y), f(x))) + +// -- Figure 3. Simple backtracking solver for CSPs. +fun mkTree(csp) = if csp is CSP(vars, vals, rel) then + fun next(ss) = + if maxLevel(ss) < vars then + fun lscomp1(ls) = if ls is + Nil then Nil + j :: t1 then + (Assign(maxLevel(ss) + 1, j) :: ss) :: lscomp1(t1) + lscomp1(enumFromTo(1, vals)) + else + Nil + + initTree(next, Nil) + + +fun earliestInconsistency(csp, aas) = if csp is CSP(vars, vals, rel) and aas is + Nil then None + a :: as_ and filter(x => not(rel(a, x)), reverse(as_)) is + Nil then None + b :: _ then Some([level(a), level(b)]) + +fun labelInconsistencies(csp, t) = + fun f2(s) = [s, earliestInconsistency(csp, s)] + + mapTree(f2, t) + + +fun btsolver0(csp) = + filter of + x => complete(csp, x) + leaves of + mapTree of + fst + prune of + x => not(snd(x) is None) + labelInconsistencies(csp, mkTree(csp)) + +// -- Figure 6. Conflict-directed solving of CSPs. +abstract class ConflictSet: Known | Unknown +data class Known(vs: List[Int]) extends ConflictSet +object Unknown extends ConflictSet + +fun knownConflict(c) = if c is + Known(a :: as_) then true + else false + +fun knownSolution(c) = if c is + Known(Nil) then true + else false + +fun checkComplete(csp, s) = if complete(csp, s) then Known(Nil) else Unknown + +fun search(labeler, csp) = + map of + fst + filter of + x => knownSolution(snd(x)) + leaves of prune of + x => knownConflict(snd(x)) + labeler(csp, mkTree(csp)) + +fun bt(csp, t) = + fun f3(s) = [s, (if earliestInconsistency(csp, s) is Some([a, b]) then Known(a :: b :: Nil) else checkComplete(csp, s))] + + mapTree(f3, t) + +// -- Figure 8. Backmarking. + +fun emptyTable(csp) = if csp is CSP(vars, vals, rel) then + fun lscomp1(ls) = if ls is + Nil then Nil + n :: t1 then + fun lscomp2(ls) = if ls is + Nil then Nil + m :: t2 then + Unknown :: lscomp2(t2) + lscomp2(enumFromTo(1, vals)) :: lscomp1(t1) + + Nil :: lscomp1(enumFromTo(1, vars)) + + +fun fillTable(s, csp, tbl) = if s is + Nil then tbl + Assign(var_, val_) :: as_ and csp is CSP(vars, vals, rel) then + fun f4(cs, varval) = if varval is [varr, vall] and + cs is Unknown and not(rel(Assign(var_, val_), Assign(varr, vall))) then Known(var_ :: varr :: Nil) + else cs + + fun lscomp1(ls) = if ls is + Nil then Nil + varrr :: t1 then + fun lscomp2(ls) = if ls is + Nil then Nil + valll :: t2 then [varrr, valll] :: lscomp2(t2) + lscomp2(enumFromTo(1, vals)) :: lscomp1(t1) + + zipWith((x, y) => zipWith(f4, x, y), tbl, lscomp1(enumFromTo(var_ + 1, vars))) + + +fun lookupCache(csp, t) = + fun f5(csp, tp) = if tp is + [Nil, tbl] then [[Nil, Unknown], tbl] + [a :: as_, tbl] then + let tableEntry = atIndex(value(a) - 1, head(tbl)) + let cs = if tableEntry is Unknown then checkComplete(csp, a :: as_) else tableEntry + [[a :: as_, cs], tbl] + + mapTree(x => f5(csp, x), t) + + +fun cacheChecks(csp, tbl, n) = if n is Node(s, cs) then + Node([s, tbl], map(x => cacheChecks(csp, fillTable(s, csp, tail(tbl)), x), cs)) + +fun bm(csp, t) = mapTree(fst, lookupCache(csp, cacheChecks(csp, emptyTable(csp), t))) + +// -- Figure 10. Conflict-directed backjumping. +fun combine(ls, acc) = if ls is + Nil then acc + [s, Known(cs)] :: css and + notElem(maxLevel(s), cs) then cs + else combine(css, union(cs, acc)) + +fun bj_(csp, t) = + fun f7(tp2, chs) = if tp2 is + [a, Known(cs)] then Node([a, Known(cs)], chs) + [a, Unknown] and + let cs_ = Known(combine(map(label, chs), Nil)) + knownConflict(cs_) then Node([a, cs_], Nil) + else Node([a, cs_], chs) + + foldTree(f7, t) + + +fun bj(csp, t) = + fun f6(tp2, chs) = if tp2 is + [a, Known(cs)] then Node([a, Known(cs)], chs) + [a, Unknown] then Node([a, Known(combine(map(label, chs), Nil))], chs) + + foldTree(f6, t) + + +fun bjbt(csp, t) = bj(csp, bt(csp, t)) + +fun bjbt_(csp, t) = bj_(csp, bt(csp, t)) + +// -- Figure 11. Forward checking. +fun collect(ls) = if ls is + Nil then Nil + Known(cs) :: css then union(cs, collect(css)) + +fun domainWipeout(csp, t) = if csp is CSP(vars, vals, rel) then + fun f8(tp2) = if tp2 is [[as_, cs], tbl] then + let wipedDomains = + fun lscomp1(ls) = if ls is + Nil then Nil + vs :: t1 and + all(knownConflict, vs) then vs :: lscomp1(t1) + else lscomp1(t1) + lscomp1(tbl) + let cs_ = if null_(wipedDomains) then cs else Known(collect(head(wipedDomains))) + [as_, cs_] + + mapTree(f8, t) + + +fun fc(csp, t) = domainWipeout(csp, lookupCache(csp, cacheChecks(csp, emptyTable(csp), t))) + +fun try_(n, algorithm) = listLen(search(algorithm, queens(n))) + +fun testConstraints_nofib(n) = map(x => try_(n, x), bt :: bm :: bjbt :: bjbt_ :: fc :: Nil) + +print(testConstraints_nofib(6)) diff --git a/hkmc2/shared/src/test/mlscript/llir/nofib/scc.mls b/hkmc2/shared/src/test/mlscript/llir/nofib/scc.mls new file mode 100644 index 0000000000..48c4d12f41 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/nofib/scc.mls @@ -0,0 +1,52 @@ +:llir + +:import NofibPrelude.mls +//│ Imported 104 member(s) + + +fun dfs(r, vsns, xs) = if vsns is [vs, ns] and + xs is + Nil then [vs, ns] + x :: xs and + inList(x, vs) then dfs(r, [vs, ns], xs) + dfs(r, [x :: vs, Nil], r(x)) is [vs', ns'] then dfs(r, [vs', (x :: ns') +: ns], xs) + +fun stronglyConnComp(es, vs) = + fun swap(a) = if a is [f, s] then [s, f] + + fun new_range(xys, w) = if xys is + Nil then Nil + [x, y] :: xys and + x == w then y :: new_range(xys, w) + else new_range(xys, w) + + fun span_tree(r, vsns, xs) = if vsns is [vs, ns] and + xs is + Nil then [vs, ns] + x :: xs and + inList(x, vs) then span_tree(r, [vs, ns], xs) + dfs(r, [x :: vs, Nil], r(x)) is [vs', ns'] then span_tree(r, [vs', (x :: ns') :: ns], xs) + + snd of span_tree of + x => new_range(map(swap, es), x) + [Nil, Nil] + snd of dfs of + x => new_range(es, x) + [Nil, Nil] + vs + + +fun testScc_nofib(d) = + let a = 1 + let b = 2 + let c = 3 + let d = 4 + let f = 5 + let g = 6 + let h = 7 + let vertices = a :: b :: c :: d :: f :: g :: h :: Nil + let edges = [b, a] :: [c, b] :: [c, d] :: [c, h] :: [d, c] :: [f, a] :: [f, g] :: [f, h] :: [g, f] :: [h, g] :: Nil + + stronglyConnComp(edges, vertices) + +print(testScc_nofib(0)) diff --git a/hkmc2/shared/src/test/mlscript/llir/nofib/secretary.mls b/hkmc2/shared/src/test/mlscript/llir/nofib/secretary.mls new file mode 100644 index 0000000000..b17fe12a31 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/llir/nofib/secretary.mls @@ -0,0 +1,39 @@ +:llir + +:import NofibPrelude.mls +//│ Imported 104 member(s) + + +fun infRand(m, s) = + fun f(x) = lazy(() => LzCons((intMod(x, m) + 1), f(intMod((97 * x + 11), power(2, 7))))) + + f(s) + + +fun simulate(n, m, proc) = + fun lscomp(ls) = if ls is + Nil then Nil + seed :: t then proc(infRand(m, seed)) :: lscomp(t) + + floatOfInt(listLen(filter(x => x, lscomp(enumFromTo(1, n))))) / floatOfInt(n) + + +fun sim(n, k) = + fun proc(rs) = + let xs = take_lz(100, nub_lz(rs)) + let best = 100 + let bestk = maximum(take(k, xs)) + let afterk = dropWhile(x => x < bestk, drop(k, xs)) + listEq(best :: Nil, take(1, afterk)) + + simulate(n, 100, proc) + + +fun testSecretary_nofib(n) = + fun listcomp(ls) = if ls is + Nil then Nil + h :: t then sim(n, h) :: listcomp(t) + + listcomp(enumFromTo(35, 39)) + +print(testSecretary_nofib(50)) diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala index 3ad015ce22..0c90e8945c 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/LlirDiffMaker.scala @@ -12,66 +12,96 @@ import codegen.cpp.* import hkmc2.syntax.Tree.Ident import hkmc2.codegen.Path import hkmc2.semantics.Term.Blk -import hkmc2.codegen.llir.Fresh import hkmc2.utils.Scope -import hkmc2.codegen.llir.Ctx import hkmc2.codegen.llir._ +import hkmc2.codegen.cpp._ import hkmc2.semantics.Elaborator +import scala.collection.mutable.ListBuffer abstract class LlirDiffMaker extends BbmlDiffMaker: val llir = NullaryCommand("llir") - val cpp = NullaryCommand("cpp") val sllir = NullaryCommand("sllir") + val intl = NullaryCommand("intl") + val lprelude = NullaryCommand("lpre") + + // C++ codegen generation commands for individual blocks + val cpp = NullaryCommand("cpp") val scpp = NullaryCommand("scpp") val rcpp = NullaryCommand("rcpp") - val intl = NullaryCommand("intl") val wcpp = Command[Str]("wcpp", false)(x => x.stripLeading()) + // C++ codegen generation commands for the whole program + val wholeCpp = NullaryCommand("wholeCpp") + val sWholeCpp = NullaryCommand("showWholeCpp") + val rWholeCpp = NullaryCommand("runWholeCpp") + val wWholeCpp = Command[Str]("writeWholeCpp", false)(x => x.stripLeading()) + def printToFile(f: java.io.File)(op: java.io.PrintWriter => Unit) = val p = new java.io.PrintWriter(f) try { op(p) } finally { p.close() } given Elaborator.Ctx = curCtx + + object Llir: // Avoid polluting the namespace + val freshId = FreshInt() + var ctx = codegen.llir.Ctx.empty + val scope = Scope.empty + val wholeProg = ListBuffer.empty[Program] + import Llir.* + + def mkWholeProgram: Program = + if wholeProg.length == 0 then + throw new Exception("No program to make") + else + Program( + classes = wholeProg.iterator.flatMap(_.classes).toSet, + defs = wholeProg.iterator.flatMap(_.defs).toSet, + entry = wholeProg.last.entry + ) override def processTerm(trm: Blk, inImport: Bool)(using Config, Raise): Unit = super.processTerm(trm, inImport) if llir.isSet then val low = ltl.givenIn: codegen.Lowering() - val le = low.program(trm) - given scp: Scope = Scope.empty - scp.allocateName(Elaborator.State.runtimeSymbol) - val fresh = Fresh() - val fuid = FreshInt() - val cuid = FreshInt() - val llb = LlirBuilder(tl)(fresh, fuid, cuid) - given Ctx = Ctx.empty + var le = low.program(trm) + given Scope = scope + given Ctx = ctx + val llb = LlirBuilder(tl, freshId) try - val llirProg = llb.bProg(le) - if sllir.isSet then + val (llirProg, ctx2) = llb.bProg(le) + ctx = ctx2 + wholeProg += llirProg + if sllir.isSet && !silent.isSet then output("LLIR:") output(llirProg.show()) - if cpp.isSet || scpp.isSet || rcpp.isSet || wcpp.isSet then - val cpp = codegen.cpp.CppCodeGen.codegen(llirProg) - if scpp.isSet then - output("\nCpp:") - output(cpp.toDocument.toString) - val auxPath = os.Path(rootPath) / "hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" - if wcpp.isSet then - printToFile(java.io.File((auxPath / s"${wcpp.get.get}.cpp").toString)): - p => p.println(cpp.toDocument.toString) - if rcpp.isSet then - val cppHost = CppCompilerHost(auxPath.toString, output.apply) - if !cppHost.ready then - output("\nCpp Compilation Failed: Cpp compiler or GNU Make not found") - else - output("\n") - cppHost.compileAndRun(cpp.toDocument.toString) + def cppGen(name: String, prog: Program, gen: Bool, show: Bool, run: Bool, write: Opt[Str]): Unit = + tl.log(s"Generating $name") + if gen || show || run || write.isDefined then + val cpp = CppCodeGen(ctx.builtinSym.hiddenClasses, tl).codegen(prog) + if show then + output(s"\n$name:") + output(cpp.toDocument.toString) + val rPath = os.Path(rootPath) + val auxPath = rPath/"hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"cpp" + if write.isDefined then + printToFile(java.io.File((auxPath / s"${write.get}").toString)): + p => p.println(cpp.toDocument.toString) + if run then + val cppHost = CppCompilerHost(auxPath.toString, output.apply) + if !cppHost.ready then + output("\nCpp Compilation Failed: Cpp compiler or GNU Make not found") + else if !silent.isSet then + output("\n") + cppHost.compileAndRun(cpp.toDocument.toString) + cppGen("Cpp", llirProg, + cpp.isSet, scpp.isSet, rcpp.isSet, wcpp.get) + cppGen("WholeProgramCpp", mkWholeProgram, + wholeCpp.isSet, sWholeCpp.isSet, rWholeCpp.isSet, wWholeCpp.get) if intl.isSet then - val intr = codegen.llir.Interpreter(verbose = true) + val intr = codegen.llir.Interpreter(tl) output("\nInterpreted:") output(intr.interpret(llirProg)) catch case e: LowLevelIRError => output("Stopped due to an error during the Llir generation") - diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/Watcher.scala b/hkmc2DiffTests/src/test/scala/hkmc2/Watcher.scala index 9a9df47c9f..511b999a00 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/Watcher.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/Watcher.scala @@ -89,15 +89,16 @@ class Watcher(dirs: Ls[File]): val path = os.Path(file.pathAsString) val basePath = path.segments.drop(dirPaths.head.segmentCount).toList.init val relativeName = basePath.map(_ + "/").mkString + path.baseName - val preludePath = os.pwd/os.up/"hkmc2"/"shared"/"src"/"test"/"mlscript"/"decls"/"Prelude.mls" - val predefPath = os.pwd/os.up/"hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"Predef.mls" + val rootPath = os.pwd/os.up + val preludePath = rootPath/"hkmc2"/"shared"/"src"/"test"/"mlscript"/"decls"/"Prelude.mls" + val predefPath = rootPath/"hkmc2"/"shared"/"src"/"test"/"mlscript-compile"/"Predef.mls" val isModuleFile = path.segments.contains("mlscript-compile") if isModuleFile then given Config = Config.default MLsCompiler(preludePath, outputConsumer => outputConsumer(System.out.println)).compileModule(path) else - val dm = new MainDiffMaker((dirPaths.head/os.up).toString, path, preludePath, predefPath, relativeName): + val dm = new MainDiffMaker(rootPath.toString, path, preludePath, predefPath, relativeName): override def unhandled(blockLineNum: Int, exc: Throwable): Unit = exc.printStackTrace() super.unhandled(blockLineNum, exc)