Skip to content

Commit 4e1d3cf

Browse files
AnsonYeungLPTK
andauthored
Non local return (#288)
* Remove old implementation of non local return and add detection for non local return * Add OuterCtx and removed some dummy code * Add tests for non local return --------- Co-authored-by: Lionel Parreaux <lionel.parreaux@gmail.com>
1 parent 92b837e commit 4e1d3cf

19 files changed

+332
-170
lines changed

hkmc2/shared/src/main/scala/hkmc2/MLsCompiler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ class MLsCompiler(preludeFile: os.Path, mkOutput: ((Str => Unit) => Unit) => Uni
7272

7373
val elab = Elaborator(etl, wd, Ctx.empty)
7474

75-
val initState = State.init.nest(N)
75+
val initState = State.init.nestLocal
7676

7777
val (pblk, newCtx) = elab.importFrom(preludeParse.resultBlk)(using initState)
7878

79-
newCtx.nest(N).givenIn:
79+
newCtx.nestLocal.givenIn:
8080
val elab = Elaborator(etl, wd, newCtx)
8181
val parsed = mainParse.resultBlk
8282
val (blk0, _) = elab.importFrom(parsed)

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,11 @@ sealed abstract class Block extends Product with AutoLocated:
4242
val rest = rst.definedVars
4343
if defn.isOwned then rest else rest + defn.sym
4444
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) => bod.definedVars ++ rst.definedVars + lhs
45-
case HandleBlockReturn(_) => Set.empty
4645
case TryBlock(sub, fin, rst) => sub.definedVars ++ fin.definedVars ++ rst.definedVars
4746
case Label(lbl, bod, rst) => bod.definedVars ++ rst.definedVars
4847

4948
lazy val size: Int = this match
50-
case _: Return | _: Throw | _: End | _: Break | _: Continue | _: HandleBlockReturn => 1
49+
case _: Return | _: Throw | _: End | _: Break | _: Continue => 1
5150
case Begin(sub, rst) => sub.size + rst.size
5251
case Assign(_, _, rst) => 1 + rst.size
5352
case AssignField(_, _, _, rst) => 1 + rst.size
@@ -97,7 +96,6 @@ sealed abstract class Block extends Product with AutoLocated:
9796
case Define(defn, rest) => defn.freeVars ++ rest.freeVars
9897
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
9998
(bod.freeVars - lhs) ++ rst.freeVars ++ hdr.flatMap(_.freeVars)
100-
case HandleBlockReturn(res) => res.freeVars
10199
case End(msg) => Set.empty
102100

103101
// TODO: freeVarsLLIR skips `fun` and `cls` in `Call` and `Instantiate` respectively, which is needed in some
@@ -121,7 +119,6 @@ sealed abstract class Block extends Product with AutoLocated:
121119
case Define(defn, rest) => defn.freeVarsLLIR ++ rest.freeVarsLLIR
122120
case HandleBlock(lhs, res, par, args, cls, hdr, bod, rst) =>
123121
(bod.freeVarsLLIR - lhs) ++ rst.freeVarsLLIR ++ hdr.flatMap(_.freeVars)
124-
case HandleBlockReturn(res) => res.freeVarsLLIR
125122
case End(msg) => Set.empty
126123

127124
lazy val subBlocks: Ls[Block] = this match
@@ -137,10 +134,9 @@ sealed abstract class Block extends Product with AutoLocated:
137134

138135
// TODO rm Lam from values and thus the need for these cases
139136
case Return(r, _) => r.subBlocks
140-
case HandleBlockReturn(r) => r.subBlocks
141137
case Throw(r) => r.subBlocks
142138

143-
case _: Return | _: Throw | _: Break | _: Continue | _: End | _: HandleBlockReturn => Nil
139+
case _: Return | _: Throw | _: Break | _: Continue | _: End => Nil
144140

145141
// Moves definitions in a block to the top. Only scans the top-level definitions of the block;
146142
// i.e, definitions inside other definitions are not moved out. Definitions inside `match`/`if`
@@ -295,8 +291,6 @@ case class HandleBlock(
295291
rest: Block
296292
) extends Block with ProductWithTail
297293

298-
case class HandleBlockReturn(res: Result) extends BlockTail
299-
300294
sealed abstract class Defn:
301295
val innerSym: Opt[MemberSymbol[?]]
302296
val sym: BlockMemberSymbol

hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTransformer.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ class BlockTransformer(subst: SymbolSubst):
2929
case Throw(exc) =>
3030
applyResult2(exc): exc2 =>
3131
if exc2 is exc then b else Throw(exc2)
32-
case HandleBlockReturn(res) =>
33-
applyResult2(res): res2 =>
34-
if res2 is res then b else HandleBlockReturn(res2)
3532
case Match(scrut, arms, dflt, rst) =>
3633
val scrut2 = applyPath(scrut)
3734
val arms2 = arms.mapConserve: arm =>

hkmc2/shared/src/main/scala/hkmc2/codegen/BlockTraverser.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ class BlockTraverser(subst: SymbolSubst):
2222
case Continue(lbl) => applyLocal(lbl)
2323
case Return(res, implct) => applyResult(res)
2424
case Throw(exc) => applyResult(exc)
25-
case HandleBlockReturn(res) => applyResult(res)
2625
case Match(scrut, arms, dflt, rst) =>
2726
val scrut2 = applyPath(scrut)
2827
arms.foreach: arm =>

hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Lines changed: 16 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,17 @@ object HandlerLowering:
2929

3030
private case class LinkState(res: Path, cls: Path, uid: StateId)
3131

32-
// shouldUnwrapRet: whether the current block should unwrap the runtime.Return if it encounter one
3332
// isTopLevel:
3433
// whether the current block is the top level block, as we do not emit code for continuation class on the top level
3534
// since we cannot return an effect signature on the top level (we are not in a function so return statement are invalid)
3635
// and we do not have any `return` statement in the top level block so we do not need the `runtime.Return` workarounds.
37-
// isHandler: whether the current block is the body of handler method
3836
// contName: the name of the continuation class
3937
// ctorThis: the path to `this` in the constructor, this is used to insert `return this;` at the end of constructor.
4038
// linkAndHandle:
4139
// a function that takes a LinkState and returns a block that links the continuation class and handles the effect
4240
// this is a convenience function which initializes the continuation class in function context or throw an error in top level
4341
private case class HandlerCtx(
44-
shouldUnwrapRet: Bool,
4542
isTopLevel: Bool,
46-
isHandlerMtd: Bool,
4743
isHandlerBody: Bool,
4844
contName: Str,
4945
ctorThis: Option[Path],
@@ -59,27 +55,25 @@ class HandlerPaths(using Elaborator.State):
5955
val effectSigPath: Path = runtimePath.selN(Tree.Ident("EffectSig")).selN(Tree.Ident("class"))
6056
val effectSigSym: ClassSymbol = State.effectSigSymbol
6157
val contClsPath: Path = runtimePath.selN(Tree.Ident("FunctionContFrame")).selN(Tree.Ident("class"))
62-
val retClsPath: Path = runtimePath.selN(Tree.Ident("Return")).selN(Tree.Ident("class"))
63-
val retClsSym: ClassSymbol = State.returnClsSymbol
6458
val mkEffectPath: Path = runtimePath.selN(Tree.Ident("mkEffect"))
6559
val handleBlockImplPath: Path = runtimePath.selN(Tree.Ident("handleBlockImpl"))
6660
val stackDelayClsPath: Path = runtimePath.selN(Tree.Ident("StackDelay"))
6761

6862
def isHandlerClsPath(p: Path) =
69-
(p eq contClsPath) || (p eq stackDelayClsPath) || (p eq effectSigPath) || (p eq retClsPath)
63+
(p eq contClsPath) || (p eq stackDelayClsPath) || (p eq effectSigPath)
7064

7165
class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, Elaborator.Ctx):
7266

7367
private def funcLikeHandlerCtx(ctorThis: Option[Path], isHandlerMtd: Bool, nme: Str) =
74-
HandlerCtx(!isHandlerMtd, false, isHandlerMtd, false, nme, ctorThis, state =>
68+
HandlerCtx(false, false, nme, ctorThis, state =>
7569
blockBuilder
7670
.assignFieldN(state.res.contTrace.last, nextIdent, Instantiate(
7771
state.cls.selN(Tree.Ident("class")),
7872
Value.Lit(Tree.IntLit(state.uid)) :: Nil))
7973
.assignFieldN(state.res.contTrace, lastIdent, state.res.contTrace.last.next)
8074
.ret(state.res))
8175
private def functionHandlerCtx(nme: Str) = funcLikeHandlerCtx(N, false, nme)
82-
private def topLevelCtx(nme: Str) = HandlerCtx(false, true, false, false, nme, N, _ => rtThrowMsg("Unhandled effects"))
76+
private def topLevelCtx(nme: Str) = HandlerCtx(true, false, nme, N, _ => rtThrowMsg("Unhandled effects"))
8377
private def ctorCtx(ctorThis: Path, nme: Str) = funcLikeHandlerCtx(S(ctorThis), false, nme)
8478
private def handlerMtdCtx(nme: Str) = funcLikeHandlerCtx(N, true, nme)
8579
private def handlerCtx(using HandlerCtx): HandlerCtx = summon
@@ -116,25 +110,25 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
116110
def apply(res: Local, uid: StateId) =
117111
Assign(res, PureCall(Value.Ref(returnContSymbol), List(Value.Lit(Tree.IntLit(uid)))), End(""))
118112
def unapply(blk: Block) = blk match
119-
case Assign(res, PureCall(Value.Ref(`returnContSymbol`), List(Value.Lit(Tree.IntLit(uid)))), _) =>
113+
case Assign(res, PureCall(Value.Ref(`returnContSymbol`), List(Value.Lit(Tree.IntLit(uid)))), _) =>
120114
Some(res, uid)
121115
case _ => None
122116

123117
// placeholder for effectful Results, such as Call, Instantiate and anything else that could
124118
// return a continuation
125119
object ResultPlaceholder:
126120
private val callSymbol = freshTmp("resultPlaceholder")
127-
def apply(res: Local, uid: StateId, canRet: Bool, r: Result, rest: Block) =
121+
def apply(res: Local, uid: StateId, r: Result, rest: Block) =
128122
Assign(
129123
res,
130-
PureCall(Value.Ref(callSymbol), List(Value.Lit(Tree.IntLit(uid)), Value.Lit(Tree.BoolLit(canRet)))),
124+
PureCall(Value.Ref(callSymbol), List(Value.Lit(Tree.IntLit(uid)))),
131125
Assign(res, r, rest))
132126
def unapply(blk: Block) = blk match
133127
case Assign(
134128
res,
135-
PureCall(Value.Ref(`callSymbol`), List(Value.Lit(Tree.IntLit(uid)), Value.Lit(Tree.BoolLit(canRet)))),
129+
PureCall(Value.Ref(`callSymbol`), List(Value.Lit(Tree.IntLit(uid)))),
136130
Assign(_, c, rest)) =>
137-
Some(res, uid, canRet, c, rest)
131+
Some(res, uid, c, rest)
138132
case _ => None
139133

140134
object StateTransition:
@@ -287,7 +281,6 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
287281
case TryBlock(sub, finallyDo, rest) => ??? // ignore
288282
case Throw(_) => PartRet(blk, Nil)
289283
case _: HandleBlock => lastWords("unexpected handleBlock") // already translated at this point
290-
case _: HandleBlockReturn => lastWords("unexpected handleBlockReturn") // already translated at this point
291284

292285
val result = go(blk)(using labelIds, N)
293286
result.states
@@ -327,26 +320,26 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
327320
val fun2 = applyPath(fun)
328321
val args2 = args.mapConserve(applyArg)
329322
val c2 = if (fun2 is fun) && (args2 is args) then c else Call(fun2, args2)(c.isMlsFun, c.mayRaiseEffects)
330-
ResultPlaceholder(lhs, freshId(), handlerCtx.isHandlerMtd, c2, applyBlock(rest))
323+
ResultPlaceholder(lhs, freshId(), c2, applyBlock(rest))
331324
case Assign(lhs, c @ Instantiate(cls, args), rest) =>
332325
val cls2 = applyPath(cls)
333326
val args2 = args.mapConserve(applyPath)
334327
val c2 = if (cls2 is cls) && (args2 is args) then c else Instantiate(cls2, args2)
335-
ResultPlaceholder(lhs, freshId(), handlerCtx.isHandlerMtd, c2, applyBlock(rest))
328+
ResultPlaceholder(lhs, freshId(), c2, applyBlock(rest))
336329
case _ => super.applyBlock(b)
337330
override def applyResult2(r: Result)(k: Result => Block): Block = r match
338331
case c @ Call(fun, args) if c.mayRaiseEffects =>
339332
val res = freshTmp("res")
340333
val fun2 = applyPath(fun)
341334
val args2 = args.mapConserve(applyArg)
342335
val c2 = if (fun2 is fun) && (args2 is args) then c else Call(fun2, args2)(c.isMlsFun, c.mayRaiseEffects)
343-
ResultPlaceholder(res, freshId(), handlerCtx.isHandlerMtd, c2, k(Value.Ref(res)))
336+
ResultPlaceholder(res, freshId(), c2, k(Value.Ref(res)))
344337
case c @ Instantiate(cls, args) =>
345338
val res = freshTmp("res")
346339
val cls2 = applyPath(cls)
347340
val args2 = args.mapConserve(applyPath)
348341
val c2 = if (cls2 is cls) && (args2 is args) then c else Instantiate(cls2, args2)
349-
ResultPlaceholder(res, freshId(), handlerCtx.isHandlerMtd, c2, k(Value.Ref(res)))
342+
ResultPlaceholder(res, freshId(), c2, k(Value.Ref(res)))
350343
case r => super.applyResult2(r)(k)
351344
override def applyLam(lam: Value.Lam): Value.Lam =
352345
Value.Lam(lam.params, translateBlock(lam.body, functionHandlerCtx(s"Cont$$lambda$$")))
@@ -389,21 +382,8 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
389382
val sym = BlockMemberSymbol(s"handleBlock$$", Nil)
390383
val lbl = freshTmp("handlerBody")
391384
val lblLoop = freshTmp("handlerLoop")
392-
val tmp = freshTmp("retCont")
393-
def prepareBody(b: Block): Block =
394-
395-
val transform = new BlockTransformerShallow(SymbolSubst()):
396-
override def applyBlock(b: Block): Block =
397-
b match
398-
case Return(res, implct) =>
399-
// In case res is effectful, it will be handled in translateBlock
400-
Assign(tmp, res, Return(Instantiate(paths.retClsPath, tmp.asPath :: Nil), implct))
401-
case HandleBlockReturn(res) =>
402-
Return(res, false)
403-
case _ => super.applyBlock(b)
404-
transform.applyBlock(b)
405385

406-
val handlerBody = translateBlock(prepareBody(h.body), HandlerCtx(false, false, false, true,
386+
val handlerBody = translateBlock(h.body, HandlerCtx(false, true,
407387
s"Cont$$handleBlock$$${h.lhs.nme}$$", N, state => blockBuilder
408388
.assignFieldN(state.res.contTrace.last, nextIdent, PureCall(state.cls, Value.Lit(Tree.IntLit(state.uid)) :: Nil))
409389
.ret(PureCall(paths.handleBlockImplPath, state.res :: h.lhs.asPath :: Nil))))
@@ -437,7 +417,7 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
437417
N, // no owner
438418
sym, PlainParamList(Nil) :: Nil, body)
439419

440-
val result = Define(defn, ResultPlaceholder(h.res, freshId(), !handlerCtx.isTopLevel, Call(sym.asPath, Nil)(true, true), h.rest))
420+
val result = Define(defn, ResultPlaceholder(h.res, freshId(), Call(sym.asPath, Nil)(true, true), h.rest))
441421
result
442422

443423
private def genContClass(b: Block)(using HandlerCtx): Opt[ClsLikeDefn] =
@@ -453,7 +433,7 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
453433
val transform = new BlockTransformerShallow(SymbolSubst()):
454434
override def applyBlock(b: Block): Block = b match
455435
case Define(_: (ClsLikeDefn | FunDefn), rst) => applyBlock(rst)
456-
case ResultPlaceholder(res, uid, canRet, c, rest) =>
436+
case ResultPlaceholder(res, uid, c, rest) =>
457437
trivial = false
458438
blockBuilder
459439
.assign(res, c)
@@ -463,12 +443,6 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
463443
ReturnCont(res, uid)
464444
)
465445
.chain(ResumptionPoint(res, uid, _))
466-
.staticif(canRet, _.ifthen(
467-
res.asPath,
468-
Case.Cls(paths.retClsSym, paths.retClsPath),
469-
blockBuilder
470-
.ret(if handlerCtx.shouldUnwrapRet then res.asPath.value else res.asPath)
471-
))
472446
.rest(applyBlock(rest))
473447
case _ => super.applyBlock(b)
474448
transform.applyBlock(b)
@@ -558,19 +532,14 @@ class HandlerLowering(paths: HandlerPaths)(using TL, Raise, Elaborator.State, El
558532
private def genNormalBody(b: Block, clsSym: BlockMemberSymbol)(using HandlerCtx): Block =
559533
val transform = new BlockTransformerShallow(SymbolSubst()):
560534
override def applyBlock(b: Block): Block = b match
561-
case ResultPlaceholder(res, uid, canRet, c, rest) =>
535+
case ResultPlaceholder(res, uid, c, rest) =>
562536
blockBuilder
563537
.assign(res, c)
564538
.ifthen(
565539
res.asPath,
566540
Case.Cls(paths.effectSigSym, paths.effectSigPath),
567541
handlerCtx.linkAndHandle(LinkState(res.asPath, clsSym.asPath, uid))
568542
)
569-
.staticif(canRet, _.ifthen(
570-
res.asPath,
571-
Case.Cls(paths.retClsSym, paths.retClsPath),
572-
blockBuilder.ret(if handlerCtx.shouldUnwrapRet then res.asPath.value else res.asPath)
573-
))
574543
.rest(applyBlock(rest))
575544
case _ => super.applyBlock(b)
576545

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
321321
subTerm(rhs): par =>
322322
subTerms(as): asr =>
323323
HandleBlock(lhs, resSym, par, asr, cls, handlers,
324-
term_nonTail(bod)(HandleBlockReturn(_)),
324+
term_nonTail(bod)(Ret),
325325
k(Value.Ref(resSym)))
326326
case st.Blk(sts, res) => block(sts, R(res))(k)
327327
case Assgn(lhs, rhs) =>

hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
8181

8282
def extractResTopLevel(res: Result, isTailCall: Bool, f: Result => Block, sym: Option[Symbol], curDepth: => Symbol) =
8383
val resSym = sym getOrElse TempSymbol(None, "res")
84-
wrapStackSafe(HandleBlockReturn(res), resSym, f(resSym.asPath))
84+
wrapStackSafe(Ret(res), resSym, f(resSym.asPath))
8585

8686
// Rewrites anything that can contain a Call to increase the stack depth
8787
def transform(b: Block, curDepth: => Symbol, isTopLevel: Bool = false): Block =
@@ -120,7 +120,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths)(using State):
120120
val rst2 = applyBlock(rst)
121121
if isTopLevel then
122122
val newRes = TempSymbol(N, "res")
123-
val newHandler = HandleBlock(l2, newRes, par2, args2, cls2, hdr2, bod2, HandleBlockReturn(newRes.asPath))
123+
val newHandler = HandleBlock(l2, newRes, par2, args2, cls2, hdr2, bod2, Ret(newRes.asPath))
124124
wrapStackSafe(newHandler, res2, rst2)
125125
else
126126
HandleBlock(l2, res2, par2, args2, cls2, hdr2, bod2, rst2)

hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder:
156156
def returningTerm(t: Block, endSemi: Bool)(using Raise, Scope): Document =
157157
def mkSemi = if endSemi then ";" else ""
158158
t match
159-
case _: (HandleBlockReturn | HandleBlock) =>
159+
case _: HandleBlock =>
160160
errStmt(msg"This code requires effect handler instrumentation but was compiled without it.")
161161
case Assign(l, r, rst) =>
162162
doc" # ${getVar(l)} = ${result(r)};${returningTerm(rst, endSemi)}"

0 commit comments

Comments
 (0)