Skip to content

Tail Recursion Optimization #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
8b73a24
move changes from tailrec-opt to new branch
CAG2Mark Mar 13, 2024
93bf78f
Optimize strongly connected components
CAG2Mark Mar 23, 2024
9b65640
update map braces
CAG2Mark Mar 24, 2024
915891d
small refactor
CAG2Mark Mar 24, 2024
6f2ca3d
refactor
CAG2Mark Mar 24, 2024
f987e02
Update test infrastructure, fix code, add basic test for tailrec.
CAG2Mark Mar 29, 2024
184a5de
Update test
CAG2Mark Mar 29, 2024
3066c7f
remove todos
CAG2Mark Mar 29, 2024
d7b4bb3
Add field assignment to IR
CAG2Mark Apr 8, 2024
058a929
Prevent unnecessary inlining for mutually tail recursive funcs, handl…
CAG2Mark Apr 14, 2024
0e3a875
Update test
CAG2Mark Apr 14, 2024
dd50c7a
Improved tests
CAG2Mark Apr 14, 2024
fc282e7
Fix bugs and cases
CAG2Mark Apr 14, 2024
edf7c89
Interpret field assignment
CAG2Mark Apr 20, 2024
c08f31b
add class info to field assignment
CAG2Mark Apr 25, 2024
518ffb4
propagate tailrec, fix tailrec parsing issue
CAG2Mark Apr 27, 2024
d00d072
progress, fix bug
CAG2Mark Apr 29, 2024
e45412b
Detect mod cons tail calls
CAG2Mark Apr 29, 2024
fc473a4
Refactor tail call discovery
CAG2Mark Apr 29, 2024
22dd518
change test
CAG2Mark Apr 29, 2024
c1d90d9
add test, verify mod cons call discovery works
CAG2Mark Apr 29, 2024
16134c2
add tests, improve formatting
CAG2Mark Apr 29, 2024
e9d6c54
remove newlines
CAG2Mark Apr 29, 2024
dcd616b
add tostring, improve formatting
CAG2Mark Apr 30, 2024
8b91f91
remove println
CAG2Mark Apr 30, 2024
8d6e14f
actually handle single tail recursive
CAG2Mark Apr 30, 2024
6213d51
properly attach tags
CAG2Mark Apr 30, 2024
e778810
progress, handle join points properly
CAG2Mark May 1, 2024
934c324
fix test
CAG2Mark May 1, 2024
b27cf3a
Done
CAG2Mark May 1, 2024
31bd2f5
re-enable normal tailrec optimization, clean up code
CAG2Mark May 1, 2024
0e64e67
Merge branch 'mlscript' into tailrec
CAG2Mark May 1, 2024
715ef73
Update tests
CAG2Mark May 1, 2024
d8d05be
fix some tests
CAG2Mark May 12, 2024
5887014
Propagate tailrec, fix join point infinite recursion bug
CAG2Mark May 12, 2024
78da910
refactor and check @tailrec for function definitions
CAG2Mark May 12, 2024
42b264f
update
CAG2Mark May 12, 2024
2199e37
make AssiggnField an Expr instead of a Node
CAG2Mark Jun 1, 2024
eacf538
Fix unsafe partial destruction
CAG2Mark Jun 1, 2024
214f9d3
rename @tailrec to @tailcall for call-level annotations
CAG2Mark Jun 1, 2024
51af35f
Propagate positions of calls and @tailrec annotations
CAG2Mark Jun 1, 2024
8ce1b35
Add error reporting to IR diff tests, report tailrec IR errors, fix bug
CAG2Mark Jun 1, 2024
ee15969
fix tests
CAG2Mark Jun 1, 2024
5af3621
fix tests
CAG2Mark Jun 1, 2024
487279e
Fix grammar
CAG2Mark Jun 1, 2024
a4efc7a
Purity check, fix join point bug
CAG2Mark Jun 5, 2024
0c8dec1
Build in true/false class in IR
CAG2Mark Jun 5, 2024
1ea98f1
restore nuscratch changes
CAG2Mark Jun 5, 2024
abcadfe
Ensure no function name clashes
CAG2Mark Jun 5, 2024
ea8d442
Document, add undefined literal and address remaining todos
CAG2Mark Jun 5, 2024
0125b1e
use unitlit instead of a new literal type for undefined
CAG2Mark Jun 5, 2024
dd8f2ae
Document @tailcall and @tailrec
CAG2Mark Jun 10, 2024
398bd9e
properly raise errors in IR builder
CAG2Mark Jun 17, 2024
4f93b76
fix some documentation
CAG2Mark Jun 17, 2024
f1b5fcb
Merge master
CAG2Mark Jun 17, 2024
f479ecc
remove bad copy paste
CAG2Mark Jun 17, 2024
49a4993
fix typo
CAG2Mark Jun 17, 2024
875484c
Improve todo comment
CAG2Mark Jun 17, 2024
92039ea
respond to code review
CAG2Mark Jun 18, 2024
22d4f47
Improve test format
CAG2Mark Jun 18, 2024
263b82a
Merge master
CAG2Mark Jun 18, 2024
d0923c3
re-run tests to fix line numbers
CAG2Mark Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,33 @@ import mlscript.IntLit
import mlscript.utils.shorthands.Bool

// fnUid should be the same FreshInt that was used to build the graph being passed into this class
class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
class TailRecOpt(fnUid: FreshInt, tag: FreshInt):
case class LetCtorNodeInfo(node: LetExpr, ctor: Expr.CtorApp, ctorValName: Name, fieldName: String)

enum CallInfo:
case TailCallInfo(src: Defn, defn: Defn, letCallNode: LetCall) extends CallInfo
case ModConsCallInfo(src: Defn, defn: Defn, letCallNode: LetCall, letCtorNode: LetCtorNodeInfo) extends CallInfo
case ModConsCallInfo(src: Defn, startNode: Node, defn: Defn, letCallNode: LetCall, letCtorNode: LetCtorNodeInfo) extends CallInfo

def getSrc = this match
case TailCallInfo(src, _, _) => src
case ModConsCallInfo(src, _, _, _) => src
case ModConsCallInfo(src, _, _, _, _) => src

def getDefn = this match
case TailCallInfo(_, defn, _) => defn
case ModConsCallInfo(_, defn, _, _) => defn
case ModConsCallInfo(_, _, defn, _, _) => defn


private class DefnGraph(val nodes: Set[DefnNode], val edges: Set[CallInfo]) {
private class DefnGraph(val nodes: Set[DefnNode], val edges: Set[CallInfo]):
def removeMetadata: ScComponent = ScComponent(nodes.map(_.defn), edges)
}


private class ScComponent(val nodes: Set[Defn], val edges: Set[CallInfo])

import CallInfo._

@tailrec
private def getOptimizableCalls(node: Node)(implicit
src: Defn,
start: Node,
calledDefn: Option[Defn],
letCallNode: Option[LetCall],
letCtorNode: Option[LetCtorNodeInfo],
Expand All @@ -48,7 +48,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
(calledDefn, letCallNode, letCtorNode, candReturnName) match
case (Some(defn), Some(letCallNode), Some(letCtorName), Some(candReturnName)) =>
if argsListEqual(List(candReturnName), res) then
Left(ModConsCallInfo(src, defn, letCallNode, letCtorName))
Left(ModConsCallInfo(src, start, defn, letCallNode, letCtorName))
else
returnFailure
case _ => returnFailure
Expand All @@ -57,7 +57,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
(calledDefn, letCallNode, letCtorNode, candReturnName) match
case (Some(defn), Some(letCallNode), Some(letCtorName), Some(candReturnName)) =>
if argsListEqual(List(candReturnName), args) && isIdentityJp(jp.expectDefn) then
Left(ModConsCallInfo(src, defn, letCallNode, letCtorName))
Left(ModConsCallInfo(src, start, defn, letCallNode, letCtorName))
else
returnFailure
case _ => returnFailure
Expand All @@ -76,7 +76,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
// if the is marked as tail recursive, we must use that call as the mod cons call, so error. otherwise,
// invalidate the discovered call and continue
if isTailRec then throw IRError("not a mod cons call")
else getOptimizableCalls(body)(src, None, None, None, None) // invalidate everything that's been discovered
else getOptimizableCalls(body)(src, start, None, None, None, None) // invalidate everything that's been discovered
else
getOptimizableCalls(body) // OK

Expand All @@ -97,7 +97,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
// Now check if the constructor uses the previous ctor.
candReturnName match
case None => getOptimizableCalls(body) // no previous ctor, just continue
case Some(value) => getOptimizableCalls(body)(src, calledDefn, letCallNode, letCtorNode, Some(name))
case Some(value) => getOptimizableCalls(body)(src, start, calledDefn, letCallNode, letCtorNode, Some(name))
else
// it does use it, further analyse
letCtorNode match
Expand All @@ -116,14 +116,14 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
val fieldName = clsInfo.fields(ctorArgIndex)

// populate required values
getOptimizableCalls(body)(src, calledDefn, letCallNode, Some(LetCtorNodeInfo(x, y, name, fieldName)), Some(name))
getOptimizableCalls(body)(src, start, calledDefn, letCallNode, Some(LetCtorNodeInfo(x, y, name, fieldName)), Some(name))
case Some(_) =>
// another constructor is already using the call. Not OK

// if the is marked as tail recursive, we must use that call as the mod cons call, so error. otherwise,
// invalidate the discovered call and continue
if isTailRec then throw IRError("not a mod cons call")
else getOptimizableCalls(body)(src, None, None, None, None) // invalidate everything that's been discovered
else getOptimizableCalls(body)(src, start, None, None, None, None) // invalidate everything that's been discovered

case Expr.Select(name, cls, field) =>
letCallNode match
Expand All @@ -134,7 +134,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
// if the is marked as tail recursive, we must use that call as the mod cons call, so error. otherwise,
// invalidate the discovered call and continue
if isTailRec then throw IRError("not a mod cons call")
else getOptimizableCalls(body)(src, None, None, None, None) // invalidate everything that's been discovered
else getOptimizableCalls(body)(src, start, None, None, None, None) // invalidate everything that's been discovered
else
getOptimizableCalls(body) // OK
case Expr.BasicOp(name, args) =>
Expand All @@ -152,7 +152,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
// if the is marked as tail recursive, we must use that call as the mod cons call, so error. otherwise,
// invalidate the discovered call and continue
if isTailRec then throw IRError("not a mod cons call")
else getOptimizableCalls(body)(src, None, None, None, None) // invalidate everything that's been discovered
else getOptimizableCalls(body)(src, start, None, None, None, None) // invalidate everything that's been discovered
case x: LetCall =>
val LetCall(names, defn, args, body, isTailRec) = x

Expand All @@ -161,7 +161,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
else
letCallNode match
case None => // OK, use this LetCall as the mod cons
getOptimizableCalls(body)(src, Some(defn.expectDefn), Some(x), None, None)
getOptimizableCalls(body)(src, start, Some(defn.expectDefn), Some(x), None, None)
case Some(LetCall(namesOld, defnOld, argsOld, bodyOld, isTailRecOld)) =>
if isTailRecOld && isTailRec then
// 1. If both the old and newly discovered call are marked with tailrec, error
Expand All @@ -179,7 +179,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
// old call is not tailrec, so we can override it however we want
// we take a lucky guess and mark this as the mod cons call, but the
// user really should mark which calls should be tailrec
getOptimizableCalls(body)(src, Some(defn.expectDefn), Some(x), None, None)
getOptimizableCalls(body)(src, start, Some(defn.expectDefn), Some(x), None, None)

case AssignField(assignee, clsInfo, assignmentFieldName, value, body) =>
// make sure `value` is not the mod cons call
Expand All @@ -189,14 +189,14 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
value match
case Expr.Ref(name) =>
if names.contains(name) && isTailRec then throw IRError("not a mod cons call")
else getOptimizableCalls(body)(src, None, None, None, None) // invalidate everything that's been discovered
else getOptimizableCalls(body)(src, start, None, None, None, None) // invalidate everything that's been discovered
case _ =>
letCtorNode match
case None => getOptimizableCalls(body) // OK
case Some(LetCtorNodeInfo(_, ctor, name, fieldName)) =>
// If this assignment overwrites the mod cons value, forget it
if fieldName == assignmentFieldName && isTailRec then throw IRError("not a mod cons call")
else getOptimizableCalls(body)(src, None, None, None, None) // invalidate everything that's been discovered
else getOptimizableCalls(body)(src, start, None, None, None, None) // invalidate everything that's been discovered

// checks whether a list of names is equal to a list of trivial expressions referencing those names
private def argsListEqual(names: List[Name], exprs: List[TrivialExpr]) =
Expand All @@ -217,7 +217,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
case _ => false

private def findTailCalls(node: Node)(implicit src: Defn): Set[CallInfo] =
getOptimizableCalls(node)(src, None, None, None, None) match
getOptimizableCalls(node)(src, node, None, None, None, None) match
case Left(callInfo) => Set(callInfo)
case Right(nodes) => nodes.foldLeft(Set())((calls, node) => calls ++ findTailCalls(node))

Expand All @@ -228,23 +228,22 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
// Wikipedia: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
// Implementation Reference: https://www.baeldung.com/cs/scc-tarjans-algorithm

private class DefnNode(val defn: Defn) {
private class DefnNode(val defn: Defn):
override def hashCode(): Int = defn.hashCode

var num: Int = Int.MaxValue
var lowest: Int = Int.MaxValue
var visited: Boolean = false
var processed: Boolean = false
}

private def partitionNodes(implicit nodeMap: Map[Int, DefnNode]): List[DefnGraph] = {
private def partitionNodes(implicit nodeMap: Map[Int, DefnNode]): List[DefnGraph] =
val defns = nodeMap.values.toSet

var ctr = 0
var stack: List[(DefnNode, Set[CallInfo])] = Nil
var sccs: List[DefnGraph] = Nil

def dfs(src: DefnNode): Unit = {
def dfs(src: DefnNode): Unit =
src.num = ctr
src.lowest = ctr
ctr += 1
Expand All @@ -264,15 +263,15 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {

src.processed = true

if (src.num == src.lowest) {
if (src.num == src.lowest) then
var scc: Set[DefnNode] = Set()
var sccEdges: Set[CallInfo] = Set()

def pop(): (DefnNode, Set[CallInfo]) = {
def pop(): (DefnNode, Set[CallInfo]) =
val ret = stack.head
stack = stack.tail
ret
}


var (vertex, edges) = pop()

Expand All @@ -290,26 +289,31 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
sccEdges = sccEdges.filter { c => sccIds.contains(c.getDefn.id)}

sccs = DefnGraph(scc, sccEdges) :: sccs
}
}



for (v <- defns) {
for (v <- defns)
if (!v.visited)
dfs(v)
}


sccs
}


private case class DefnInfo(defn: Defn, stackFrameIdx: Int)

// Given a strongly connected component `defns`,
// returns a set containing the optimized function and the
// original functions pointing to an optimized function.
private def optimize(component: ScComponent, classes: Set[ClassInfo]): Set[Defn] = {
def asLit(x: Int) = Expr.Literal(IntLit(x))

// Given a strongly connected component `defns` of mutually mod cons functions,
// returns a set containing mutually tail recursive versions of them and
// the original functions pointing to the optimized ones.
private def optimizeModCons(component: ScComponent, classes: Set[ClassInfo]): Set[Defn] = ???

def asLit(x: Int) = Expr.Literal(IntLit(x))

// Given a strongly connected component `defns` of mutually
// tail recursive functions, returns a set containing the optimized function and the
// original functions pointing to an optimized function.
private def optimizeTailRec(component: ScComponent, classes: Set[ClassInfo]): Set[Defn] =

// To build the case block, we need to compare integers and check if the result is "True"
val trueClass = classes.find(c => c.ident == "True").get
Expand Down Expand Up @@ -368,7 +372,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {

// Tail calls to another function in the component will be replaced with a tail call
// to the merged function
def transformDefn(defn: Defn): Defn = {
def transformDefn(defn: Defn): Defn =
// TODO: Figure out how to substitute variables with dummy variables.
val info = defnInfoMap(defn.id)

Expand All @@ -384,7 +388,7 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
val res = Result(namesExpr).attachTag(tag)
val call = LetCall(names, newDefnRef, args, res, false).attachTag(tag)
Defn(defn.id, defn.name, defn.params, defn.resultNum, call)
}


// given expressions value, e1, e2, transform it into
// let scrut = tailrecBranch == value
Expand Down Expand Up @@ -429,28 +433,24 @@ class TailRecOpt(fnUid: FreshInt, tag: FreshInt) {
newDefnRef.defn = Left(newDefn)

defns.map { d => transformDefn(d) } + newDefn + jpDefn
}


private def partition(defns: Set[Defn]): List[ScComponent] = {
private def partition(defns: Set[Defn]): List[ScComponent] =
val nodeMap: Map[Int, DefnNode] = defns.foldLeft(Map.empty)((m, d) => m + (d.id -> DefnNode(d)))
partitionNodes(nodeMap).map(_.removeMetadata)
}


def apply(p: Program) = run(p)

def run_debug(p: Program): (Program, List[Set[String]]) = {
def run_debug(p: Program): (Program, List[Set[String]]) =
// val rewritten = p.defs.map(d => Defn(d.id, d.name, d.params, d.resultNum, rewriteTailCalls(d.body)))
val partitions = partition(p.defs)
val newDefs: Set[Defn] = partitions.flatMap { optimize(_, p.classes) }.toSet
val newDefs: Set[Defn] = partitions.flatMap { optimizeTailRec(_, p.classes) }.toSet

// update the definition refs
newDefs.foreach { defn => resolveDefnRef(defn.body, newDefs, true) }
resolveDefnRef(p.main, newDefs, true)

(Program(p.classes, newDefs, p.main), partitions.map(t => t.nodes.map(f => f.name)))
}

def run(p: Program): Program = {
run_debug(p)._1
}
}
def run(p: Program): Program = run_debug(p)._1
10 changes: 5 additions & 5 deletions compiler/shared/test/diff-ir/IRTailRec.mls
Original file line number Diff line number Diff line change
Expand Up @@ -467,14 +467,14 @@ hello()
//│ mlscript.compiler.optimizer.TailRecOpt.returnFailure$1(TailRecOpt.scala:43)
//│ mlscript.compiler.optimizer.TailRecOpt.getOptimizableCalls(TailRecOpt.scala:54)
//│ mlscript.compiler.optimizer.TailRecOpt.findTailCalls(TailRecOpt.scala:220)
//│ mlscript.compiler.optimizer.TailRecOpt.dfs$1(TailRecOpt.scala:253)
//│ mlscript.compiler.optimizer.TailRecOpt.partitionNodes$$anonfun$1(TailRecOpt.scala:298)
//│ mlscript.compiler.optimizer.TailRecOpt.dfs$1(TailRecOpt.scala:252)
//│ mlscript.compiler.optimizer.TailRecOpt.partitionNodes$$anonfun$1(TailRecOpt.scala:297)
//│ scala.runtime.function.JProcedure1.apply(JProcedure1.java:15)
//│ scala.runtime.function.JProcedure1.apply(JProcedure1.java:10)
//│ scala.collection.immutable.Set$Set1.foreach(Set.scala:168)
//│ mlscript.compiler.optimizer.TailRecOpt.partitionNodes(TailRecOpt.scala:299)
//│ mlscript.compiler.optimizer.TailRecOpt.partition(TailRecOpt.scala:436)
//│ mlscript.compiler.optimizer.TailRecOpt.run_debug(TailRecOpt.scala:443)
//│ mlscript.compiler.optimizer.TailRecOpt.partitionNodes(TailRecOpt.scala:297)
//│ mlscript.compiler.optimizer.TailRecOpt.partition(TailRecOpt.scala:440)
//│ mlscript.compiler.optimizer.TailRecOpt.run_debug(TailRecOpt.scala:447)
//│ mlscript.compiler.IRDiffTestCompiler.postProcess(TestIR.scala:28)
//│ mlscript.DiffTests.rec$1(DiffTests.scala:470)
//│ mlscript.DiffTests.$anonfun$new$3(DiffTests.scala:1076)
Expand Down