@@ -4,63 +4,67 @@ import mlscript.compiler.ir.{Expr => IExpr, _}
4
4
import mlscript .compiler .utils ._
5
5
import mlscript .utils ._
6
6
import mlscript .utils .shorthands ._
7
+ import scala .collection .mutable .ListBuffer
7
8
9
+ def codegen (prog : Program ): CompilationUnit =
10
+ val codegen = CppCodeGen ()
11
+ codegen.codegen(prog)
8
12
9
- class CppCodeGen :
10
- private def mapName (name : Name ): Str = " _mls_" + name.str.replace('$' , '_' ).replace('\' ' , '_' )
11
- private def mapName (name : Str ): Str = " _mls_" + name.replace('$' , '_' ).replace('\' ' , '_' )
12
- private val freshName = Fresh (div = '_' );
13
- private val mlsValType = Type .Prim (" _mlsValue" )
14
- private val mlsUnitValue = Expr .Call (Expr .Var (" _mlsValue::create<_mls_Unit>" ), Ls ());
15
- private val mlsRetValue = " _mls_retval"
16
- private val mlsRetValueDecl = Decl .VarDecl (mlsRetValue, mlsValType)
17
- private val mlsMainName = " _mlsMain"
18
- private val mlsPrelude = " #include \" mlsprelude.h\" "
19
- private val mlsPreludeImpl = " #include \" mlsprelude.cpp\" "
20
- private val mlsInternalClass = Set (" True" , " False" , " Boolean" , " Callable" )
21
- private val mlsObject = " _mlsObject"
22
- private val mlsBuiltin = " builtin"
23
- private val mlsEntryPoint = s " int main() { return _mlsLargeStack(_mlsMainWrapper); } " ;
24
- private def mlsIntLit (x : BigInt ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .IntLit (x)))
25
- private def mlsStrLit (x : Str ) = Expr .Call (Expr .Var (" _mlsValue::fromStrLit" ), Ls (Expr .StrLit (x)))
26
- private def mlsCharLit (x : Char ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .CharLit (x)))
27
- private def mlsNewValue (cls : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (s " _mlsValue::create< $cls> " ), args)
28
- private def mlsIsValueOf (cls : Str , scrut : Expr ) = Expr .Call (Expr .Var (s " _mlsValue::isValueOf< $cls> " ), Ls (scrut))
29
- private def mlsIsIntLit (scrut : Expr , lit : mlscript.IntLit ) = Expr .Call (Expr .Var (" _mlsValue::isIntLit" ), Ls (scrut, Expr .IntLit (lit.value)))
30
- private def mlsDebugPrint (x : Expr ) = Expr .Call (Expr .Var (" _mlsValue::print" ), Ls (x))
31
- private def mlsTupleValue (init : Expr ) = Expr .Constructor (" _mlsValue::tuple" , init)
32
- private def mlsAs (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::as< $cls>( $name) " )
33
- private def mlsAsUnchecked (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::cast< $cls>( $name) " )
34
- private def mlsObjectNameMethod (name : Str ) = s " constexpr static inline const char *typeName = \" ${name}\" ; "
35
- private def mlsTypeTag () = s " constexpr static inline uint32_t typeTag = nextTypeTag(); "
36
- private def mlsTypeTag (n : Int ) = s " constexpr static inline uint32_t typeTag = $n; "
37
- private def mlsCommonCreateMethod (cls : Str , fields : Ls [Str ], id : Int ) =
13
+ private class CppCodeGen :
14
+ def mapName (name : Name ): Str = " _mls_" + name.str.replace('$' , '_' ).replace('\' ' , '_' )
15
+ def mapName (name : Str ): Str = " _mls_" + name.replace('$' , '_' ).replace('\' ' , '_' )
16
+ val freshName = Fresh (div = '_' );
17
+ val mlsValType = Type .Prim (" _mlsValue" )
18
+ val mlsUnitValue = Expr .Call (Expr .Var (" _mlsValue::create<_mls_Unit>" ), Ls ());
19
+ val mlsRetValue = " _mls_retval"
20
+ val mlsRetValueDecl = Decl .VarDecl (mlsRetValue, mlsValType)
21
+ val mlsMainName = " _mlsMain"
22
+ val mlsPrelude = " #include \" mlsprelude.h\" "
23
+ val mlsPreludeImpl = " #include \" mlsprelude.cpp\" "
24
+ val mlsInternalClass = Set (" True" , " False" , " Boolean" , " Callable" )
25
+ val mlsObject = " _mlsObject"
26
+ val mlsBuiltin = " builtin"
27
+ val mlsEntryPoint = s " int main() { return _mlsLargeStack(_mlsMainWrapper); } " ;
28
+ def mlsIntLit (x : BigInt ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .IntLit (x)))
29
+ def mlsStrLit (x : Str ) = Expr .Call (Expr .Var (" _mlsValue::fromStrLit" ), Ls (Expr .StrLit (x)))
30
+ def mlsCharLit (x : Char ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .CharLit (x)))
31
+ def mlsNewValue (cls : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (s " _mlsValue::create< $cls> " ), args)
32
+ def mlsIsValueOf (cls : Str , scrut : Expr ) = Expr .Call (Expr .Var (s " _mlsValue::isValueOf< $cls> " ), Ls (scrut))
33
+ def mlsIsIntLit (scrut : Expr , lit : mlscript.IntLit ) = Expr .Call (Expr .Var (" _mlsValue::isIntLit" ), Ls (scrut, Expr .IntLit (lit.value)))
34
+ def mlsDebugPrint (x : Expr ) = Expr .Call (Expr .Var (" _mlsValue::print" ), Ls (x))
35
+ def mlsTupleValue (init : Expr ) = Expr .Constructor (" _mlsValue::tuple" , init)
36
+ def mlsAs (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::as< $cls>( $name) " )
37
+ def mlsAsUnchecked (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::cast< $cls>( $name) " )
38
+ def mlsObjectNameMethod (name : Str ) = s " constexpr static inline const char *typeName = \" ${name}\" ; "
39
+ def mlsTypeTag () = s " constexpr static inline uint32_t typeTag = nextTypeTag(); "
40
+ def mlsTypeTag (n : Int ) = s " constexpr static inline uint32_t typeTag = $n; "
41
+ def mlsCommonCreateMethod (cls : Str , fields : Ls [Str ], id : Int ) =
38
42
val parameters = fields.map{x => s " _mlsValue $x" }.mkString(" , " )
39
43
val fieldsAssignment = fields.map{x => s " _mlsVal-> $x = $x; " }.mkString
40
44
s " static _mlsValue create( $parameters) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) $cls; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; $fieldsAssignment return _mlsValue(_mlsVal); } "
41
- private def mlsCommonPrintMethod (fields : Ls [Str ]) =
45
+ def mlsCommonPrintMethod (fields : Ls [Str ]) =
42
46
if fields.isEmpty then s " virtual void print() const override { std::printf( \" %s \" , typeName); } "
43
47
else
44
48
val fieldsPrint = fields.map{x => s " this-> $x.print(); " }.mkString(" std::printf(\" , \" ); " )
45
49
s " virtual void print() const override { std::printf( \" %s \" , typeName); std::printf( \" ( \" ); $fieldsPrint std::printf( \" ) \" ); } "
46
- private def mlsCommonDestructorMethod (cls : Str , fields : Ls [Str ]) =
50
+ def mlsCommonDestructorMethod (cls : Str , fields : Ls [Str ]) =
47
51
val fieldsDeletion = fields.map{x => s " _mlsValue::destroy(this-> $x); " }.mkString
48
52
s " virtual void destroy() override { $fieldsDeletion operator delete (this, std::align_val_t(_mlsAlignment)); } "
49
- private def mlsThrowNonExhaustiveMatch = Stmt .Raw (" _mlsNonExhaustiveMatch();" );
50
- private def mlsCall (fn : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (" _mlsCall" ), Expr .Var (fn) :: args)
51
- private def mlsMethodCall (cls : ClassRef , method : Str , args : Ls [Expr ]) =
53
+ def mlsThrowNonExhaustiveMatch = Stmt .Raw (" _mlsNonExhaustiveMatch();" );
54
+ def mlsCall (fn : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (" _mlsCall" ), Expr .Var (fn) :: args)
55
+ def mlsMethodCall (cls : ClassRef , method : Str , args : Ls [Expr ]) =
52
56
Expr .Call (Expr .Member (Expr .Call (Expr .Var (s " _mlsMethodCall< ${cls.name |> mapName}> " ), Ls (args.head)), method), args.tail)
53
- private def mlsFnWrapperName (fn : Str ) = s " _mlsFn_ $fn"
54
- private def mlsFnCreateMethod (fn : Str ) = s " static _mlsValue create() { static _mlsFn_ $fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); } "
55
- private def mlsNeverValue (n : Int ) = if (n <= 1 ) then Expr .Call (Expr .Var (s " _mlsValue::never " ), Ls ()) else Expr .Call (Expr .Var (s " _mlsValue::never< $n> " ), Ls ())
57
+ def mlsFnWrapperName (fn : Str ) = s " _mlsFn_ $fn"
58
+ def mlsFnCreateMethod (fn : Str ) = s " static _mlsValue create() { static _mlsFn_ $fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); } "
59
+ def mlsNeverValue (n : Int ) = if (n <= 1 ) then Expr .Call (Expr .Var (s " _mlsValue::never " ), Ls ()) else Expr .Call (Expr .Var (s " _mlsValue::never< $n> " ), Ls ())
56
60
57
- private case class Ctx (
58
- val defnCtx : Set [Str ],
61
+ case class Ctx (
62
+ defnCtx : Set [Str ],
59
63
)
60
64
61
- private def codegenClassInfo (using ctx : Ctx )(cls : ClassInfo ): (Opt [Def ], Decl ) =
65
+ def codegenClassInfo (using ctx : Ctx )(cls : ClassInfo ): (Opt [Def ], Decl ) =
62
66
val fields = cls.fields.map{x => (x |> mapName, mlsValType)}
63
- val parents = if cls.parents.nonEmpty then cls.parents.toList.map{x => x |> mapName} else mlsObject :: Nil
67
+ val parents = if cls.parents.nonEmpty then cls.parents.toList.map( mapName) else mlsObject :: Nil
64
68
val decl = Decl .StructDecl (cls.name |> mapName)
65
69
if mlsInternalClass.contains(cls.name) then return (None , decl)
66
70
val theDef = Def .StructDef (
@@ -80,28 +84,28 @@ class CppCodeGen:
80
84
)
81
85
(S (theDef), decl)
82
86
83
- private def toExpr (texpr : TrivialExpr , reifyUnit : Bool = false )(using ctx : Ctx ): Opt [Expr ] = texpr match
87
+ def toExpr (texpr : TrivialExpr , reifyUnit : Bool = false )(using ctx : Ctx ): Opt [Expr ] = texpr match
84
88
case IExpr .Ref (name) => S (Expr .Var (name |> mapName))
85
89
case IExpr .Literal (mlscript.IntLit (x)) => S (mlsIntLit(x))
86
90
case IExpr .Literal (mlscript.DecLit (x)) => S (mlsIntLit(x.toBigInt))
87
91
case IExpr .Literal (mlscript.StrLit (x)) => S (mlsStrLit(x))
88
92
case IExpr .Literal (mlscript.UnitLit (_)) => if reifyUnit then S (mlsUnitValue) else None
89
93
90
- private def toExpr (texpr : TrivialExpr )(using ctx : Ctx ): Expr = texpr match
94
+ def toExpr (texpr : TrivialExpr )(using ctx : Ctx ): Expr = texpr match
91
95
case IExpr .Ref (name) => Expr .Var (name |> mapName)
92
96
case IExpr .Literal (mlscript.IntLit (x)) => mlsIntLit(x)
93
97
case IExpr .Literal (mlscript.DecLit (x)) => mlsIntLit(x.toBigInt)
94
98
case IExpr .Literal (mlscript.StrLit (x)) => mlsStrLit(x)
95
99
case IExpr .Literal (mlscript.UnitLit (_)) => mlsUnitValue
96
100
97
101
98
- private def wrapMultiValues (exprs : Ls [TrivialExpr ])(using ctx : Ctx ): Expr = exprs match
102
+ def wrapMultiValues (exprs : Ls [TrivialExpr ])(using ctx : Ctx ): Expr = exprs match
99
103
case x :: Nil => toExpr(x, reifyUnit = true ).get
100
104
case _ =>
101
105
val init = Expr .Initializer (exprs.map{x => toExpr(x)})
102
106
mlsTupleValue(init)
103
107
104
- private def codegenCaseWithIfs (scrut : Name , cases : Ls [(Pat , Node )], default : Opt [Node ], storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
108
+ def codegenCaseWithIfs (scrut : Name , cases : Ls [(Pat , Node )], default : Opt [Node ], storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
105
109
val scrutName = mapName(scrut)
106
110
val init : Stmt =
107
111
default.fold(mlsThrowNonExhaustiveMatch)(x => {
@@ -121,12 +125,12 @@ class CppCodeGen:
121
125
}
122
126
(decls, stmt.fold(stmts)(x => stmts :+ x))
123
127
124
- private def codegenJumpWithCall (defn : DefnRef , args : Ls [TrivialExpr ], storeInto : Opt [Str ])(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
128
+ def codegenJumpWithCall (defn : DefnRef , args : Ls [TrivialExpr ], storeInto : Opt [Str ])(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
125
129
val call = Expr .Call (Expr .Var (defn.name |> mapName), args.map(toExpr))
126
130
val stmts2 = stmts ++ Ls (storeInto.fold(Stmt .Return (call))(x => Stmt .Assign (x, call)))
127
131
(decls, stmts2)
128
132
129
- private def codegenOps (op : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ) = op match
133
+ def codegenOps (op : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ) = op match
130
134
case " +" => Expr .Binary (" +" , toExpr(args(0 )), toExpr(args(1 )))
131
135
case " -" => Expr .Binary (" -" , toExpr(args(0 )), toExpr(args(1 )))
132
136
case " *" => Expr .Binary (" *" , toExpr(args(0 )), toExpr(args(1 )))
@@ -141,22 +145,21 @@ class CppCodeGen:
141
145
case " &&" => Expr .Binary (" &&" , toExpr(args(0 )), toExpr(args(1 )))
142
146
case " ||" => Expr .Binary (" ||" , toExpr(args(0 )), toExpr(args(1 )))
143
147
case " !" => Expr .Unary (" !" , toExpr(args(0 )))
144
- case _ => ???
148
+ case _ => mlscript.utils. TODO ( " codegenOps " )
145
149
146
150
147
- private def codegen (expr : IExpr )(using ctx : Ctx ): Expr = expr match
151
+ def codegen (expr : IExpr )(using ctx : Ctx ): Expr = expr match
148
152
case x @ (IExpr .Ref (_) | IExpr .Literal (_)) => toExpr(x, reifyUnit = true ).get
149
153
case IExpr .CtorApp (cls, args) => mlsNewValue(cls.name |> mapName, args.map(toExpr))
150
154
case IExpr .Select (name, cls, field) => Expr .Member (mlsAsUnchecked(name |> mapName, cls.name |> mapName), field |> mapName)
151
155
case IExpr .BasicOp (name, args) => codegenOps(name, args)
152
- // TODO: Implement this
153
- case IExpr .AssignField (assignee, cls, field, value) => ???
156
+ case IExpr .AssignField (assignee, cls, field, value) => mlscript.utils.TODO (" Assign field in the backend" )
154
157
155
- private def codegenBuiltin (names : Ls [Name ], builtin : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ): Ls [Stmt ] = builtin match
158
+ def codegenBuiltin (names : Ls [Name ], builtin : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ): Ls [Stmt ] = builtin match
156
159
case " error" => Ls (Stmt .Raw (" throw std::runtime_error(\" Error\" );" ), Stmt .AutoBind (names.map(mapName), mlsNeverValue(names.size)))
157
160
case _ => Ls (Stmt .AutoBind (names.map(mapName), Expr .Call (Expr .Var (" _mls_builtin_" + builtin), args.map(toExpr))))
158
161
159
- private def codegen (body : Node , storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) = body match
162
+ def codegen (body : Node , storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) = body match
160
163
case Node .Result (res) =>
161
164
val expr = wrapMultiValues(res)
162
165
val stmts2 = stmts ++ Ls (Stmt .Assign (storeInto, expr))
@@ -173,15 +176,6 @@ class CppCodeGen:
173
176
val call = mlsMethodCall(cls, method.str |> mapName, args.map(toExpr))
174
177
val stmts2 = stmts ++ Ls (Stmt .AutoBind (names.map(mapName), call))
175
178
codegen(body, storeInto)(using decls, stmts2)
176
- // Use method calls instead of apply
177
- //
178
- // case Node.LetApply(names, fn, args, body) if fn.str == "builtin" =>
179
- // val stmts2 = stmts ++ codegenBuiltin(names, args.head.toString.replace("\"", ""), args.tail)
180
- // codegen(body, storeInto)(using decls, stmts2)
181
- // case Node.LetApply(names, fn, args, body) =>
182
- // val call = mlsCall(fn.str |> mapName, args.map(toExpr))
183
- // val stmts2 = stmts ++ Ls(Stmt.AutoBind(names.map(mapName), call))
184
- // codegen(body, storeInto)(using decls, stmts2)
185
179
case Node .LetCall (names, defn, args, _, body) =>
186
180
val call = Expr .Call (Expr .Var (defn.name |> mapName), args.map(toExpr))
187
181
val stmts2 = stmts ++ Ls (Stmt .AutoBind (names.map(mapName), call))
@@ -199,7 +193,7 @@ class CppCodeGen:
199
193
val decl = Decl .FuncDecl (mlsValType, name |> mapName, params.map(x => mlsValType))
200
194
(theDef, decl)
201
195
202
- private def codegenTopNode (node : Node )(using ctx : Ctx ): (Def , Decl ) =
196
+ def codegenTopNode (node : Node )(using ctx : Ctx ): (Def , Decl ) =
203
197
val decls = Ls (mlsRetValueDecl)
204
198
val stmts = Ls .empty[Stmt ]
205
199
val (decls2, stmts2) = codegen(node, mlsRetValue)(using decls, stmts)
@@ -208,27 +202,28 @@ class CppCodeGen:
208
202
val decl = Decl .FuncDecl (mlsValType, mlsMainName, Ls ())
209
203
(theDef, decl)
210
204
211
- private def sortClasses (prog : Program ): Ls [ClassInfo ] =
205
+ // Topological sort of classes based on inheritance relationships
206
+ def sortClasses (prog : Program ): Ls [ClassInfo ] =
212
207
var depgraph = prog.classes.map(x => (x.name, x.parents)).toMap
213
208
var degree = depgraph.view.mapValues(_.size).toMap
214
209
def removeNode (node : Str ) =
215
210
degree -= node
216
211
depgraph -= node
217
212
depgraph = depgraph.view.mapValues(_.filter(_ != node)).toMap
218
213
degree = depgraph.view.mapValues(_.size).toMap
219
- var sorted = Ls .empty[ClassInfo ]
214
+ val sorted = ListBuffer .empty[ClassInfo ]
220
215
var work = degree.filter(_._2 == 0 ).keys.toSet
221
216
while work.nonEmpty do
222
217
val node = work.head
223
218
work -= node
224
- sorted = sorted :+ prog.classes.find(_.name == node).get
219
+ sorted.addOne( prog.classes.find(_.name == node).get)
225
220
removeNode(node)
226
221
val next = degree.filter(_._2 == 0 ).keys
227
- work = work ++ next
222
+ work ++= next
228
223
if depgraph.nonEmpty then
229
224
val cycle = depgraph.keys.mkString(" , " )
230
225
throw new Exception (s " Cycle detected in class hierarchy: $cycle" )
231
- sorted
226
+ sorted.toList
232
227
233
228
def codegen (prog : Program ): CompilationUnit =
234
229
val sortedClasses = sortClasses(prog)
0 commit comments