Skip to content

Tail Recursion Optimization #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
8b73a24
move changes from tailrec-opt to new branch
CAG2Mark Mar 13, 2024
93bf78f
Optimize strongly connected components
CAG2Mark Mar 23, 2024
9b65640
update map braces
CAG2Mark Mar 24, 2024
915891d
small refactor
CAG2Mark Mar 24, 2024
6f2ca3d
refactor
CAG2Mark Mar 24, 2024
f987e02
Update test infrastructure, fix code, add basic test for tailrec.
CAG2Mark Mar 29, 2024
184a5de
Update test
CAG2Mark Mar 29, 2024
3066c7f
remove todos
CAG2Mark Mar 29, 2024
d7b4bb3
Add field assignment to IR
CAG2Mark Apr 8, 2024
058a929
Prevent unnecessary inlining for mutually tail recursive funcs, handl…
CAG2Mark Apr 14, 2024
0e3a875
Update test
CAG2Mark Apr 14, 2024
dd50c7a
Improved tests
CAG2Mark Apr 14, 2024
fc282e7
Fix bugs and cases
CAG2Mark Apr 14, 2024
edf7c89
Interpret field assignment
CAG2Mark Apr 20, 2024
c08f31b
add class info to field assignment
CAG2Mark Apr 25, 2024
518ffb4
propagate tailrec, fix tailrec parsing issue
CAG2Mark Apr 27, 2024
d00d072
progress, fix bug
CAG2Mark Apr 29, 2024
e45412b
Detect mod cons tail calls
CAG2Mark Apr 29, 2024
fc473a4
Refactor tail call discovery
CAG2Mark Apr 29, 2024
22dd518
change test
CAG2Mark Apr 29, 2024
c1d90d9
add test, verify mod cons call discovery works
CAG2Mark Apr 29, 2024
16134c2
add tests, improve formatting
CAG2Mark Apr 29, 2024
e9d6c54
remove newlines
CAG2Mark Apr 29, 2024
dcd616b
add tostring, improve formatting
CAG2Mark Apr 30, 2024
8b91f91
remove println
CAG2Mark Apr 30, 2024
8d6e14f
actually handle single tail recursive
CAG2Mark Apr 30, 2024
6213d51
properly attach tags
CAG2Mark Apr 30, 2024
e778810
progress, handle join points properly
CAG2Mark May 1, 2024
934c324
fix test
CAG2Mark May 1, 2024
b27cf3a
Done
CAG2Mark May 1, 2024
31bd2f5
re-enable normal tailrec optimization, clean up code
CAG2Mark May 1, 2024
0e64e67
Merge branch 'mlscript' into tailrec
CAG2Mark May 1, 2024
715ef73
Update tests
CAG2Mark May 1, 2024
d8d05be
fix some tests
CAG2Mark May 12, 2024
5887014
Propagate tailrec, fix join point infinite recursion bug
CAG2Mark May 12, 2024
78da910
refactor and check @tailrec for function definitions
CAG2Mark May 12, 2024
42b264f
update
CAG2Mark May 12, 2024
2199e37
make AssiggnField an Expr instead of a Node
CAG2Mark Jun 1, 2024
eacf538
Fix unsafe partial destruction
CAG2Mark Jun 1, 2024
214f9d3
rename @tailrec to @tailcall for call-level annotations
CAG2Mark Jun 1, 2024
51af35f
Propagate positions of calls and @tailrec annotations
CAG2Mark Jun 1, 2024
8ce1b35
Add error reporting to IR diff tests, report tailrec IR errors, fix bug
CAG2Mark Jun 1, 2024
ee15969
fix tests
CAG2Mark Jun 1, 2024
5af3621
fix tests
CAG2Mark Jun 1, 2024
487279e
Fix grammar
CAG2Mark Jun 1, 2024
a4efc7a
Purity check, fix join point bug
CAG2Mark Jun 5, 2024
0c8dec1
Build in true/false class in IR
CAG2Mark Jun 5, 2024
1ea98f1
restore nuscratch changes
CAG2Mark Jun 5, 2024
abcadfe
Ensure no function name clashes
CAG2Mark Jun 5, 2024
ea8d442
Document, add undefined literal and address remaining todos
CAG2Mark Jun 5, 2024
0125b1e
use unitlit instead of a new literal type for undefined
CAG2Mark Jun 5, 2024
dd8f2ae
Document @tailcall and @tailrec
CAG2Mark Jun 10, 2024
398bd9e
properly raise errors in IR builder
CAG2Mark Jun 17, 2024
4f93b76
fix some documentation
CAG2Mark Jun 17, 2024
f1b5fcb
Merge master
CAG2Mark Jun 17, 2024
f479ecc
remove bad copy paste
CAG2Mark Jun 17, 2024
49a4993
fix typo
CAG2Mark Jun 17, 2024
875484c
Improve todo comment
CAG2Mark Jun 17, 2024
92039ea
respond to code review
CAG2Mark Jun 18, 2024
22d4f47
Improve test format
CAG2Mark Jun 18, 2024
263b82a
Merge master
CAG2Mark Jun 18, 2024
d0923c3
re-run tests to fix line numbers
CAG2Mark Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions compiler/shared/main/scala/mlscript/compiler/ir/Builder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,10 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres

import scala.collection.mutable.{ HashSet => MutHSet }

val cls = grouped.getOrElse(0, Nil).map(buildClassInfo)
val cls = ClassInfo(classUid.make, "True", List()) // TODO: add "True" amd "False" at some pointgrouped.getOrElse(0, Nil).map(buildClassInfo)
:: ClassInfo(classUid.make, "False", List())
:: grouped.getOrElse(0, Nil).map(buildClassInfo)

cls.foldLeft(Set.empty)(checkDuplicateField(_, _))

val clsinfo = cls.toSet
Expand Down Expand Up @@ -342,10 +345,5 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres

resolveDefnRef(main, defs, true)
validate(main, defs)

// TODO: should properly import built-in types
val clsWithBool = clsinfo
+ ClassInfo(classUid.make, "True", List())
+ ClassInfo(classUid.make, "False", List())

Program(clsinfo, defs, main)
4 changes: 2 additions & 2 deletions compiler/shared/main/scala/mlscript/compiler/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ enum Expr:
case Literal(IntLit(lit)) => s"$lit" |> raw
case Literal(DecLit(lit)) => s"$lit" |> raw
case Literal(StrLit(lit)) => s"$lit" |> raw
case Literal(UnitLit(lit)) => s"$lit" |> raw
case Literal(UnitLit(lit)) => (if lit then "undefined" else "null") |> raw
case CtorApp(ClassInfo(_, name, _), args) =>
raw(name) <#> raw("(") <#> raw(args |> show_args) <#> raw(")")
case Select(s, _, fld) =>
Expand Down Expand Up @@ -297,7 +297,7 @@ enum LocMarker:
case MLit(IntLit(lit)) => s"$lit" |> raw
case MLit(DecLit(lit)) => s"$lit" |> raw
case MLit(StrLit(lit)) => s"$lit" |> raw
case MLit(UnitLit(lit)) => s"$lit" |> raw
case MLit(UnitLit(lit)) => (if lit then "undefined" else "null") |> raw
case _ => raw("...")

def show = s"$tag-" + toDocument.print
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,62 @@ import Message.MessageContext
import compiler.ir._
import compiler.ir.Node._

/*

DOCUMENTATION OF SEMANTICS OF @tailcall and @tailrec

@tailcall: Used to annotate specific function calls. Calls annotated with @tailcall
must be tail calls or tail modulo-cons calls. These calls may be optimized to not
consume additional stack space. If such an optimization is not possible, then the
compiler will throw an error.

If there are multiple possible candidates for tail modulo-cons calls in a single
branch of an expression, then @tailcall can be uesd to indicate which one will be
optimized. For instance in

fun foo() =
A(foo(), bar())

we can use @tailcall to annotate the call foo() or bar(). If a call other than the
last call is annotated with @tailcall, then the remaining functions must be pure
to ensure that reordering computations does not change the result.

If bar() is impure but you still want to optimize the call foo(), then you can do

fun foo() =
let b = bar()
let a = @tailcall foo()
A(a, b)

because here, you are taking responsibility for the reordering of the computations.

@tailrec: Used to annotate functions. When this annotation is used on a function, say
@tailrec fun foo(), the compiler will ensure no sequence of recursive calls back to foo()
consume stack space, i.e. they are all tail calls. Note that a call to foo() may
consume an arbitrary amount of stack space as long as foo() is only consuming finite
stack space. For example,

@tailrec fun foo() = bar()
fun bar() =
bar()
bar()

is valid. However,

@tailrec fun foo() = bar()
fun bar() =
foo()
bar()

is invalid. If we swap the position of foo() and bar() in the body of bar, it is still invalid.

Equivalently, if fun foo() is annotated with @tailrec, let S be the largest strongly
connected component in the call-graph of the program that contains foo. Then an error
will be thrown if and only if all edges (calls) connecting the nodes of the strongly
connected component are tail calls or tail modulo-cons calls.

*/

// fnUid should be the same FreshInt that was used to build the graph being passed into this class
class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diagnostic => Unit):
case class LetCtorNodeInfo(node: LetExpr, ctor: Expr.CtorApp, cls: ClassInfo, ctorValName: Name, fieldName: String, idx: Int)
Expand Down Expand Up @@ -583,12 +639,14 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag
val trueClass = classes.find(c => c.ident == "True").get
val falseClass = classes.find(c => c.ident == "False").get

// CONOTEXT APPLICATION

// CONTEXT APPLICATION
val mergedNames = defns.foldLeft("")(_ + "_" + _.name)

val ctxAppName = mergedNames + "_ctx_app"
val ctxCompName = mergedNames + "_ctx_comp"
val ctxAppId = fnUid.make
val ctxAppName = mergedNames + "_ctx_app$" + ctxAppId
val ctxCompId = fnUid.make
val ctxCompName = mergedNames + "_ctx_comp$" + ctxCompId

// map integers to classes and fields which will be assigned to
val classIdMap = classes.map(c => c.id -> c).toMap
Expand Down Expand Up @@ -649,7 +707,7 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag
)
).attachTag(tag)

val appDefn = Defn(fnUid.make, ctxAppName, List(appCtxName, appValName), 1, appNode, false)
val appDefn = Defn(ctxAppId, ctxAppName, List(appCtxName, appValName), 1, appNode, false)

// CONTEXT COMPOSITION
val cmpCtx1Name = Name("ctx1")
Expand Down Expand Up @@ -690,7 +748,7 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag
).attachTag(tag)
).attachTag(tag)

val cmpDefn = Defn(fnUid.make, ctxCompName, List(cmpCtx1Name, cmpCtx2Name), 1, cmpNode, false)
val cmpDefn = Defn(ctxCompId, ctxCompName, List(cmpCtx1Name, cmpCtx2Name), 1, cmpNode, false)

// We use tags to identify nodes
// a bit hacky but it's the most elegant way
Expand Down Expand Up @@ -790,7 +848,8 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag

def rewriteDefn(d: Defn): Defn =
val transformed = transformNode(d.body)
Defn(fnUid.make, d.name + "_modcons", Name("ctx") :: d.params, d.resultNum, transformed, d.isTailRec)
val id = fnUid.make
Defn(id, d.name + "_modcons$" + id, Name("ctx") :: d.params, d.resultNum, transformed, d.isTailRec)

// returns (new defn, mod cons defn)
// where new defn has the same signature and ids as the original, but immediately calls the mod cons defn
Expand Down Expand Up @@ -843,6 +902,8 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag
// To build the case block, we need to compare integers and check if the result is "True"
val trueClass = classes.find(c => c.ident == "True").get
val falseClass = classes.find(c => c.ident == "False").get
// undefined for dummy values
val dummyVal = Expr.Literal(UnitLit(true))

// join points need to be rewritten. For now, just combine them into the rest of the function. They will be inlined anyways
val defns = component.nodes ++ component.joinPoints
Expand Down Expand Up @@ -913,10 +974,10 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag

val stackFrame = trName :: defnsList.flatMap(d => d.params.map(n => nameMaps(d.id)(n))) // take union of stack frames

// TODO: This works fine for now, but ideally should find a way to guarantee the new
// name is unique
val newName = defns.foldLeft("")(_ + "_" + _.name) + "_opt"
val jpName = defns.foldLeft("")(_ + "_" + _.name) + "_opt_jp"
val newId = fnUid.make
val newName = defns.foldLeft("")(_ + "_" + _.name) + "_opt$" + newId
val jpId = fnUid.make
val jpName = defns.foldLeft("")(_ + "_" + _.name) + "_opt_jp$" + jpId

val newDefnRef = DefnRef(Right(newName))
val jpDefnRef = DefnRef(Right(jpName))
Expand All @@ -941,15 +1002,16 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag
Jump(jpDefnRef, transformStackFrame(args, defnInfoMap(defn.expectDefn.id))).attachTag(tag)
else LetCall(names, defn, args, isTailRec, transformNode(body))().attachTag(tag)

// Tail calls to another function in the component will be replaced with a tail call
// Tail calls to another function in the component will be replaced with a call
// to the merged function
// i.e. for mutually tailrec functions f(a, b) and g(c, d),
// f's body will be replaced with a call f_g(a, b, *, *), where * is a dummy value
def transformDefn(defn: Defn): Defn =
// TODO: Figure out how to substitute variables with dummy variables.
val info = defnInfoMap(defn.id)

val start =
stackFrame.take(info.stackFrameIdx).drop(1).map { _ => Expr.Literal(IntLit(0)) } // we drop tailrecBranch and replace it with the defn id
val end = stackFrame.drop(info.stackFrameIdx + defn.params.size).map { _ => Expr.Literal(IntLit(0)) }
stackFrame.take(info.stackFrameIdx).drop(1).map { _ => dummyVal } // we drop tailrecBranch and replace it with the defn id
val end = stackFrame.drop(info.stackFrameIdx + defn.params.size).map { _ => dummyVal }
val args = asLit(info.defn.id) :: start ::: defn.params.map(Expr.Ref(_)) ::: end

// We use a let call instead of a jump to avoid newDefn from being turned into a join point,
Expand Down Expand Up @@ -978,10 +1040,10 @@ class TailRecOpt(fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diag

val newNode = makeSwitch(trName, valsAndNodes.tail, valsAndNodes.head._2)(trueClass, falseClass)

val jpDefn = Defn(fnUid.make, jpName, stackFrame, resultNum, newNode, false)
val jpDefn = Defn(jpId, jpName, stackFrame, resultNum, newNode, false)

val jmp = Jump(jpDefnRef, stackFrame.map(Expr.Ref(_))).attachTag(tag)
val newDefn = Defn(fnUid.make, newName, stackFrame, resultNum, jmp, defnsNoJp.find { _.isTailRec }.isDefined )
val newDefn = Defn(newId, newName, stackFrame, resultNum, jmp, defnsNoJp.find { _.isTailRec }.isDefined )

jpDefnRef.defn = Left(jpDefn)
newDefnRef.defn = Left(newDefn)
Expand Down
26 changes: 12 additions & 14 deletions compiler/shared/test/diff-ir/IR.mls
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ foo()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, mktup2, [x$0,y$0],
//│ 1,
//│ let* (x$1) = mktup(x$0,y$0) in -- #7
Expand Down Expand Up @@ -55,7 +55,7 @@ bar()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, foo, [pair$0],
//│ 1,
//│ case pair$0 of -- #16
Expand Down Expand Up @@ -105,7 +105,7 @@ foo()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, silly, [pair$0],
//│ 1,
//│ let x$0 = 0 in -- #29
Expand Down Expand Up @@ -163,7 +163,7 @@ foo()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, inc_fst, [pair$0],
//│ 1,
//│ let x$0 = 2 in -- #15
Expand Down Expand Up @@ -208,7 +208,7 @@ foo()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, inc_fst, [pair$0],
//│ 1,
//│ let x$0 = 0 in -- #15
Expand Down Expand Up @@ -256,7 +256,7 @@ bar()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Left, [x]),ClassInfo(1, Right, [y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Left, [x]),ClassInfo(3, Right, [y])}, {
//│ Def(0, foo, [a$0,b$0],
//│ 1,
//│ case a$0 of -- #36
Expand Down Expand Up @@ -298,15 +298,13 @@ bar()
//│ 2

:interpIR
class True
class False
class Pair(x, y)
fun foo(a) = a.x + a.y
fun bar() =
foo(Pair(1, 0))
bar()
//│ |#class| |True|↵|#class| |False|↵|#class| |Pair|(|x|,| |y|)|↵|#fun| |foo|(|a|)| |#=| |a|.x| |+| |a|.y|↵|#fun| |bar|(||)| |#=|→|foo|(|Pair|(|1|,| |0|)|)|←|↵|bar|(||)|
//│ Parsed: {class True {}; class False {}; class Pair(x, y,) {}; fun foo = (a,) => +((a).x,)((a).y,); fun bar = () => {foo(Pair(1, 0,),)}; bar()}
//│ |#class| |Pair|(|x|,| |y|)|↵|#fun| |foo|(|a|)| |#=| |a|.x| |+| |a|.y|↵|#fun| |bar|(||)| |#=|→|foo|(|Pair|(|1|,| |0|)|)|←|↵|bar|(||)|
//│ Parsed: {class Pair(x, y,) {}; fun foo = (a,) => +((a).x,)((a).y,); fun bar = () => {foo(Pair(1, 0,),)}; bar()}
//│
//│
//│ IR:
Expand Down Expand Up @@ -350,7 +348,7 @@ bar()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, C1, [x,y]),ClassInfo(1, C2, [z])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, C1, [x,y]),ClassInfo(3, C2, [z])}, {
//│ Def(0, foo, [a$0],
//│ 1,
//│ case a$0 of -- #15
Expand Down Expand Up @@ -401,7 +399,7 @@ baz()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, foo, [a$0,b$0],
//│ 1,
//│ let x$0 = a$0.x in -- #21
Expand Down Expand Up @@ -451,7 +449,7 @@ foo()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, Pair, [x,y])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, Pair, [x,y])}, {
//│ Def(0, foo, [],
//│ 1,
//│ let x$0 = Pair(0,1) in -- #10
Expand Down Expand Up @@ -484,7 +482,7 @@ foo()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, S, [s]),ClassInfo(1, O, [])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, S, [s]),ClassInfo(3, O, [])}, {
//│ Def(0, foo, [],
//│ 1,
//│ let x$0 = O() in -- #10
Expand Down
6 changes: 3 additions & 3 deletions compiler/shared/test/diff-ir/IRComplex.mls
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ bar()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, A, [x,y,z]),ClassInfo(1, B, [m,n])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, A, [x,y,z]),ClassInfo(3, B, [m,n])}, {
//│ Def(0, complex_foo, [t$0],
//│ 1,
//│ case t$0 of -- #63
Expand Down Expand Up @@ -117,7 +117,7 @@ bar()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, A, [w,x]),ClassInfo(1, B, [y]),ClassInfo(2, C, [z])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, A, [w,x]),ClassInfo(3, B, [y]),ClassInfo(4, C, [z])}, {
//│ Def(0, complex_foo, [t$0],
//│ 1,
//│ let x$0 = +(1,2) in -- #140
Expand Down Expand Up @@ -256,7 +256,7 @@ bar()
//│ IR:
//│
//│ Promoted:
//│ Program({ClassInfo(0, A, [w,x]),ClassInfo(1, B, [y]),ClassInfo(2, C, [z])}, {
//│ Program({ClassInfo(0, True, []),ClassInfo(1, False, []),ClassInfo(2, A, [w,x]),ClassInfo(3, B, [y]),ClassInfo(4, C, [z])}, {
//│ Def(0, complex_foo, [t$0],
//│ 1,
//│ let x$0 = +(1,2) in -- #150
Expand Down
Loading