Skip to content

Commit bbae1e4

Browse files
committed
Merge branch 'graph-ir' into tidy-ir
# Conflicts: # compiler/shared/main/scala/mlscript/compiler/codegen/CppCodeGen.scala # compiler/shared/main/scala/mlscript/compiler/ir/IR.scala # compiler/shared/test/scala/mlscript/compiler/TestIR.scala
2 parents 98fd820 + a597a89 commit bbae1e4

File tree

4 files changed

+68
-73
lines changed

4 files changed

+68
-73
lines changed

compiler/shared/main/scala/mlscript/compiler/codegen/CppAst.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ object Expr:
122122
enum Expr:
123123
case Var(name: Str)
124124
case IntLit(value: BigInt)
125-
case FloatLit(value: Float)
125+
case DoubleLit(value: Float)
126126
case StrLit(value: Str)
127127
case CharLit(value: Char)
128128
case Call(func: Expr, args: Ls[Expr])
@@ -137,7 +137,7 @@ enum Expr:
137137
def aux(x: Expr): Document = x match
138138
case Var(name) => name |> raw
139139
case IntLit(value) => value.toString |> raw
140-
case FloatLit(value) => value.toString |> raw
140+
case DoubleLit(value) => value.toString |> raw
141141
case StrLit(value) => s"\"$value\"" |> raw // need more reliable escape utils
142142
case CharLit(value) => s"'$value'" |> raw
143143
case Call(func, args) => aux(func) <#> raw("(") <#> Expr.toDocuments(args, sep = raw(", ")) <#> raw(")")

compiler/shared/main/scala/mlscript/compiler/codegen/CppCodeGen.scala

Lines changed: 63 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -4,63 +4,67 @@ import mlscript.compiler.ir.{Expr => IExpr, _}
44
import mlscript.compiler.utils._
55
import mlscript.utils._
66
import mlscript.utils.shorthands._
7+
import scala.collection.mutable.ListBuffer
78

9+
def codegen(prog: Program): CompilationUnit =
10+
val codegen = CppCodeGen()
11+
codegen.codegen(prog)
812

9-
class CppCodeGen:
10-
private def mapName(name: Name): Str = "_mls_" + name.str.replace('$', '_').replace('\'', '_')
11-
private def mapName(name: Str): Str = "_mls_" + name.replace('$', '_').replace('\'', '_')
12-
private val freshName = Fresh(div = '_');
13-
private val mlsValType = Type.Prim("_mlsValue")
14-
private val mlsUnitValue = Expr.Call(Expr.Var("_mlsValue::create<_mls_Unit>"), Ls());
15-
private val mlsRetValue = "_mls_retval"
16-
private val mlsRetValueDecl = Decl.VarDecl(mlsRetValue, mlsValType)
17-
private val mlsMainName = "_mlsMain"
18-
private val mlsPrelude = "#include \"mlsprelude.h\""
19-
private val mlsPreludeImpl = "#include \"mlsprelude.cpp\""
20-
private val mlsInternalClass = Set("True", "False", "Boolean", "Callable")
21-
private val mlsObject = "_mlsObject"
22-
private val mlsBuiltin = "builtin"
23-
private val mlsEntryPoint = s"int main() { return _mlsLargeStack(_mlsMainWrapper); }";
24-
private def mlsIntLit(x: BigInt) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.IntLit(x)))
25-
private def mlsStrLit(x: Str) = Expr.Call(Expr.Var("_mlsValue::fromStrLit"), Ls(Expr.StrLit(x)))
26-
private def mlsCharLit(x: Char) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.CharLit(x)))
27-
private def mlsNewValue(cls: Str, args: Ls[Expr]) = Expr.Call(Expr.Var(s"_mlsValue::create<$cls>"), args)
28-
private def mlsIsValueOf(cls: Str, scrut: Expr) = Expr.Call(Expr.Var(s"_mlsValue::isValueOf<$cls>"), Ls(scrut))
29-
private def mlsIsIntLit(scrut: Expr, lit: mlscript.IntLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(lit.value)))
30-
private def mlsDebugPrint(x: Expr) = Expr.Call(Expr.Var("_mlsValue::print"), Ls(x))
31-
private def mlsTupleValue(init: Expr) = Expr.Constructor("_mlsValue::tuple", init)
32-
private def mlsAs(name: Str, cls: Str) = Expr.Var(s"_mlsValue::as<$cls>($name)")
33-
private def mlsAsUnchecked(name: Str, cls: Str) = Expr.Var(s"_mlsValue::cast<$cls>($name)")
34-
private def mlsObjectNameMethod(name: Str) = s"constexpr static inline const char *typeName = \"${name}\";"
35-
private def mlsTypeTag() = s"constexpr static inline uint32_t typeTag = nextTypeTag();"
36-
private def mlsTypeTag(n: Int) = s"constexpr static inline uint32_t typeTag = $n;"
37-
private def mlsCommonCreateMethod(cls: Str, fields: Ls[Str], id: Int) =
13+
private class CppCodeGen:
14+
def mapName(name: Name): Str = "_mls_" + name.str.replace('$', '_').replace('\'', '_')
15+
def mapName(name: Str): Str = "_mls_" + name.replace('$', '_').replace('\'', '_')
16+
val freshName = Fresh(div = '_');
17+
val mlsValType = Type.Prim("_mlsValue")
18+
val mlsUnitValue = Expr.Call(Expr.Var("_mlsValue::create<_mls_Unit>"), Ls());
19+
val mlsRetValue = "_mls_retval"
20+
val mlsRetValueDecl = Decl.VarDecl(mlsRetValue, mlsValType)
21+
val mlsMainName = "_mlsMain"
22+
val mlsPrelude = "#include \"mlsprelude.h\""
23+
val mlsPreludeImpl = "#include \"mlsprelude.cpp\""
24+
val mlsInternalClass = Set("True", "False", "Boolean", "Callable")
25+
val mlsObject = "_mlsObject"
26+
val mlsBuiltin = "builtin"
27+
val mlsEntryPoint = s"int main() { return _mlsLargeStack(_mlsMainWrapper); }";
28+
def mlsIntLit(x: BigInt) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.IntLit(x)))
29+
def mlsStrLit(x: Str) = Expr.Call(Expr.Var("_mlsValue::fromStrLit"), Ls(Expr.StrLit(x)))
30+
def mlsCharLit(x: Char) = Expr.Call(Expr.Var("_mlsValue::fromIntLit"), Ls(Expr.CharLit(x)))
31+
def mlsNewValue(cls: Str, args: Ls[Expr]) = Expr.Call(Expr.Var(s"_mlsValue::create<$cls>"), args)
32+
def mlsIsValueOf(cls: Str, scrut: Expr) = Expr.Call(Expr.Var(s"_mlsValue::isValueOf<$cls>"), Ls(scrut))
33+
def mlsIsIntLit(scrut: Expr, lit: mlscript.IntLit) = Expr.Call(Expr.Var("_mlsValue::isIntLit"), Ls(scrut, Expr.IntLit(lit.value)))
34+
def mlsDebugPrint(x: Expr) = Expr.Call(Expr.Var("_mlsValue::print"), Ls(x))
35+
def mlsTupleValue(init: Expr) = Expr.Constructor("_mlsValue::tuple", init)
36+
def mlsAs(name: Str, cls: Str) = Expr.Var(s"_mlsValue::as<$cls>($name)")
37+
def mlsAsUnchecked(name: Str, cls: Str) = Expr.Var(s"_mlsValue::cast<$cls>($name)")
38+
def mlsObjectNameMethod(name: Str) = s"constexpr static inline const char *typeName = \"${name}\";"
39+
def mlsTypeTag() = s"constexpr static inline uint32_t typeTag = nextTypeTag();"
40+
def mlsTypeTag(n: Int) = s"constexpr static inline uint32_t typeTag = $n;"
41+
def mlsCommonCreateMethod(cls: Str, fields: Ls[Str], id: Int) =
3842
val parameters = fields.map{x => s"_mlsValue $x"}.mkString(", ")
3943
val fieldsAssignment = fields.map{x => s"_mlsVal->$x = $x; "}.mkString
4044
s"static _mlsValue create($parameters) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) $cls; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; $fieldsAssignment return _mlsValue(_mlsVal); }"
41-
private def mlsCommonPrintMethod(fields: Ls[Str]) =
45+
def mlsCommonPrintMethod(fields: Ls[Str]) =
4246
if fields.isEmpty then s"virtual void print() const override { std::printf(\"%s\", typeName); }"
4347
else
4448
val fieldsPrint = fields.map{x => s"this->$x.print(); "}.mkString("std::printf(\", \"); ")
4549
s"virtual void print() const override { std::printf(\"%s\", typeName); std::printf(\"(\"); $fieldsPrint std::printf(\")\"); }"
46-
private def mlsCommonDestructorMethod(cls: Str, fields: Ls[Str]) =
50+
def mlsCommonDestructorMethod(cls: Str, fields: Ls[Str]) =
4751
val fieldsDeletion = fields.map{x => s"_mlsValue::destroy(this->$x); "}.mkString
4852
s"virtual void destroy() override { $fieldsDeletion operator delete (this, std::align_val_t(_mlsAlignment)); }"
49-
private def mlsThrowNonExhaustiveMatch = Stmt.Raw("_mlsNonExhaustiveMatch();");
50-
private def mlsCall(fn: Str, args: Ls[Expr]) = Expr.Call(Expr.Var("_mlsCall"), Expr.Var(fn) :: args)
51-
private def mlsMethodCall(cls: ClassRef, method: Str, args: Ls[Expr]) =
53+
def mlsThrowNonExhaustiveMatch = Stmt.Raw("_mlsNonExhaustiveMatch();");
54+
def mlsCall(fn: Str, args: Ls[Expr]) = Expr.Call(Expr.Var("_mlsCall"), Expr.Var(fn) :: args)
55+
def mlsMethodCall(cls: ClassRef, method: Str, args: Ls[Expr]) =
5256
Expr.Call(Expr.Member(Expr.Call(Expr.Var(s"_mlsMethodCall<${cls.name |> mapName}>"), Ls(args.head)), method), args.tail)
53-
private def mlsFnWrapperName(fn: Str) = s"_mlsFn_$fn"
54-
private def mlsFnCreateMethod(fn: Str) = s"static _mlsValue create() { static _mlsFn_$fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); }"
55-
private def mlsNeverValue(n: Int) = if (n <= 1) then Expr.Call(Expr.Var(s"_mlsValue::never"), Ls()) else Expr.Call(Expr.Var(s"_mlsValue::never<$n>"), Ls())
57+
def mlsFnWrapperName(fn: Str) = s"_mlsFn_$fn"
58+
def mlsFnCreateMethod(fn: Str) = s"static _mlsValue create() { static _mlsFn_$fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); }"
59+
def mlsNeverValue(n: Int) = if (n <= 1) then Expr.Call(Expr.Var(s"_mlsValue::never"), Ls()) else Expr.Call(Expr.Var(s"_mlsValue::never<$n>"), Ls())
5660

57-
private case class Ctx(
58-
val defnCtx: Set[Str],
61+
case class Ctx(
62+
defnCtx: Set[Str],
5963
)
6064

61-
private def codegenClassInfo(using ctx: Ctx)(cls: ClassInfo): (Opt[Def], Decl) =
65+
def codegenClassInfo(using ctx: Ctx)(cls: ClassInfo): (Opt[Def], Decl) =
6266
val fields = cls.fields.map{x => (x |> mapName, mlsValType)}
63-
val parents = if cls.parents.nonEmpty then cls.parents.toList.map{x => x |> mapName} else mlsObject :: Nil
67+
val parents = if cls.parents.nonEmpty then cls.parents.toList.map(mapName) else mlsObject :: Nil
6468
val decl = Decl.StructDecl(cls.name |> mapName)
6569
if mlsInternalClass.contains(cls.name) then return (None, decl)
6670
val theDef = Def.StructDef(
@@ -80,28 +84,28 @@ class CppCodeGen:
8084
)
8185
(S(theDef), decl)
8286

83-
private def toExpr(texpr: TrivialExpr, reifyUnit: Bool = false)(using ctx: Ctx): Opt[Expr] = texpr match
87+
def toExpr(texpr: TrivialExpr, reifyUnit: Bool = false)(using ctx: Ctx): Opt[Expr] = texpr match
8488
case IExpr.Ref(name) => S(Expr.Var(name |> mapName))
8589
case IExpr.Literal(mlscript.IntLit(x)) => S(mlsIntLit(x))
8690
case IExpr.Literal(mlscript.DecLit(x)) => S(mlsIntLit(x.toBigInt))
8791
case IExpr.Literal(mlscript.StrLit(x)) => S(mlsStrLit(x))
8892
case IExpr.Literal(mlscript.UnitLit(_)) => if reifyUnit then S(mlsUnitValue) else None
8993

90-
private def toExpr(texpr: TrivialExpr)(using ctx: Ctx): Expr = texpr match
94+
def toExpr(texpr: TrivialExpr)(using ctx: Ctx): Expr = texpr match
9195
case IExpr.Ref(name) => Expr.Var(name |> mapName)
9296
case IExpr.Literal(mlscript.IntLit(x)) => mlsIntLit(x)
9397
case IExpr.Literal(mlscript.DecLit(x)) => mlsIntLit(x.toBigInt)
9498
case IExpr.Literal(mlscript.StrLit(x)) => mlsStrLit(x)
9599
case IExpr.Literal(mlscript.UnitLit(_)) => mlsUnitValue
96100

97101

98-
private def wrapMultiValues(exprs: Ls[TrivialExpr])(using ctx: Ctx): Expr = exprs match
102+
def wrapMultiValues(exprs: Ls[TrivialExpr])(using ctx: Ctx): Expr = exprs match
99103
case x :: Nil => toExpr(x, reifyUnit = true).get
100104
case _ =>
101105
val init = Expr.Initializer(exprs.map{x => toExpr(x)})
102106
mlsTupleValue(init)
103107

104-
private def codegenCaseWithIfs(scrut: Name, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) =
108+
def codegenCaseWithIfs(scrut: Name, cases: Ls[(Pat, Node)], default: Opt[Node], storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) =
105109
val scrutName = mapName(scrut)
106110
val init: Stmt =
107111
default.fold(mlsThrowNonExhaustiveMatch)(x => {
@@ -121,12 +125,12 @@ class CppCodeGen:
121125
}
122126
(decls, stmt.fold(stmts)(x => stmts :+ x))
123127

124-
private def codegenJumpWithCall(defn: DefnRef, args: Ls[TrivialExpr], storeInto: Opt[Str])(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) =
128+
def codegenJumpWithCall(defn: DefnRef, args: Ls[TrivialExpr], storeInto: Opt[Str])(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) =
125129
val call = Expr.Call(Expr.Var(defn.name |> mapName), args.map(toExpr))
126130
val stmts2 = stmts ++ Ls(storeInto.fold(Stmt.Return(call))(x => Stmt.Assign(x, call)))
127131
(decls, stmts2)
128132

129-
private def codegenOps(op: Str, args: Ls[TrivialExpr])(using ctx: Ctx) = op match
133+
def codegenOps(op: Str, args: Ls[TrivialExpr])(using ctx: Ctx) = op match
130134
case "+" => Expr.Binary("+", toExpr(args(0)), toExpr(args(1)))
131135
case "-" => Expr.Binary("-", toExpr(args(0)), toExpr(args(1)))
132136
case "*" => Expr.Binary("*", toExpr(args(0)), toExpr(args(1)))
@@ -141,22 +145,21 @@ class CppCodeGen:
141145
case "&&" => Expr.Binary("&&", toExpr(args(0)), toExpr(args(1)))
142146
case "||" => Expr.Binary("||", toExpr(args(0)), toExpr(args(1)))
143147
case "!" => Expr.Unary("!", toExpr(args(0)))
144-
case _ => ???
148+
case _ => mlscript.utils.TODO("codegenOps")
145149

146150

147-
private def codegen(expr: IExpr)(using ctx: Ctx): Expr = expr match
151+
def codegen(expr: IExpr)(using ctx: Ctx): Expr = expr match
148152
case x @ (IExpr.Ref(_) | IExpr.Literal(_)) => toExpr(x, reifyUnit = true).get
149153
case IExpr.CtorApp(cls, args) => mlsNewValue(cls.name |> mapName, args.map(toExpr))
150154
case IExpr.Select(name, cls, field) => Expr.Member(mlsAsUnchecked(name |> mapName, cls.name |> mapName), field |> mapName)
151155
case IExpr.BasicOp(name, args) => codegenOps(name, args)
152-
// TODO: Implement this
153-
case IExpr.AssignField(assignee, cls, field, value) => ???
156+
case IExpr.AssignField(assignee, cls, field, value) => mlscript.utils.TODO("Assign field in the backend")
154157

155-
private def codegenBuiltin(names: Ls[Name], builtin: Str, args: Ls[TrivialExpr])(using ctx: Ctx): Ls[Stmt] = builtin match
158+
def codegenBuiltin(names: Ls[Name], builtin: Str, args: Ls[TrivialExpr])(using ctx: Ctx): Ls[Stmt] = builtin match
156159
case "error" => Ls(Stmt.Raw("throw std::runtime_error(\"Error\");"), Stmt.AutoBind(names.map(mapName), mlsNeverValue(names.size)))
157160
case _ => Ls(Stmt.AutoBind(names.map(mapName), Expr.Call(Expr.Var("_mls_builtin_" + builtin), args.map(toExpr))))
158161

159-
private def codegen(body: Node, storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = body match
162+
def codegen(body: Node, storeInto: Str)(using decls: Ls[Decl], stmts: Ls[Stmt])(using ctx: Ctx): (Ls[Decl], Ls[Stmt]) = body match
160163
case Node.Result(res) =>
161164
val expr = wrapMultiValues(res)
162165
val stmts2 = stmts ++ Ls(Stmt.Assign(storeInto, expr))
@@ -173,15 +176,6 @@ class CppCodeGen:
173176
val call = mlsMethodCall(cls, method.str |> mapName, args.map(toExpr))
174177
val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call))
175178
codegen(body, storeInto)(using decls, stmts2)
176-
// Use method calls instead of apply
177-
//
178-
// case Node.LetApply(names, fn, args, body) if fn.str == "builtin" =>
179-
// val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail)
180-
// codegen(body, storeInto)(using decls, stmts2)
181-
// case Node.LetApply(names, fn, args, body) =>
182-
// val call = mlsCall(fn.str |> mapName, args.map(toExpr))
183-
// val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call))
184-
// codegen(body, storeInto)(using decls, stmts2)
185179
case Node.LetCall(names, defn, args, _, body) =>
186180
val call = Expr.Call(Expr.Var(defn.name |> mapName), args.map(toExpr))
187181
val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call))
@@ -199,7 +193,7 @@ class CppCodeGen:
199193
val decl = Decl.FuncDecl(mlsValType, name |> mapName, params.map(x => mlsValType))
200194
(theDef, decl)
201195

202-
private def codegenTopNode(node: Node)(using ctx: Ctx): (Def, Decl) =
196+
def codegenTopNode(node: Node)(using ctx: Ctx): (Def, Decl) =
203197
val decls = Ls(mlsRetValueDecl)
204198
val stmts = Ls.empty[Stmt]
205199
val (decls2, stmts2) = codegen(node, mlsRetValue)(using decls, stmts)
@@ -208,27 +202,28 @@ class CppCodeGen:
208202
val decl = Decl.FuncDecl(mlsValType, mlsMainName, Ls())
209203
(theDef, decl)
210204

211-
private def sortClasses(prog: Program): Ls[ClassInfo] =
205+
// Topological sort of classes based on inheritance relationships
206+
def sortClasses(prog: Program): Ls[ClassInfo] =
212207
var depgraph = prog.classes.map(x => (x.name, x.parents)).toMap
213208
var degree = depgraph.view.mapValues(_.size).toMap
214209
def removeNode(node: Str) =
215210
degree -= node
216211
depgraph -= node
217212
depgraph = depgraph.view.mapValues(_.filter(_ != node)).toMap
218213
degree = depgraph.view.mapValues(_.size).toMap
219-
var sorted = Ls.empty[ClassInfo]
214+
val sorted = ListBuffer.empty[ClassInfo]
220215
var work = degree.filter(_._2 == 0).keys.toSet
221216
while work.nonEmpty do
222217
val node = work.head
223218
work -= node
224-
sorted = sorted :+ prog.classes.find(_.name == node).get
219+
sorted.addOne(prog.classes.find(_.name == node).get)
225220
removeNode(node)
226221
val next = degree.filter(_._2 == 0).keys
227-
work = work ++ next
222+
work ++= next
228223
if depgraph.nonEmpty then
229224
val cycle = depgraph.keys.mkString(", ")
230225
throw new Exception(s"Cycle detected in class hierarchy: $cycle")
231-
sorted
226+
sorted.toList
232227

233228
def codegen(prog: Program): CompilationUnit =
234229
val sortedClasses = sortClasses(prog)

compiler/shared/test/scala/mlscript/compiler/TestIR.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import mlscript.compiler.ir._
66
import scala.collection.mutable.StringBuilder
77
import mlscript.{DiffTests, ModeType, TypingUnit}
88
import mlscript.compiler.ir.{Fresh, FreshInt, Builder}
9-
import mlscript.compiler.codegen.cpp.{CppCodeGen, CppCompilerHost}
9+
import mlscript.compiler.codegen.cpp._
1010
import mlscript.Diagnostic
1111
import mlscript.compiler.optimizer.TailRecOpt
1212

@@ -43,7 +43,7 @@ class IRDiffTestCompiler extends DiffTests {
4343
interp_result = Some(ir)
4444
output(ir)
4545
if (mode.genCpp)
46-
val cpp = CppCodeGen().codegen(graph)
46+
val cpp = codegen(graph)
4747
if (mode.showCpp)
4848
output("\nCpp:")
4949
output(cpp.toDocument.print)

shared/src/main/scala/mlscript/NewParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1475,7 +1475,7 @@ abstract class NewParser(origin: Origin, tokens: Ls[Stroken -> Loc], newDefs: Bo
14751475

14761476
}
14771477
}
1478-
1478+
14791479
final def bindings(acc: Ls[Var -> Term])(implicit fe: FoundErr): Ls[Var -> Term] =
14801480
cur match {
14811481
case (SPACE, _) :: _ =>

0 commit comments

Comments
 (0)