Skip to content

Tail Recursion Optimization #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 62 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 58 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
8b73a24
move changes from tailrec-opt to new branch
CAG2Mark Mar 13, 2024
93bf78f
Optimize strongly connected components
CAG2Mark Mar 23, 2024
9b65640
update map braces
CAG2Mark Mar 24, 2024
915891d
small refactor
CAG2Mark Mar 24, 2024
6f2ca3d
refactor
CAG2Mark Mar 24, 2024
f987e02
Update test infrastructure, fix code, add basic test for tailrec.
CAG2Mark Mar 29, 2024
184a5de
Update test
CAG2Mark Mar 29, 2024
3066c7f
remove todos
CAG2Mark Mar 29, 2024
d7b4bb3
Add field assignment to IR
CAG2Mark Apr 8, 2024
058a929
Prevent unnecessary inlining for mutually tail recursive funcs, handl…
CAG2Mark Apr 14, 2024
0e3a875
Update test
CAG2Mark Apr 14, 2024
dd50c7a
Improved tests
CAG2Mark Apr 14, 2024
fc282e7
Fix bugs and cases
CAG2Mark Apr 14, 2024
edf7c89
Interpret field assignment
CAG2Mark Apr 20, 2024
c08f31b
add class info to field assignment
CAG2Mark Apr 25, 2024
518ffb4
propagate tailrec, fix tailrec parsing issue
CAG2Mark Apr 27, 2024
d00d072
progress, fix bug
CAG2Mark Apr 29, 2024
e45412b
Detect mod cons tail calls
CAG2Mark Apr 29, 2024
fc473a4
Refactor tail call discovery
CAG2Mark Apr 29, 2024
22dd518
change test
CAG2Mark Apr 29, 2024
c1d90d9
add test, verify mod cons call discovery works
CAG2Mark Apr 29, 2024
16134c2
add tests, improve formatting
CAG2Mark Apr 29, 2024
e9d6c54
remove newlines
CAG2Mark Apr 29, 2024
dcd616b
add tostring, improve formatting
CAG2Mark Apr 30, 2024
8b91f91
remove println
CAG2Mark Apr 30, 2024
8d6e14f
actually handle single tail recursive
CAG2Mark Apr 30, 2024
6213d51
properly attach tags
CAG2Mark Apr 30, 2024
e778810
progress, handle join points properly
CAG2Mark May 1, 2024
934c324
fix test
CAG2Mark May 1, 2024
b27cf3a
Done
CAG2Mark May 1, 2024
31bd2f5
re-enable normal tailrec optimization, clean up code
CAG2Mark May 1, 2024
0e64e67
Merge branch 'mlscript' into tailrec
CAG2Mark May 1, 2024
715ef73
Update tests
CAG2Mark May 1, 2024
d8d05be
fix some tests
CAG2Mark May 12, 2024
5887014
Propagate tailrec, fix join point infinite recursion bug
CAG2Mark May 12, 2024
78da910
refactor and check @tailrec for function definitions
CAG2Mark May 12, 2024
42b264f
update
CAG2Mark May 12, 2024
2199e37
make AssiggnField an Expr instead of a Node
CAG2Mark Jun 1, 2024
eacf538
Fix unsafe partial destruction
CAG2Mark Jun 1, 2024
214f9d3
rename @tailrec to @tailcall for call-level annotations
CAG2Mark Jun 1, 2024
51af35f
Propagate positions of calls and @tailrec annotations
CAG2Mark Jun 1, 2024
8ce1b35
Add error reporting to IR diff tests, report tailrec IR errors, fix bug
CAG2Mark Jun 1, 2024
ee15969
fix tests
CAG2Mark Jun 1, 2024
5af3621
fix tests
CAG2Mark Jun 1, 2024
487279e
Fix grammar
CAG2Mark Jun 1, 2024
a4efc7a
Purity check, fix join point bug
CAG2Mark Jun 5, 2024
0c8dec1
Build in true/false class in IR
CAG2Mark Jun 5, 2024
1ea98f1
restore nuscratch changes
CAG2Mark Jun 5, 2024
abcadfe
Ensure no function name clashes
CAG2Mark Jun 5, 2024
ea8d442
Document, add undefined literal and address remaining todos
CAG2Mark Jun 5, 2024
0125b1e
use unitlit instead of a new literal type for undefined
CAG2Mark Jun 5, 2024
dd8f2ae
Document @tailcall and @tailrec
CAG2Mark Jun 10, 2024
398bd9e
properly raise errors in IR builder
CAG2Mark Jun 17, 2024
4f93b76
fix some documentation
CAG2Mark Jun 17, 2024
f1b5fcb
Merge master
CAG2Mark Jun 17, 2024
f479ecc
remove bad copy paste
CAG2Mark Jun 17, 2024
49a4993
fix typo
CAG2Mark Jun 17, 2024
875484c
Improve todo comment
CAG2Mark Jun 17, 2024
92039ea
respond to code review
CAG2Mark Jun 18, 2024
22d4f47
Improve test format
CAG2Mark Jun 18, 2024
263b82a
Merge master
CAG2Mark Jun 18, 2024
d0923c3
re-run tests to fix line numbers
CAG2Mark Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 74 additions & 32 deletions compiler/shared/main/scala/mlscript/compiler/ir/Builder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import mlscript.compiler.optimizer.FreeVarAnalysis
import mlscript.utils.shorthands._
import mlscript.utils._
import mlscript._
import mlscript.Message._
import collection.mutable.ListBuffer

final val ops = Set("+", "-", "*", "/", ">", "<", ">=", "<=", "!=", "==")

final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: FreshInt):
final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: FreshInt, raise: Diagnostic => Unit):
import Node._
import Expr._

Expand Down Expand Up @@ -72,6 +73,33 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres
tm

private def buildResultFromTerm(using ctx: Ctx)(tm: Term)(k: Node => Node): Node =
def buildLetCall(f: Term, xs: Tup, ann: Option[Term]) =
buildResultFromTerm(f) { node => node match
case Result(Ref(g) :: Nil) if ctx.fnCtx.contains(g.str) => buildResultFromTerm(xs) {
case Result(args) =>
val v = fresh.make

ann match
case Some(ann @ Var(nme)) =>
if nme == "tailcall" then
LetCall(List(v), DefnRef(Right(g.str)), args, true, v |> ref |> sresult |> k)(f.toLoc).attachTag(tag)
else
if nme == "tailrec" then
raise(ErrorReport(List(msg"@tailrec is for annotating functions; try @tailcall instead" -> ann.toLoc), true, Diagnostic.Compilation))
LetCall(List(v), DefnRef(Right(g.str)), args, false, v |> ref |> sresult |> k)(f.toLoc).attachTag(tag)
case Some(_) => node |> unexpectedNode
case None => LetCall(List(v), DefnRef(Right(g.str)), args, false, v |> ref |> sresult |> k)(f.toLoc).attachTag(tag)

case node @ _ => node |> unexpectedNode
}
case Result(Ref(f) :: Nil) => buildResultFromTerm(xs) {
case Result(args) =>
throw IRError(s"not supported: apply")
case node @ _ => node |> unexpectedNode
}
case node @ _ => node |> unexpectedNode
}

val res = tm match
case lit: Lit => Literal(lit) |> sresult |> k
case v @ Var(name) =>
Expand Down Expand Up @@ -114,21 +142,16 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres
case node @ _ => node |> unexpectedNode
}

case App(f, xs @ Tup(_)) =>
buildResultFromTerm(f) {
case Result(Ref(f) :: Nil) if ctx.fnCtx.contains(f.str) => buildResultFromTerm(xs) {
case Result(args) =>
val v = fresh.make
LetCall(List(v), DefnRef(Right(f.str)), args, v |> ref |> sresult |> k).attachTag(tag)
case node @ _ => node |> unexpectedNode
}
case Result(Ref(f) :: Nil) => buildResultFromTerm(xs) {
case Result(args) =>
throw IRError(s"not supported: apply")
case node @ _ => node |> unexpectedNode
}
case node @ _ => node |> unexpectedNode
}
case App(f, xs @ Tup(_)) => buildLetCall(f, xs, None)
case Ann(ann, App(f, xs @ Tup(_))) => buildLetCall(f, xs, Some(ann))

case Ann(ann @ Var(name), recv) =>
if name == "tailcall" then
raise(ErrorReport(List(msg"@tailcall may only be used to annotate function calls" -> ann.toLoc), true, Diagnostic.Compilation))
else if name == "tailrec" then
raise(ErrorReport(List(msg"@tailrec may only be used to annotate functions" -> ann.toLoc), true, Diagnostic.Compilation))

buildResultFromTerm(recv)(k)

case Let(false, Var(name), rhs, body) =>
buildBinding(name, rhs, body)(k)
Expand All @@ -147,7 +170,9 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres
jp.str,
params = res :: fvs.map(x => Name(x)),
resultNum = 1,
jpbody
jpbody,
false,
None
)
ctx.jpAcc.addOne(jpdef)
val tru2 = buildResultFromTerm(tru) {
Expand Down Expand Up @@ -180,6 +205,8 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres
params = res :: fvs.map(x => Name(x)),
resultNum = 1,
jpbody,
false,
None
)
ctx.jpAcc.addOne(jpdef)
val cases: Ls[(ClassInfo, Node)] = lines map {
Expand Down Expand Up @@ -234,20 +261,31 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres
res

private def buildDefFromNuFunDef(using ctx: Ctx)(nfd: Statement): Defn = nfd match
case NuFunDef(_, Var(name), None, Nil, L(Lam(Tup(fields), body))) =>
val strs = fields map {
case N -> Fld(FldFlags.empty, Var(x)) => x
case _ => throw IRError("unsupported field")
}
val names = strs map (fresh.make(_))
given Ctx = ctx.copy(nameCtx = ctx.nameCtx ++ (strs zip names))
Defn(
fnUid.make,
name,
params = names,
resultNum = 1,
buildResultFromTerm(body) { x => x }
)
case nfd: NuFunDef => nfd match
case NuFunDef(_, Var(name), None, Nil, L(Lam(Tup(fields), body))) =>
val strs = fields map {
case N -> Fld(FldFlags.empty, Var(x)) => x
case _ => throw IRError("unsupported field")
}
val names = strs map (fresh.make(_))
given Ctx = ctx.copy(nameCtx = ctx.nameCtx ++ (strs zip names))
val trAnn = nfd.annotations.find {
case Var("tailrec") => true
case ann @ Var("tailcall") =>
raise(ErrorReport(List(msg"@tailcall is for annotating function calls; try @tailrec instead" -> ann.toLoc), true, Diagnostic.Compilation))
false
case _ => false }

Defn(
fnUid.make,
name,
params = names,
resultNum = 1,
buildResultFromTerm(body) { x => x },
trAnn.isDefined,
trAnn.flatMap(_.toLoc)
)
case _ => throw IRError("unsupported NuFunDef")
case _ => throw IRError("unsupported NuFunDef")

private def buildClassInfo(ntd: Statement): ClassInfo = ntd match
Expand Down Expand Up @@ -288,7 +326,11 @@ final class Builder(fresh: Fresh, fnUid: FreshInt, classUid: FreshInt, tag: Fres

import scala.collection.mutable.{ HashSet => MutHSet }

val cls = grouped.getOrElse(0, Nil).map(buildClassInfo)
// TODO: properly add prelude classes such as "True" and "False" rather than this hacky method
val cls = ClassInfo(classUid.make, "True", List())
:: ClassInfo(classUid.make, "False", List())
:: grouped.getOrElse(0, Nil).map(buildClassInfo)

cls.foldLeft(Set.empty)(checkDuplicateField(_, _))

val clsinfo = cls.toSet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ private final class DefnRefResolver(defs: Set[Defn], allowInlineJp: Bool):
case Result(res) =>
case Case(scrut, cases) => cases map { (_, body) => f(body) }
case LetExpr(name, expr, body) => f(body)
case LetCall(resultNames, defnref, args, body) =>
case LetCall(resultNames, defnref, args, _, body) =>
defs.find{_.getName == defnref.getName} match
case Some(defn) => defnref.defn = Left(defn)
case None => throw IRError(f"unknown function ${defnref.getName} in ${defs.map{_.getName}.mkString(",")}")
Expand Down
40 changes: 28 additions & 12 deletions compiler/shared/main/scala/mlscript/compiler/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import mlscript.utils.shorthands._
import mlscript.compiler.ir._
import mlscript.compiler.optimizer._

import mlscript.Loc

import collection.mutable.{Map as MutMap, Set as MutSet, HashMap, ListBuffer}
import annotation.unused
import util.Sorting
Expand Down Expand Up @@ -67,7 +69,9 @@ case class Defn(
val name: Str,
val params: Ls[Name],
val resultNum: Int,
val body: Node
val body: Node,
val isTailRec: Bool,
val loc: Opt[Loc] = None
):
override def hashCode: Int = id
def getName: String = name
Expand Down Expand Up @@ -96,6 +100,7 @@ enum Expr:
case CtorApp(name: ClassInfo, args: Ls[TrivialExpr])
case Select(name: Name, cls: ClassInfo, field: Str)
case BasicOp(name: Str, args: Ls[TrivialExpr])
case AssignField(assignee: Name, clsInfo: ClassInfo, fieldName: Str, value: TrivialExpr)

override def toString: String = show

Expand All @@ -107,28 +112,36 @@ enum Expr:
case Literal(IntLit(lit)) => s"$lit" |> raw
case Literal(DecLit(lit)) => s"$lit" |> raw
case Literal(StrLit(lit)) => s"$lit" |> raw
case Literal(UnitLit(lit)) => s"$lit" |> raw
case Literal(UnitLit(lit)) => (if lit then "undefined" else "null") |> raw
case CtorApp(ClassInfo(_, name, _), args) =>
raw(name) <#> raw("(") <#> raw(args |> show_args) <#> raw(")")
case Select(s, _, fld) =>
raw(s.toString) <#> raw(".") <#> raw(fld)
case BasicOp(name: Str, args) =>
raw(name) <#> raw("(") <#> raw(args |> show_args) <#> raw(")")
case AssignField(assignee, clsInfo, fieldName, value) =>
stack(
raw("assign")
<:> raw(assignee.toString + "." + fieldName)
<:> raw(":=")
<:> value.toDocument
)

def mapName(f: Name => Name): Expr = this match
case Ref(name) => Ref(f(name))
case Literal(lit) => Literal(lit)
case CtorApp(cls, args) => CtorApp(cls, args.map(_.mapNameOfTrivialExpr(f)))
case Select(x, cls, field) => Select(f(x), cls, field)
case BasicOp(name, args) => BasicOp(name, args.map(_.mapNameOfTrivialExpr(f)))
case AssignField(assignee, clsInfo, fieldName, value) => AssignField(f(assignee), clsInfo, fieldName, value.mapNameOfTrivialExpr(f))

def locMarker: LocMarker = this match
case Ref(name) => LocMarker.MRef(name.str)
case Literal(lit) => LocMarker.MLit(lit)
case CtorApp(name, args) => LocMarker.MCtorApp(name, args.map(_.toExpr.locMarker))
case Select(name, cls, field) => LocMarker.MSelect(name.str, cls, field)
case BasicOp(name, args) => LocMarker.MBasicOp(name, args.map(_.toExpr.locMarker))

case AssignField(assignee, clsInfo, fieldName, value) => LocMarker.MAssignField(assignee.str, fieldName, value.toExpr.locMarker)

enum Node:
// Terminal forms:
Expand All @@ -137,7 +150,7 @@ enum Node:
case Case(scrut: Name, cases: Ls[(ClassInfo, Node)])
// Intermediate forms:
case LetExpr(name: Name, expr: Expr, body: Node)
case LetCall(names: Ls[Name], defn: DefnRef, args: Ls[TrivialExpr], body: Node)
case LetCall(names: Ls[Name], defn: DefnRef, args: Ls[TrivialExpr], isTailRec: Bool, body: Node)(val loc: Opt[Loc] = None)

var tag = DefnTag(-1)

Expand All @@ -160,7 +173,9 @@ enum Node:
case Jump(defn, args) => Jump(defn, args.map(_.mapNameOfTrivialExpr(f)))
case Case(scrut, cases) => Case(f(scrut), cases.map { (cls, arm) => (cls, arm.mapName(f)) })
case LetExpr(name, expr, body) => LetExpr(f(name), expr.mapName(f), body.mapName(f))
case LetCall(names, defn, args, body) => LetCall(names.map(f), defn, args.map(_.mapNameOfTrivialExpr(f)), body.mapName(f))
case x: LetCall =>
val LetCall(names, defn, args, isTailRec, body) = x
LetCall(names.map(f), defn, args.map(_.mapNameOfTrivialExpr(f)), isTailRec, body.mapName(f))(x.loc)

def copy(ctx: Map[Str, Name]): Node = this match
case Result(res) => Result(res.map(_.mapNameOfTrivialExpr(_.trySubst(ctx))))
Expand All @@ -169,9 +184,10 @@ enum Node:
case LetExpr(name, expr, body) =>
val name_copy = name.copy
LetExpr(name_copy, expr.mapName(_.trySubst(ctx)), body.copy(ctx + (name_copy.str -> name_copy)))
case LetCall(names, defn, args, body) =>
case x: LetCall =>
val LetCall(names, defn, args, isTailRec, body) = x
val names_copy = names.map(_.copy)
LetCall(names_copy, defn, args.map(_.mapNameOfTrivialExpr(_.trySubst(ctx))), body.copy(ctx ++ names_copy.map(x => x.str -> x)))
LetCall(names_copy, defn, args.map(_.mapNameOfTrivialExpr(_.trySubst(ctx))), isTailRec, body.copy(ctx ++ names_copy.map(x => x.str -> x)))(x.loc)

private def toDocument: Document = this match
case Result(res) => raw(res |> show_args) <:> raw(s"-- $tag")
Expand Down Expand Up @@ -203,28 +219,27 @@ enum Node:
<:> raw("in")
<:> raw(s"-- $tag"),
body.toDocument)
case LetCall(xs, defn, args, body) =>
case LetCall(xs, defn, args, isTailRec, body) =>
stack(
raw("let*")
<:> raw("(")
<#> raw(xs.map(_.toString).mkString(","))
<#> raw(")")
<:> raw("=")
<:> raw(defn.getName)
<:> raw((if isTailRec then "@tailcall " else "") + defn.getName)
<#> raw("(")
<#> raw(args.map{ x => x.toString }.mkString(","))
<#> raw(")")
<:> raw("in")
<:> raw(s"-- $tag"),
body.toDocument)

def locMarker: LocMarker =
val marker = this match
case Result(res) => LocMarker.MResult(res.map(_.toExpr.locMarker))
case Jump(defn, args) => LocMarker.MJump(defn.getName, args.map(_.toExpr.locMarker))
case Case(scrut, cases) => LocMarker.MCase(scrut.str, cases.map(_._1))
case LetExpr(name, expr, _) => LocMarker.MLetExpr(name.str, expr.locMarker)
case LetCall(names, defn, args, _) => LocMarker.MLetCall(names.map(_.str), defn.getName, args.map(_.toExpr.locMarker))
case LetCall(names, defn, args, _, _) => LocMarker.MLetCall(names.map(_.str), defn.getName, args.map(_.toExpr.locMarker))
marker.tag = this.tag
marker

Expand Down Expand Up @@ -252,6 +267,7 @@ enum LocMarker:
case MCase(scrut: Str, cases: Ls[ClassInfo])
case MLetExpr(name: Str, expr: LocMarker)
case MLetCall(names: Ls[Str], defn: Str, args: Ls[LocMarker])
case MAssignField(assignee: Str, field: Str, value: LocMarker)
var tag = DefnTag(-1)

def toDocument: Document = this match
Expand Down Expand Up @@ -281,7 +297,7 @@ enum LocMarker:
case MLit(IntLit(lit)) => s"$lit" |> raw
case MLit(DecLit(lit)) => s"$lit" |> raw
case MLit(StrLit(lit)) => s"$lit" |> raw
case MLit(UnitLit(lit)) => s"$lit" |> raw
case MLit(UnitLit(lit)) => (if lit then "undefined" else "null") |> raw
case _ => raw("...")

def show = s"$tag-" + toDocument.print
Expand Down
29 changes: 26 additions & 3 deletions compiler/shared/main/scala/mlscript/compiler/ir/Interp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,10 @@ class Interpreter(verbose: Bool):
private enum Expr:
case Ref(name: Name)
case Literal(lit: Lit)
case CtorApp(name: ClassInfo, args: Ls[Expr])
case CtorApp(name: ClassInfo, var args: Ls[Expr])
case Select(name: Name, cls: ClassInfo, field: Str)
case BasicOp(name: Str, args: Ls[Expr])
case AssignField(assignee: Name, clsInfo: ClassInfo, fieldName: Str, value: Expr)

def show: Str =
document.print
Expand All @@ -60,6 +61,15 @@ class Interpreter(verbose: Bool):
raw(s) <#> raw(".") <#> raw(fld)
case BasicOp(name: Str, args) =>
raw(name) <#> raw("(") <#> raw(args |> show_args) <#> raw(")")
case AssignField(Name(assignee), clsInfo, fieldName, value) =>
stack(
raw("assign")
<:> raw(assignee)
<#> raw(".")
<#> raw(fieldName)
<:> raw("=")
<:> value.document,
)

private enum Node:
case Result(res: Ls[Expr])
Expand Down Expand Up @@ -147,13 +157,14 @@ class Interpreter(verbose: Bool):
case IExpr.CtorApp(name, args) => CtorApp(name, args |> convertArgs)
case IExpr.Select(name, cls, field) => Select(name, cls, field)
case IExpr.BasicOp(name, args) => BasicOp(name, args |> convertArgs)
case IExpr.AssignField(assignee, clsInfo, fieldName, value) => AssignField(assignee, clsInfo, fieldName, value |> convert)

private def convert(node: INode): Node = node match
case INode.Result(xs) => Result(xs |> convertArgs)
case INode.Jump(defnref, args) => Jump(DefnRef(Right(defnref.getName)), args |> convertArgs)
case INode.Case(scrut, cases) => Case(scrut, cases.map{(cls, node) => (cls, node |> convert)})
case INode.LetExpr(name, expr, body) => LetExpr(name, expr |> convert, body |> convert)
case INode.LetCall(xs, defnref, args, body) =>
case INode.LetCall(xs, defnref, args, _, body) =>
LetCall(xs, DefnRef(Right(defnref.getName)), args |> convertArgs, body |> convert)

private def convert(defn: IDefn): Defn =
Expand Down Expand Up @@ -210,6 +221,7 @@ class Interpreter(verbose: Bool):

private def evalArgs(using ctx: Ctx, clsctx: ClassCtx)(exprs: Ls[Expr]): Either[Ls[Expr], Ls[Expr]] =
var changed = false

val xs = exprs.map {
arg => eval(arg) match
case Left(expr) => changed = true; expr
Expand All @@ -230,7 +242,7 @@ class Interpreter(verbose: Bool):
case CtorApp(name, args) =>
evalArgs(args) match
case Left(xs) => Left(CtorApp(name, xs))
case _ => Right(expr)
case Right(xs) => Right(CtorApp(name, xs)) // TODO: This makes recursion modulo cons work, but should be investigated further.
case Select(name, cls, field) =>
ctx.get(name.str).map {
case CtorApp(cls2, xs) if cls == cls2 =>
Expand All @@ -246,6 +258,17 @@ class Interpreter(verbose: Bool):
eval(using ctx, clsctx)(name, xs.head, xs.tail.head)
case _ => throw IRInterpreterError("unexpected basic operation")
x.toLeft(expr)
case AssignField(assignee, clsInfo, fieldName, expr) =>
val value = evalMayNotProgress(expr)
ctx.get(assignee.str) match
case Some(x: CtorApp) =>
val CtorApp(cls, args) = x
val idx = cls.fields.indexOf(fieldName)
val newArgs = args.updated(idx, value)
x.args = newArgs
Left(x)
case Some(_) => throw IRInterpreterError("tried to assign a field of a non-ctor")
case None => throw IRInterpreterError("could not find value " + assignee)

private def expectDefn(r: DefnRef) = r.defn match
case Left(value) => value
Expand Down
Loading
Loading