@@ -4,67 +4,68 @@ import mlscript.compiler.ir.{Expr => IExpr, _}
4
4
import mlscript .compiler .utils ._
5
5
import mlscript .utils ._
6
6
import mlscript .utils .shorthands ._
7
+ import scala .collection .mutable .ListBuffer
7
8
9
+ def codegen (prog : Program ): CompilationUnit =
10
+ val codegen = CppCodeGen ()
11
+ codegen.codegen(prog)
8
12
9
- class CppCodeGen :
10
- private def mapName (name : Name ): Str = " _mls_" + name.str.replace('$' , '_' ).replace('\' ' , '_' )
11
- private def mapName (name : Str ): Str = " _mls_" + name.replace('$' , '_' ).replace('\' ' , '_' )
12
- private val freshName = Fresh (div = '_' );
13
- private val mlsValType = Type .Prim (" _mlsValue" )
14
- private val mlsUnitValue = Expr .Call (Expr .Var (" _mlsValue::create<_mls_Unit>" ), Ls ());
15
- private val mlsRetValue = " _mls_retval"
16
- private val mlsRetValueDecl = Decl .VarDecl (mlsRetValue, mlsValType)
17
- private val mlsMainName = " _mlsMain"
18
- private val mlsPrelude = " #include \" mlsprelude.h\" "
19
- private val mlsPreludeImpl = " #include \" mlsprelude.cpp\" "
20
- private val mlsInternalClass = Set (" True" , " False" , " Boolean" , " Callable" )
21
- private val mlsObject = " _mlsObject"
22
- private val mlsBuiltin = " builtin"
23
- private val mlsEntryPoint = s " int main() { return _mlsLargeStack(_mlsMainWrapper); } " ;
24
- private def mlsIntLit (x : BigInt ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .IntLit (x)))
25
- private def mlsStrLit (x : Str ) = Expr .Call (Expr .Var (" _mlsValue::fromStrLit" ), Ls (Expr .StrLit (x)))
26
- private def mlsCharLit (x : Char ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .CharLit (x)))
27
- private def mlsNewValue (cls : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (s " _mlsValue::create< $cls> " ), args)
28
- private def mlsIsValueOf (cls : Str , scrut : Expr ) = Expr .Call (Expr .Var (s " _mlsValue::isValueOf< $cls> " ), Ls (scrut))
29
- private def mlsIsIntLit (scrut : Expr , lit : mlscript.IntLit ) = Expr .Call (Expr .Var (" _mlsValue::isIntLit" ), Ls (scrut, Expr .IntLit (lit.value)))
30
- private def mlsIsCharLit (scrut : Expr , lit : mlscript.CharLit ) = Expr .Call (Expr .Var (" _mlsValue::isCharLit" ), Ls (scrut, Expr .CharLit (lit.value)))
31
- private def mlsDebugPrint (x : Expr ) = Expr .Call (Expr .Var (" _mlsValue::print" ), Ls (x))
32
- private def mlsTupleValue (init : Expr ) = Expr .Constructor (" _mlsValue::tuple" , init)
33
- private def mlsAs (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::as< $cls>( $name) " )
34
- private def mlsAsUnchecked (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::cast< $cls>( $name) " )
35
- private def mlsObjectNameMethod (name : Str ) = s " constexpr static inline const char *typeName = \" ${name}\" ; "
36
- private def mlsTypeTag () = s " constexpr static inline uint32_t typeTag = nextTypeTag(); "
37
- private def mlsTypeTag (n : Int ) = s " constexpr static inline uint32_t typeTag = $n; "
38
- private def mlsCommonCreateMethod (cls : Str , fields : Ls [Str ], id : Int ) =
13
+ private class CppCodeGen :
14
+ def mapName (name : Name ): Str = " _mls_" + name.str.replace('$' , '_' ).replace('\' ' , '_' )
15
+ def mapName (name : Str ): Str = " _mls_" + name.replace('$' , '_' ).replace('\' ' , '_' )
16
+ val freshName = Fresh (div = '_' );
17
+ val mlsValType = Type .Prim (" _mlsValue" )
18
+ val mlsUnitValue = Expr .Call (Expr .Var (" _mlsValue::create<_mls_Unit>" ), Ls ());
19
+ val mlsRetValue = " _mls_retval"
20
+ val mlsRetValueDecl = Decl .VarDecl (mlsRetValue, mlsValType)
21
+ val mlsMainName = " _mlsMain"
22
+ val mlsPrelude = " #include \" mlsprelude.h\" "
23
+ val mlsPreludeImpl = " #include \" mlsprelude.cpp\" "
24
+ val mlsInternalClass = Set (" True" , " False" , " Boolean" , " Callable" )
25
+ val mlsObject = " _mlsObject"
26
+ val mlsBuiltin = " builtin"
27
+ val mlsEntryPoint = s " int main() { return _mlsLargeStack(_mlsMainWrapper); } " ;
28
+ def mlsIntLit (x : BigInt ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .IntLit (x)))
29
+ def mlsStrLit (x : Str ) = Expr .Call (Expr .Var (" _mlsValue::fromStrLit" ), Ls (Expr .StrLit (x)))
30
+ def mlsCharLit (x : Char ) = Expr .Call (Expr .Var (" _mlsValue::fromIntLit" ), Ls (Expr .CharLit (x)))
31
+ def mlsNewValue (cls : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (s " _mlsValue::create< $cls> " ), args)
32
+ def mlsIsValueOf (cls : Str , scrut : Expr ) = Expr .Call (Expr .Var (s " _mlsValue::isValueOf< $cls> " ), Ls (scrut))
33
+ def mlsIsIntLit (scrut : Expr , lit : mlscript.IntLit ) = Expr .Call (Expr .Var (" _mlsValue::isIntLit" ), Ls (scrut, Expr .IntLit (lit.value)))
34
+ def mlsIsCharLit (scrut : Expr , lit : mlscript.CharLit ) = Expr .Call (Expr .Var (" _mlsValue::isCharLit" ), Ls (scrut, Expr .CharLit (lit.value)))
35
+ def mlsDebugPrint (x : Expr ) = Expr .Call (Expr .Var (" _mlsValue::print" ), Ls (x))
36
+ def mlsTupleValue (init : Expr ) = Expr .Constructor (" _mlsValue::tuple" , init)
37
+ def mlsAs (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::as< $cls>( $name) " )
38
+ def mlsAsUnchecked (name : Str , cls : Str ) = Expr .Var (s " _mlsValue::cast< $cls>( $name) " )
39
+ def mlsObjectNameMethod (name : Str ) = s " constexpr static inline const char *typeName = \" ${name}\" ; "
40
+ def mlsTypeTag () = s " constexpr static inline uint32_t typeTag = nextTypeTag(); "
41
+ def mlsTypeTag (n : Int ) = s " constexpr static inline uint32_t typeTag = $n; "
42
+ def mlsCommonCreateMethod (cls : Str , fields : Ls [Str ], id : Int ) =
39
43
val parameters = fields.map{x => s " _mlsValue $x" }.mkString(" , " )
40
44
val fieldsAssignment = fields.map{x => s " _mlsVal-> $x = $x; " }.mkString
41
45
s " static _mlsValue create( $parameters) { auto _mlsVal = new (std::align_val_t(_mlsAlignment)) $cls; _mlsVal->refCount = 1; _mlsVal->tag = typeTag; $fieldsAssignment return _mlsValue(_mlsVal); } "
42
- private def mlsCommonPrintMethod (fields : Ls [Str ]) =
46
+ def mlsCommonPrintMethod (fields : Ls [Str ]) =
43
47
if fields.isEmpty then s " virtual void print() const override { std::printf( \" %s \" , typeName); } "
44
48
else
45
49
val fieldsPrint = fields.map{x => s " this-> $x.print(); " }.mkString(" std::printf(\" , \" ); " )
46
50
s " virtual void print() const override { std::printf( \" %s \" , typeName); std::printf( \" ( \" ); $fieldsPrint std::printf( \" ) \" ); } "
47
- private def mlsCommonDestructorMethod (cls : Str , fields : Ls [Str ]) =
51
+ def mlsCommonDestructorMethod (cls : Str , fields : Ls [Str ]) =
48
52
val fieldsDeletion = fields.map{x => s " _mlsValue::destroy(this-> $x); " }.mkString
49
53
s " virtual void destroy() override { $fieldsDeletion operator delete (this, std::align_val_t(_mlsAlignment)); } "
50
- private def mlsThrowNonExhaustiveMatch = Stmt .Raw (" _mlsNonExhaustiveMatch();" );
51
- private def mlsCall (fn : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (" _mlsCall" ), Expr .Var (fn) :: args)
52
- private def mlsMethodCall (cls : ClassRef , method : Str , args : Ls [Expr ]) =
54
+ def mlsThrowNonExhaustiveMatch = Stmt .Raw (" _mlsNonExhaustiveMatch();" );
55
+ def mlsCall (fn : Str , args : Ls [Expr ]) = Expr .Call (Expr .Var (" _mlsCall" ), Expr .Var (fn) :: args)
56
+ def mlsMethodCall (cls : ClassRef , method : Str , args : Ls [Expr ]) =
53
57
Expr .Call (Expr .Member (Expr .Call (Expr .Var (s " _mlsMethodCall< ${cls.name |> mapName}> " ), Ls (args.head)), method), args.tail)
54
- private def mlsFnWrapperName (fn : Str ) = s " _mlsFn_ $fn"
55
- private def mlsFnCreateMethod (fn : Str ) = s " static _mlsValue create() { static _mlsFn_ $fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); } "
56
- private def mlsFnApplyNMethod (fn : Str , n : Int ) =
57
- Def .FuncDef (Type .Qualifier (mlsValType, " virtual" ), s " apply $n" , (0 until n).map{x => (s " arg $x" , mlsValType)}.toList,
58
- Stmt .Block (Ls (), Ls (Stmt .Return (Expr .Call (Expr .Var (fn), (0 until n).map{x => Expr .Var (s " arg $x" )}.toList)))), true )
59
- private def mlsNeverValue (n : Int ) = if (n <= 1 ) then Expr .Call (Expr .Var (s " _mlsValue::never " ), Ls ()) else Expr .Call (Expr .Var (s " _mlsValue::never< $n> " ), Ls ())
58
+ def mlsFnWrapperName (fn : Str ) = s " _mlsFn_ $fn"
59
+ def mlsFnCreateMethod (fn : Str ) = s " static _mlsValue create() { static _mlsFn_ $fn mlsFn alignas(_mlsAlignment); mlsFn.refCount = stickyRefCount; mlsFn.tag = typeTag; return _mlsValue(&mlsFn); } "
60
+ def mlsNeverValue (n : Int ) = if (n <= 1 ) then Expr .Call (Expr .Var (s " _mlsValue::never " ), Ls ()) else Expr .Call (Expr .Var (s " _mlsValue::never< $n> " ), Ls ())
60
61
61
- private case class Ctx (
62
- val defnCtx : Set [Str ],
62
+ case class Ctx (
63
+ defnCtx : Set [Str ],
63
64
)
64
65
65
- private def codegenClassInfo (using ctx : Ctx )(cls : ClassInfo ): (Opt [Def ], Decl ) =
66
+ def codegenClassInfo (using ctx : Ctx )(cls : ClassInfo ): (Opt [Def ], Decl ) =
66
67
val fields = cls.fields.map{x => (x |> mapName, mlsValType)}
67
- val parents = if cls.parents.nonEmpty then cls.parents.toList.map{x => x |> mapName} else mlsObject :: Nil
68
+ val parents = if cls.parents.nonEmpty then cls.parents.toList.map( mapName) else mlsObject :: Nil
68
69
val decl = Decl .StructDecl (cls.name |> mapName)
69
70
if mlsInternalClass.contains(cls.name) then return (None , decl)
70
71
val theDef = Def .StructDef (
@@ -84,15 +85,15 @@ class CppCodeGen:
84
85
)
85
86
(S (theDef), decl)
86
87
87
- private def toExpr (texpr : TrivialExpr , reifyUnit : Bool = false )(using ctx : Ctx ): Opt [Expr ] = texpr match
88
+ def toExpr (texpr : TrivialExpr , reifyUnit : Bool = false )(using ctx : Ctx ): Opt [Expr ] = texpr match
88
89
case IExpr .Ref (name) => S (Expr .Var (name |> mapName))
89
90
case IExpr .Literal (mlscript.IntLit (x)) => S (mlsIntLit(x))
90
91
case IExpr .Literal (mlscript.DecLit (x)) => S (mlsIntLit(x.toBigInt))
91
92
case IExpr .Literal (mlscript.StrLit (x)) => S (mlsStrLit(x))
92
93
case IExpr .Literal (mlscript.CharLit (x)) => S (mlsCharLit(x))
93
94
case IExpr .Literal (mlscript.UnitLit (_)) => if reifyUnit then S (mlsUnitValue) else None
94
95
95
- private def toExpr (texpr : TrivialExpr )(using ctx : Ctx ): Expr = texpr match
96
+ def toExpr (texpr : TrivialExpr )(using ctx : Ctx ): Expr = texpr match
96
97
case IExpr .Ref (name) => Expr .Var (name |> mapName)
97
98
case IExpr .Literal (mlscript.IntLit (x)) => mlsIntLit(x)
98
99
case IExpr .Literal (mlscript.DecLit (x)) => mlsIntLit(x.toBigInt)
@@ -101,13 +102,13 @@ class CppCodeGen:
101
102
case IExpr .Literal (mlscript.UnitLit (_)) => mlsUnitValue
102
103
103
104
104
- private def wrapMultiValues (exprs : Ls [TrivialExpr ])(using ctx : Ctx ): Expr = exprs match
105
+ def wrapMultiValues (exprs : Ls [TrivialExpr ])(using ctx : Ctx ): Expr = exprs match
105
106
case x :: Nil => toExpr(x, reifyUnit = true ).get
106
107
case _ =>
107
108
val init = Expr .Initializer (exprs.map{x => toExpr(x)})
108
109
mlsTupleValue(init)
109
110
110
- private def codegenCaseWithIfs (scrut : Name , cases : Ls [(Pat , Node )], default : Opt [Node ], storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
111
+ def codegenCaseWithIfs (scrut : Name , cases : Ls [(Pat , Node )], default : Opt [Node ], storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
111
112
val scrutName = mapName(scrut)
112
113
val init : Stmt =
113
114
default.fold(mlsThrowNonExhaustiveMatch)(x => {
@@ -131,12 +132,12 @@ class CppCodeGen:
131
132
}
132
133
(decls, stmt.fold(stmts)(x => stmts :+ x))
133
134
134
- private def codegenJumpWithCall (defn : DefnRef , args : Ls [TrivialExpr ], storeInto : Opt [Str ])(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
135
+ def codegenJumpWithCall (defn : DefnRef , args : Ls [TrivialExpr ], storeInto : Opt [Str ])(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) =
135
136
val call = Expr .Call (Expr .Var (defn.name |> mapName), args.map(toExpr))
136
137
val stmts2 = stmts ++ Ls (storeInto.fold(Stmt .Return (call))(x => Stmt .Assign (x, call)))
137
138
(decls, stmts2)
138
139
139
- private def codegenOps (op : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ) = op match
140
+ def codegenOps (op : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ) = op match
140
141
case " +" => Expr .Binary (" +" , toExpr(args(0 )), toExpr(args(1 )))
141
142
case " -" => Expr .Binary (" -" , toExpr(args(0 )), toExpr(args(1 )))
142
143
case " *" => Expr .Binary (" *" , toExpr(args(0 )), toExpr(args(1 )))
@@ -151,20 +152,20 @@ class CppCodeGen:
151
152
case " &&" => Expr .Binary (" &&" , toExpr(args(0 )), toExpr(args(1 )))
152
153
case " ||" => Expr .Binary (" ||" , toExpr(args(0 )), toExpr(args(1 )))
153
154
case " !" => Expr .Unary (" !" , toExpr(args(0 )))
154
- case _ => ???
155
+ case _ => mlscript.utils. TODO ( " codegenOps " )
155
156
156
157
157
- private def codegen (expr : IExpr )(using ctx : Ctx ): Expr = expr match
158
+ def codegen (expr : IExpr )(using ctx : Ctx ): Expr = expr match
158
159
case x @ (IExpr .Ref (_) | IExpr .Literal (_)) => toExpr(x, reifyUnit = true ).get
159
160
case IExpr .CtorApp (cls, args) => mlsNewValue(cls.name |> mapName, args.map(toExpr))
160
161
case IExpr .Select (name, cls, field) => Expr .Member (mlsAsUnchecked(name |> mapName, cls.name |> mapName), field |> mapName)
161
162
case IExpr .BasicOp (name, args) => codegenOps(name, args)
162
163
163
- private def codegenBuiltin (names : Ls [Name ], builtin : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ): Ls [Stmt ] = builtin match
164
+ def codegenBuiltin (names : Ls [Name ], builtin : Str , args : Ls [TrivialExpr ])(using ctx : Ctx ): Ls [Stmt ] = builtin match
164
165
case " error" => Ls (Stmt .Raw (" throw std::runtime_error(\" Error\" );" ), Stmt .AutoBind (names.map(mapName), mlsNeverValue(names.size)))
165
166
case _ => Ls (Stmt .AutoBind (names.map(mapName), Expr .Call (Expr .Var (" _mls_builtin_" + builtin), args.map(toExpr))))
166
167
167
- private def codegen (body : Node , storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) = body match
168
+ def codegen (body : Node , storeInto : Str )(using decls : Ls [Decl ], stmts : Ls [Stmt ])(using ctx : Ctx ): (Ls [Decl ], Ls [Stmt ]) = body match
168
169
case Node .Result (res) =>
169
170
val expr = wrapMultiValues(res)
170
171
val stmts2 = stmts ++ Ls (Stmt .Assign (storeInto, expr))
@@ -197,7 +198,7 @@ class CppCodeGen:
197
198
case Node .Case (scrut, cases, default) =>
198
199
codegenCaseWithIfs(scrut, cases, default, storeInto)
199
200
200
- private def codegenDefn (using ctx : Ctx )(defn : Defn ): (Def , Decl ) = defn match
201
+ def codegenDefn (using ctx : Ctx )(defn : Defn ): (Def , Decl ) = defn match
201
202
case Defn (id, name, params, resultNum, specialized, body) =>
202
203
val decls = Ls (mlsRetValueDecl)
203
204
val stmts = Ls .empty[Stmt ]
@@ -207,7 +208,7 @@ class CppCodeGen:
207
208
val decl = Decl .FuncDecl (mlsValType, name |> mapName, params.map(x => mlsValType))
208
209
(theDef, decl)
209
210
210
- private def codegenTopNode (node : Node )(using ctx : Ctx ): (Def , Decl ) =
211
+ def codegenTopNode (node : Node )(using ctx : Ctx ): (Def , Decl ) =
211
212
val decls = Ls (mlsRetValueDecl)
212
213
val stmts = Ls .empty[Stmt ]
213
214
val (decls2, stmts2) = codegen(node, mlsRetValue)(using decls, stmts)
@@ -216,27 +217,28 @@ class CppCodeGen:
216
217
val decl = Decl .FuncDecl (mlsValType, mlsMainName, Ls ())
217
218
(theDef, decl)
218
219
219
- private def sortClasses (prog : Program ): Ls [ClassInfo ] =
220
+ // Topological sort of classes based on inheritance relationships
221
+ def sortClasses (prog : Program ): Ls [ClassInfo ] =
220
222
var depgraph = prog.classes.map(x => (x.name, x.parents)).toMap
221
223
var degree = depgraph.view.mapValues(_.size).toMap
222
224
def removeNode (node : Str ) =
223
225
degree -= node
224
226
depgraph -= node
225
227
depgraph = depgraph.view.mapValues(_.filter(_ != node)).toMap
226
228
degree = depgraph.view.mapValues(_.size).toMap
227
- var sorted = Ls .empty[ClassInfo ]
229
+ val sorted = ListBuffer .empty[ClassInfo ]
228
230
var work = degree.filter(_._2 == 0 ).keys.toSet
229
231
while work.nonEmpty do
230
232
val node = work.head
231
233
work -= node
232
- sorted = sorted :+ prog.classes.find(_.name == node).get
234
+ sorted.addOne( prog.classes.find(_.name == node).get)
233
235
removeNode(node)
234
236
val next = degree.filter(_._2 == 0 ).keys
235
- work = work ++ next
237
+ work ++= next
236
238
if depgraph.nonEmpty then
237
239
val cycle = depgraph.keys.mkString(" , " )
238
240
throw new Exception (s " Cycle detected in class hierarchy: $cycle" )
239
- sorted
241
+ sorted.toList
240
242
241
243
def codegen (prog : Program ): CompilationUnit =
242
244
val sortedClasses = sortClasses(prog)
0 commit comments