Skip to content

Commit 63d1c8d

Browse files
AnsonYeungLPTKCAG2Mark
authored
Handler runtime improvements (hkust-taco#282)
Co-authored-by: Lionel Parreaux <lionel.parreaux@gmail.com> Co-authored-by: CAG2Mark <git@markng.com>
1 parent 26a4a14 commit 63d1c8d

24 files changed

+1295
-988
lines changed

hkmc2/shared/src/main/scala/hkmc2/Diagnostic.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,6 @@ object Loc:
9191
def apply(xs: IterableOnce[Located]): Opt[Loc] =
9292
xs.iterator.foldLeft(none[Loc])((acc, l) => acc.fold(l.toLoc)(_ ++ l.toLoc |> some))
9393

94-
final case class Origin(fileName: Str, startLineNum: Int, fph: FastParseHelpers):
95-
override def toString = s"$fileName:+$startLineNum"
94+
final case class Origin(fileName: os.Path, startLineNum: Int, fph: FastParseHelpers):
95+
override def toString = s"${fileName.last}:+$startLineNum"
9696

hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class ParserSetup(file: os.Path, dbgParsing: Bool)(using Elaborator.State, Raise
1616

1717
val block = os.read(file)
1818
val fph = new FastParseHelpers(block)
19-
val origin = Origin(file.toString, 0, fph)
19+
val origin = Origin(file, 0, fph)
2020

2121
val lexer = new syntax.Lexer(origin, dbg = dbgParsing)
2222
val tokens = lexer.bracketedTokens

hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Lines changed: 68 additions & 52 deletions
Large diffs are not rendered by default.

hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,19 @@ import hkmc2.semantics.*
99
import hkmc2.syntax.Tree
1010

1111
class StackSafeTransform(depthLimit: Int)(using State):
12-
private val STACK_LIMIT_IDENT: Tree.Ident = Tree.Ident("__stackLimit")
13-
private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("__stackDepth")
14-
private val STACK_OFFSET_IDENT: Tree.Ident = Tree.Ident("__stackOffset")
15-
private val STACK_HANDLER_IDENT: Tree.Ident = Tree.Ident("__stackHandler")
16-
17-
private val predefPath: Path = State.globalThisSymbol.asPath.selN(Tree.Ident("Predef"))
18-
private val checkDepthPath: Path = predefPath.selN(Tree.Ident("checkDepth"))
19-
private val resetDepthPath: Path = predefPath.selN(Tree.Ident("resetDepth"))
20-
private val stackDelayClsPath: Path = predefPath.selN(Tree.Ident("__StackDelay"))
21-
private val stackLimitPath: Path = predefPath.selN(STACK_LIMIT_IDENT)
22-
private val stackDepthPath: Path = predefPath.selN(STACK_DEPTH_IDENT)
23-
private val stackOffsetPath: Path = predefPath.selN(STACK_OFFSET_IDENT)
24-
private val stackHandlerPath: Path = predefPath.selN(STACK_HANDLER_IDENT)
12+
private val STACK_LIMIT_IDENT: Tree.Ident = Tree.Ident("stackLimit")
13+
private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("stackDepth")
14+
private val STACK_OFFSET_IDENT: Tree.Ident = Tree.Ident("stackOffset")
15+
private val STACK_HANDLER_IDENT: Tree.Ident = Tree.Ident("stackHandler")
16+
17+
private val runtimePath: Path = State.runtimeSymbol.asPath
18+
private val checkDepthPath: Path = runtimePath.selN(Tree.Ident("checkDepth"))
19+
private val resetDepthPath: Path = runtimePath.selN(Tree.Ident("resetDepth"))
20+
private val stackDelayClsPath: Path = runtimePath.selN(Tree.Ident("StackDelay"))
21+
private val stackLimitPath: Path = runtimePath.selN(STACK_LIMIT_IDENT)
22+
private val stackDepthPath: Path = runtimePath.selN(STACK_DEPTH_IDENT)
23+
private val stackOffsetPath: Path = runtimePath.selN(STACK_OFFSET_IDENT)
24+
private val stackHandlerPath: Path = runtimePath.selN(STACK_HANDLER_IDENT)
2525

2626
private def intLit(n: BigInt) = Value.Lit(Tree.IntLit(n))
2727

@@ -33,22 +33,20 @@ class StackSafeTransform(depthLimit: Int)(using State):
3333
def extractRes(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) =
3434
if isTailCall then
3535
blockBuilder
36-
.assignFieldN(predefPath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1)))
36+
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1)))
3737
.ret(res)
3838
else
3939
val tmp = sym getOrElse TempSymbol(None, "tmp")
4040
val offsetGtDepth = TempSymbol(None, "offsetGtDepth")
4141
blockBuilder
42-
.assignFieldN(predefPath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1)))
42+
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, op("+", stackDepthPath, intLit(1)))
4343
.assign(tmp, res)
4444
.assign(tmp, Call(resetDepthPath, tmp.asPath.asArg :: curDepth.asPath.asArg :: Nil)(true, false))
4545
.rest(f(tmp.asPath))
46-
47-
def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) =
46+
47+
def wrapStackSafe(body: Block, resSym: Local, rest: Block) =
4848
val resumeSym = VarSymbol(Tree.Ident("resume"))
4949
val handlerSym = TempSymbol(None, "stackHandler")
50-
val resSym = sym getOrElse TempSymbol(None, "res")
51-
val handlerRes = TempSymbol(None, "res")
5250

5351
val clsSym = ClassSymbol(
5452
Tree.TypeDef(syntax.Cls, Tree.Error(), N, N),
@@ -64,26 +62,28 @@ class StackSafeTransform(depthLimit: Int)(using State):
6462
/*
6563
fun perform() =
6664
stackOffset = stackDepth
67-
let ret = resume()
68-
ret
65+
resume()
6966
*/
7067
blockBuilder
71-
.assignFieldN(predefPath, STACK_OFFSET_IDENT, stackDepthPath)
72-
.assign(handlerRes, Call(Value.Ref(resumeSym), Nil)(true, true))
73-
.ret(handlerRes.asPath)
68+
.assignFieldN(runtimePath, STACK_OFFSET_IDENT, stackDepthPath)
69+
.ret(Call(Value.Ref(resumeSym), Nil)(true, true))
7470
) :: Nil,
7571
blockBuilder
76-
.assignFieldN(predefPath, STACK_LIMIT_IDENT, intLit(depthLimit)) // set stackLimit before call
77-
.assignFieldN(predefPath, STACK_OFFSET_IDENT, intLit(0)) // set stackOffset = 0 before call
78-
.assignFieldN(predefPath, STACK_DEPTH_IDENT, intLit(1)) // set stackDepth = 1 before call
79-
.assignFieldN(predefPath, STACK_HANDLER_IDENT, handlerSym.asPath) // assign stack handler
80-
.rest(HandleBlockReturn(res)),
72+
.assignFieldN(runtimePath, STACK_LIMIT_IDENT, intLit(depthLimit)) // set stackLimit before call
73+
.assignFieldN(runtimePath, STACK_OFFSET_IDENT, intLit(0)) // set stackOffset = 0 before call
74+
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, intLit(1)) // set stackDepth = 1 before call
75+
.assignFieldN(runtimePath, STACK_HANDLER_IDENT, handlerSym.asPath) // assign stack handler
76+
.rest(body),
8177
blockBuilder // reset the stack safety values
82-
.assignFieldN(predefPath, STACK_DEPTH_IDENT, intLit(0)) // set stackDepth = 0 after call
83-
.assignFieldN(predefPath, STACK_HANDLER_IDENT, Value.Lit(Tree.UnitLit(true))) // set stackHandler = null
84-
.rest(f(resSym.asPath))
78+
.assignFieldN(runtimePath, STACK_DEPTH_IDENT, intLit(0)) // set stackDepth = 0 after call
79+
.assignFieldN(runtimePath, STACK_HANDLER_IDENT, Value.Lit(Tree.UnitLit(true))) // set stackHandler = null
80+
.rest(rest)
8581
)
8682

83+
def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) =
84+
val resSym = sym getOrElse TempSymbol(None, "res")
85+
wrapStackSafe(HandleBlockReturn(res), resSym, f(resSym.asPath))
86+
8787
// Rewrites anything that can contain a Call to increase the stack depth
8888
def transform(b: Block, curDepth: => Symbol, isTopLevel: Bool = false): Block =
8989
def usesStack(r: Result) = r match
@@ -119,8 +119,21 @@ class StackSafeTransform(depthLimit: Int)(using State):
119119
val hdr2 = hdr.mapConserve(applyHandler)
120120
val bod2 = rewriteBlk(bod)
121121
val rst2 = applyBlock(rst)
122-
HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
122+
if isTopLevel then
123+
val newRes = TempSymbol(N, "res")
124+
val newHandler = HandleBlock(l2, newRes, par2, args2, cls2, hdr2, bod2, HandleBlockReturn(newRes.asPath))
125+
wrapStackSafe(newHandler, res2, rst2)
126+
else
127+
HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
128+
123129
case _ => super.applyBlock(b)
130+
131+
override def applyHandler(hdr: Handler): Handler =
132+
val sym2 = hdr.sym.subst
133+
val resumeSym2 = hdr.resumeSym.subst
134+
val params2 = hdr.params.mapConserve(applyParamList)
135+
val body2 = rewriteBlk(hdr.body)
136+
Handler(sym2, resumeSym2, params2, body2)
124137

125138
override def applyResult2(r: Result)(k: Result => Block): Block =
126139
if usesStack(r) then

hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ object Elaborator:
142142
given State = this
143143
val globalThisSymbol = TopLevelSymbol("globalThis")
144144
val runtimeSymbol = TempSymbol(N, "runtime")
145+
val effectSigSymbol = ClassSymbol(Tree.TypeDef(syntax.Cls, Tree.Error(), N, N), Tree.Ident("EffectSig"))
146+
val returnClsSymbol = ClassSymbol(Tree.TypeDef(syntax.Cls, Tree.Error(), N, N), Tree.Ident("Return"))
145147
val builtinOpsMap =
146148
val baseBuiltins = builtins.map: op =>
147149
op -> BuiltinSymbol(op,

hkmc2/shared/src/main/scala/hkmc2/semantics/Importer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class Importer:
4444

4545
val block = os.read(file)
4646
val fph = new FastParseHelpers(block)
47-
val origin = Origin(file.toString, 0, fph)
47+
val origin = Origin(file, 0, fph)
4848

4949
val sym = tl.trace(s">>> Importing $file"):
5050

0 commit comments

Comments
 (0)