@@ -9,19 +9,19 @@ import hkmc2.semantics.*
9
9
import hkmc2 .syntax .Tree
10
10
11
11
class StackSafeTransform (depthLimit : Int )(using State ):
12
- private val STACK_LIMIT_IDENT : Tree .Ident = Tree .Ident (" __stackLimit " )
13
- private val STACK_DEPTH_IDENT : Tree .Ident = Tree .Ident (" __stackDepth " )
14
- private val STACK_OFFSET_IDENT : Tree .Ident = Tree .Ident (" __stackOffset " )
15
- private val STACK_HANDLER_IDENT : Tree .Ident = Tree .Ident (" __stackHandler " )
16
-
17
- private val predefPath : Path = State .globalThisSymbol .asPath.selN( Tree . Ident ( " Predef " ))
18
- private val checkDepthPath : Path = predefPath .selN(Tree .Ident (" checkDepth" ))
19
- private val resetDepthPath : Path = predefPath .selN(Tree .Ident (" resetDepth" ))
20
- private val stackDelayClsPath : Path = predefPath .selN(Tree .Ident (" __StackDelay " ))
21
- private val stackLimitPath : Path = predefPath .selN(STACK_LIMIT_IDENT )
22
- private val stackDepthPath : Path = predefPath .selN(STACK_DEPTH_IDENT )
23
- private val stackOffsetPath : Path = predefPath .selN(STACK_OFFSET_IDENT )
24
- private val stackHandlerPath : Path = predefPath .selN(STACK_HANDLER_IDENT )
12
+ private val STACK_LIMIT_IDENT : Tree .Ident = Tree .Ident (" stackLimit " )
13
+ private val STACK_DEPTH_IDENT : Tree .Ident = Tree .Ident (" stackDepth " )
14
+ private val STACK_OFFSET_IDENT : Tree .Ident = Tree .Ident (" stackOffset " )
15
+ private val STACK_HANDLER_IDENT : Tree .Ident = Tree .Ident (" stackHandler " )
16
+
17
+ private val runtimePath : Path = State .runtimeSymbol .asPath
18
+ private val checkDepthPath : Path = runtimePath .selN(Tree .Ident (" checkDepth" ))
19
+ private val resetDepthPath : Path = runtimePath .selN(Tree .Ident (" resetDepth" ))
20
+ private val stackDelayClsPath : Path = runtimePath .selN(Tree .Ident (" StackDelay " ))
21
+ private val stackLimitPath : Path = runtimePath .selN(STACK_LIMIT_IDENT )
22
+ private val stackDepthPath : Path = runtimePath .selN(STACK_DEPTH_IDENT )
23
+ private val stackOffsetPath : Path = runtimePath .selN(STACK_OFFSET_IDENT )
24
+ private val stackHandlerPath : Path = runtimePath .selN(STACK_HANDLER_IDENT )
25
25
26
26
private def intLit (n : BigInt ) = Value .Lit (Tree .IntLit (n))
27
27
@@ -33,22 +33,20 @@ class StackSafeTransform(depthLimit: Int)(using State):
33
33
def extractRes (res : Result , isTailCall : Bool , f : Result => Block , sym : Option [Symbol ], curDepth : => Symbol ) =
34
34
if isTailCall then
35
35
blockBuilder
36
- .assignFieldN(predefPath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
36
+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
37
37
.ret(res)
38
38
else
39
39
val tmp = sym getOrElse TempSymbol (None , " tmp" )
40
40
val offsetGtDepth = TempSymbol (None , " offsetGtDepth" )
41
41
blockBuilder
42
- .assignFieldN(predefPath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
42
+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , op(" +" , stackDepthPath, intLit(1 )))
43
43
.assign(tmp, res)
44
44
.assign(tmp, Call (resetDepthPath, tmp.asPath.asArg :: curDepth.asPath.asArg :: Nil )(true , false ))
45
45
.rest(f(tmp.asPath))
46
-
47
- def extractResTopLevel ( res : Result , isTailCall : Bool , f : Result => Block , sym : Option [ Symbol ], curDepth : => Symbol ) =
46
+
47
+ def wrapStackSafe ( body : Block , resSym : Local , rest : Block ) =
48
48
val resumeSym = VarSymbol (Tree .Ident (" resume" ))
49
49
val handlerSym = TempSymbol (None , " stackHandler" )
50
- val resSym = sym getOrElse TempSymbol (None , " res" )
51
- val handlerRes = TempSymbol (None , " res" )
52
50
53
51
val clsSym = ClassSymbol (
54
52
Tree .TypeDef (syntax.Cls , Tree .Error (), N , N ),
@@ -64,26 +62,28 @@ class StackSafeTransform(depthLimit: Int)(using State):
64
62
/*
65
63
fun perform() =
66
64
stackOffset = stackDepth
67
- let ret = resume()
68
- ret
65
+ resume()
69
66
*/
70
67
blockBuilder
71
- .assignFieldN(predefPath, STACK_OFFSET_IDENT , stackDepthPath)
72
- .assign(handlerRes, Call (Value .Ref (resumeSym), Nil )(true , true ))
73
- .ret(handlerRes.asPath)
68
+ .assignFieldN(runtimePath, STACK_OFFSET_IDENT , stackDepthPath)
69
+ .ret(Call (Value .Ref (resumeSym), Nil )(true , true ))
74
70
) :: Nil ,
75
71
blockBuilder
76
- .assignFieldN(predefPath , STACK_LIMIT_IDENT , intLit(depthLimit)) // set stackLimit before call
77
- .assignFieldN(predefPath , STACK_OFFSET_IDENT , intLit(0 )) // set stackOffset = 0 before call
78
- .assignFieldN(predefPath , STACK_DEPTH_IDENT , intLit(1 )) // set stackDepth = 1 before call
79
- .assignFieldN(predefPath , STACK_HANDLER_IDENT , handlerSym.asPath) // assign stack handler
80
- .rest(HandleBlockReturn (res) ),
72
+ .assignFieldN(runtimePath , STACK_LIMIT_IDENT , intLit(depthLimit)) // set stackLimit before call
73
+ .assignFieldN(runtimePath , STACK_OFFSET_IDENT , intLit(0 )) // set stackOffset = 0 before call
74
+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , intLit(1 )) // set stackDepth = 1 before call
75
+ .assignFieldN(runtimePath , STACK_HANDLER_IDENT , handlerSym.asPath) // assign stack handler
76
+ .rest(body ),
81
77
blockBuilder // reset the stack safety values
82
- .assignFieldN(predefPath , STACK_DEPTH_IDENT , intLit(0 )) // set stackDepth = 0 after call
83
- .assignFieldN(predefPath , STACK_HANDLER_IDENT , Value .Lit (Tree .UnitLit (true ))) // set stackHandler = null
84
- .rest(f(resSym.asPath) )
78
+ .assignFieldN(runtimePath , STACK_DEPTH_IDENT , intLit(0 )) // set stackDepth = 0 after call
79
+ .assignFieldN(runtimePath , STACK_HANDLER_IDENT , Value .Lit (Tree .UnitLit (true ))) // set stackHandler = null
80
+ .rest(rest )
85
81
)
86
82
83
+ def extractResTopLevel (res : Result , isTailCall : Bool , f : Result => Block , sym : Option [Symbol ], curDepth : => Symbol ) =
84
+ val resSym = sym getOrElse TempSymbol (None , " res" )
85
+ wrapStackSafe(HandleBlockReturn (res), resSym, f(resSym.asPath))
86
+
87
87
// Rewrites anything that can contain a Call to increase the stack depth
88
88
def transform (b : Block , curDepth : => Symbol , isTopLevel : Bool = false ): Block =
89
89
def usesStack (r : Result ) = r match
@@ -119,8 +119,21 @@ class StackSafeTransform(depthLimit: Int)(using State):
119
119
val hdr2 = hdr.mapConserve(applyHandler)
120
120
val bod2 = rewriteBlk(bod)
121
121
val rst2 = applyBlock(rst)
122
- HandleBlock (l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
122
+ if isTopLevel then
123
+ val newRes = TempSymbol (N , " res" )
124
+ val newHandler = HandleBlock (l2, newRes, par2, args2, cls2, hdr2, bod2, HandleBlockReturn (newRes.asPath))
125
+ wrapStackSafe(newHandler, res2, rst2)
126
+ else
127
+ HandleBlock (l2, res2, par2, args2, cls2, hdr2, bod2, rst2)
128
+
123
129
case _ => super .applyBlock(b)
130
+
131
+ override def applyHandler (hdr : Handler ): Handler =
132
+ val sym2 = hdr.sym.subst
133
+ val resumeSym2 = hdr.resumeSym.subst
134
+ val params2 = hdr.params.mapConserve(applyParamList)
135
+ val body2 = rewriteBlk(hdr.body)
136
+ Handler (sym2, resumeSym2, params2, body2)
124
137
125
138
override def applyResult2 (r : Result )(k : Result => Block ): Block =
126
139
if usesStack(r) then
0 commit comments