Skip to content

Commit ae9d3ca

Browse files
committed
refactor
1 parent b71a092 commit ae9d3ca

14 files changed

+404
-394
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object HandlerLowering:
4343
e
4444
case Some(v) => v
4545

46-
private case class LinkState(res: Local, cls: Path, uid: StateId, doUnwind: Opt[LazyVal[Path]])
46+
private case class LinkState(res: Local, cls: Path, uid: Path)
4747

4848
type FnOrCls = Either[BlockMemberSymbol, MemberSymbol[? <: ClassLikeDef] & InnerSymbol]
4949

@@ -78,6 +78,9 @@ object HandlerLowering:
7878
def topLevel(debugNme: Str) = DebugInfo(debugNme, Set.empty, N)
7979

8080
type StateId = BigInt
81+
82+
// TODO: move somewhere else, not sure where
83+
def simpleParam(sym: VarSymbol) = Param(FldFlags.empty, sym, N, Modulefulness.none)
8184

8285
import HandlerLowering.*
8386

@@ -98,18 +101,24 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
98101

99102
private def funcLikeHandlerCtx(ctorThis: Option[Path], isHandlerMtd: Bool, contNme: Str, debugNme: Str)(using h: HandlerCtx) =
100103
HandlerCtx(false, false, contNme, ctorThis, h.debugInfo.copy(debugNme), state =>
101-
if state.doUnwind.isEmpty then
102-
lastWords("doUnwind function does not exist for this function")
103-
Return(Call(
104-
state.doUnwind.get.get,
105-
state.res.asPath.asArg :: Value.Lit(Tree.IntLit(state.uid)).asArg :: Nil
106-
)(true, false), false)
107-
)
104+
blockBuilder
105+
.assignFieldN(state.res.asPath.contTrace.last, nextIdent, Instantiate(
106+
state.cls.selN(Tree.Ident("class")),
107+
state.uid :: Nil))
108+
.assignFieldN(state.res.asPath.contTrace, lastIdent, state.res.asPath.contTrace.last.next)
109+
.ret(state.res.asPath))
108110
private def functionHandlerCtx(nme: Str, debugNme: Str)(using HandlerCtx) = funcLikeHandlerCtx(N, false, nme, debugNme)
109-
private def topLevelCtx(nme: Str, debugNme: Str) = HandlerCtx(true, false, nme, N, DebugInfo.topLevel(debugNme), state => Assign(
110-
state.res,
111-
Call(paths.topLevelEffectPath, state.res.asPath.asArg :: Value.Lit(Tree.BoolLit(opt.debug)).asArg :: Nil)(true, false),
112-
End()))
111+
private def topLevelCall(state: LinkState) = Call(
112+
paths.topLevelEffectPath,
113+
state.res.asPath.asArg :: Value.Lit(Tree.BoolLit(opt.debug)).asArg :: Nil
114+
)(true, false)
115+
private def topLevelCtx(nme: Str, debugNme: Str) = HandlerCtx(
116+
true, false, nme, N, DebugInfo.topLevel(debugNme),
117+
state => Assign(
118+
state.res,
119+
topLevelCall(state),
120+
End())
121+
)
113122
private def ctorCtx(ctorThis: Path, nme: Str, debugNme: Str)(using HandlerCtx) = funcLikeHandlerCtx(S(ctorThis), false, nme, debugNme)
114123
private def handlerMtdCtx(nme: Str, debugNme: Str)(using HandlerCtx) = funcLikeHandlerCtx(N, true, nme, debugNme)
115124
private def handlerCtx(using HandlerCtx): HandlerCtx = summon
@@ -258,8 +267,6 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
258267
val rewrittenEntry = rewriteState(entryState)
259268
val rewrittenStates = states.map(rewriteState)
260269

261-
println(finalDests)
262-
263270
(rewrittenEntry, rewrittenStates)
264271

265272
def partitionBlock(blk: Block, inclEntryPoint: Bool, labelIds: Map[Symbol, (StateId, StateId)] = Map.empty): Ls[BlockState] =
@@ -507,17 +514,9 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
507514
doUnwindMap += fnOrCls -> doUnwindSym.asPath
508515
val pcSym = VarSymbol(Tree.Ident("pc"))
509516
val resSym = VarSymbol(Tree.Ident("res"))
510-
val doUnwindBlk = blockBuilder
511-
.assignFieldN(
512-
resSym.asPath.contTrace.last, nextIdent,
513-
Instantiate(
514-
cls.sym.asPath.selN(Tree.Ident("class")),
515-
pcSym.asPath :: Nil
516-
)
517-
)
518-
.assignFieldN(resSym.asPath.contTrace, lastIdent, resSym.asPath.contTrace.last.next)
519-
.ret(resSym.asPath)
520-
517+
val doUnwindBlk = h.linkAndHandle(
518+
LinkState(resSym, cls.sym.asPath, pcSym.asPath)
519+
)
521520
def simpleParam(sym: VarSymbol) = Param(FldFlags.empty, sym, N, Modulefulness.none)
522521
val doUnwindDef = FunDefn(
523522
N, doUnwindSym,
@@ -576,11 +575,6 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
576575
val lbl = freshTmp("handlerBody")
577576
val lblLoop = freshTmp("handlerLoop")
578577

579-
val handlerBody = translateBlock(h.body, Set.empty, L(sym), HandlerCtx(false, true,
580-
s"Cont$$handleBlock$$${symToStr(h.lhs)}$$", N, handlerCtx.debugInfo.copy(debugNme = s"‹handler body of ${h.lhs.nme}"), state => blockBuilder
581-
.assignFieldN(state.res.asPath.contTrace.last, nextIdent, PureCall(state.cls, Value.Lit(Tree.IntLit(state.uid)) :: Nil))
582-
.ret(PureCall(paths.handleBlockImplPath, state.res.asPath :: h.lhs.asPath :: Nil))))
583-
584578
val handlerMtds = h.handlers.map: handler =>
585579
val sym = BlockMemberSymbol("hdlrFun", Nil, true)
586580
val mtdBdy = translateBlock(handler.body,
@@ -605,16 +599,30 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
605599
// NOTE: the super call is inside the preCtor
606600
// during resumption we need to resume both the this.x = x bindings done in JSBuilder and the ctor
607601

608-
val body = blockBuilder
609-
.define(clsDefn)
610-
.assign(h.lhs, Instantiate(Value.Ref(clsDefn.sym), Nil))
611-
.rest(handlerBody)
602+
val handlerBody = translateBlock(
603+
h.body, Set.empty, L(sym),
604+
HandlerCtx(
605+
false, true,
606+
s"Cont$$handleBlock$$${symToStr(h.lhs)}$$", N,
607+
handlerCtx.debugInfo.copy(debugNme = s"‹handler body of ${h.lhs.nme}"),
608+
state => blockBuilder
609+
.assignFieldN(state.res.asPath.contTrace.last, nextIdent, PureCall(state.cls, state.uid :: Nil))
610+
.ret(PureCall(paths.handleBlockImplPath, state.res.asPath :: h.lhs.asPath :: Nil))
611+
)
612+
)
612613

613614
val defn = FunDefn(
614615
N, // no owner
615-
sym, PlainParamList(Nil) :: Nil, body)
616+
sym, PlainParamList(Nil) :: Nil, handlerBody)
616617

617-
val result = Define(defn, ResultPlaceholder(h.res, freshId(), Call(sym.asPath, Nil)(true, true), h.rest))
618+
// moved all defns outside
619+
val result = blockBuilder
620+
.define(defn)
621+
.define(clsDefn)
622+
.assign(h.lhs, Instantiate(Value.Ref(clsDefn.sym), Nil))
623+
.rest(
624+
ResultPlaceholder(h.res, freshId(), Call(sym.asPath, Nil)(true, true), h.rest)
625+
)
618626
result
619627

620628
private def genContClass(b: Block)(using h: HandlerCtx): Opt[ClsLikeDefn] =
@@ -644,7 +652,6 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
644652
.assignFieldN(resSym.asPath.contTrace.last, nextIdent, clsSym.asPath)
645653
.assignFieldN(resSym.asPath.contTrace, lastIdent, clsSym.asPath)
646654
.ret(resSym.asPath)
647-
def simpleParam(sym: VarSymbol) = Param(FldFlags.empty, sym, N, Modulefulness.none)
648655
val doUnwindDef = FunDefn(
649656
S(clsSym), doUnwindSym,
650657
PlainParamList(simpleParam(resSym) :: simpleParam(newPcSym) :: Nil) :: Nil,
@@ -823,13 +830,16 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
823830
private def genNormalBody(b: Block, clsSym: BlockMemberSymbol, doUnwind: Opt[LazyVal[Path]])(using HandlerCtx): Block =
824831
val transform = new BlockTransformerShallow(SymbolSubst()):
825832
override def applyBlock(b: Block): Block = b match
826-
case ResultPlaceholder(res, uid, c, rest) =>
833+
case ResultPlaceholder(res, uid, c, rest) =>
834+
val (doUnwindCall, implct) = doUnwind match
835+
case None => (topLevelCall(LinkState(res, clsSym.asPath, Value.Lit(Tree.IntLit(uid)))), true)
836+
case Some(doUnwind) => (PureCall(doUnwind.get, res.asPath :: Value.Lit(Tree.IntLit(uid)) :: Nil), false)
827837
blockBuilder
828838
.assign(res, c)
829839
.ifthen(
830840
res.asPath,
831841
Case.Cls(paths.effectSigSym, paths.effectSigPath),
832-
handlerCtx.linkAndHandle(LinkState(res, clsSym.asPath, uid, doUnwind))
842+
Return(doUnwindCall, implct)
833843
)
834844
.rest(applyBlock(rest))
835845
case _ => super.applyBlock(b)

hkmc2/shared/src/test/mlscript-compile/Runtime.mjs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ let Runtime1;
225225
};
226226
this.stackLimit = 0;
227227
this.stackDepth = 0;
228+
this.stackOffset = 0;
228229
this.stackHandler = null;
229230
this.stackResume = null;
230231
const StackDelayHandler$class = class StackDelayHandler {
@@ -777,24 +778,34 @@ let Runtime1;
777778
return tmp4
778779
}
779780
static checkDepth() {
780-
let scrut, tmp, tmp1;
781-
tmp = Runtime.stackDepth >= Runtime.stackLimit;
782-
tmp1 = Runtime.stackHandler !== null;
783-
scrut = tmp && tmp1;
781+
let scrut, tmp, tmp1, tmp2;
782+
tmp = Runtime.stackDepth - Runtime.stackOffset;
783+
tmp1 = tmp >= Runtime.stackLimit;
784+
tmp2 = Runtime.stackHandler !== null;
785+
scrut = tmp1 && tmp2;
784786
if (scrut === true) {
785787
return runtime.safeCall(Runtime.stackHandler.delay())
786788
} else {
787789
return runtime.Unit
788790
}
789791
}
790792
static resetDepth(tmp, curDepth) {
793+
let scrut, tmp1;
791794
Runtime.stackDepth = curDepth;
795+
scrut = curDepth < Runtime.stackOffset;
796+
if (scrut === true) {
797+
Runtime.stackOffset = curDepth;
798+
tmp1 = runtime.Unit;
799+
} else {
800+
tmp1 = runtime.Unit;
801+
}
792802
return tmp
793803
}
794804
static runStackSafe(limit, f1) {
795805
let result, scrut, saved, tmp1, tmp2, tmp3;
796806
Runtime.stackLimit = limit;
797807
Runtime.stackDepth = 1;
808+
Runtime.stackOffset = 0;
798809
Runtime.stackHandler = Runtime.StackDelayHandler;
799810
tmp1 = Runtime.enterHandleBlock(Runtime.StackDelayHandler, f1);
800811
result = tmp1;
@@ -803,6 +814,7 @@ let Runtime1;
803814
if (scrut === true) {
804815
saved = Runtime.stackResume;
805816
Runtime.stackResume = null;
817+
Runtime.stackOffset = Runtime.stackDepth;
806818
tmp2 = runtime.safeCall(saved());
807819
result = tmp2;
808820
tmp3 = runtime.Unit;
@@ -814,6 +826,7 @@ let Runtime1;
814826
}
815827
Runtime.stackLimit = 0;
816828
Runtime.stackDepth = 0;
829+
Runtime.stackOffset = 0;
817830
Runtime.stackHandler = null;
818831
return result
819832
}

hkmc2/shared/src/test/mlscript-compile/Runtime.mls

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ fun resumeContTrace(contTrace, value) =
345345
// stack safety
346346
mut val stackLimit = 0 // How deep the stack can go before heapifying the stack
347347
mut val stackDepth = 0 // Tracks the virtual + real stack depth
348+
mut val stackOffset = 0 // How much to offset stackDepth by to get the true stack depth (i.e. the virtual depth)
348349
mut val stackHandler = null
349350
mut val stackResume = null
350351

@@ -353,29 +354,34 @@ object StackDelayHandler with
353354
set stackResume = k
354355

355356
fun checkDepth() =
356-
if stackDepth >= stackLimit && stackHandler !== null then
357+
if stackDepth - stackOffset >= stackLimit && stackHandler !== null then
357358
// this is a tail call to effectful function
358359
stackHandler.delay()
359360
else
360361
()
361362

362363
fun resetDepth(tmp, curDepth) =
363364
set stackDepth = curDepth
365+
if curDepth < stackOffset do
366+
set stackOffset = curDepth
364367
tmp
365368

366369
fun runStackSafe(limit, f) =
367370
set
368371
stackLimit = limit
369372
stackDepth = 1
373+
stackOffset = 0
370374
stackHandler = StackDelayHandler
371375
let result = enterHandleBlock(StackDelayHandler, f)
372376
while stackResume !== null do
373377
let saved = stackResume
374378
set
375379
stackResume = null
380+
stackOffset = stackDepth
376381
result = saved()
377382
set
378383
stackLimit = 0
379384
stackDepth = 0
385+
stackOffset = 0
380386
stackHandler = null
381387
result

0 commit comments

Comments
 (0)