@@ -132,7 +132,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
132
132
* @param ignoredDefns The definitions which must not be lifted.
133
133
* @param inScopeDefns Definitions which are in scope to another definition (excluding itself and its nested definitions).
134
134
* @param modLocals A map from the modules and objects to the local to which it is instantiated after lifting.
135
- * @param localCaptureSyms The symbols in a capture corresponding to a particular local
135
+ * @param localCaptureSyms The symbols in a capture corresponding to a particular local.
136
+ * The `VarSymbol` is the parameter in the capture class, and the `BlockMemberSymbol` is the field in the class.
136
137
* @param prevFnLocals Locals belonging to function definitions that have already been traversed
137
138
* @param prevClsDefns Class definitions that have already been traversed, excluding modules
138
139
* @param curModules Modules that that we are currently nested in (cleared if we are lifted out)
@@ -152,7 +153,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
152
153
val ignoredDefns : Set [BlockMemberSymbol ] = Set .empty,
153
154
val inScopeDefns : Map [BlockMemberSymbol , Set [BlockMemberSymbol ]] = Map .empty,
154
155
val modLocals : Map [BlockMemberSymbol , Local ] = Map .empty,
155
- val localCaptureSyms : Map [Local , LocalSymbol & NamedSymbol ] = Map .empty,
156
+ val localCaptureSyms : Map [Local , ( VarSymbol , BlockMemberSymbol ) ] = Map .empty,
156
157
val prevFnLocals : FreeVars = FreeVars .empty,
157
158
val prevClsDefns : List [ClsLikeDefn ] = Nil ,
158
159
val curModules : List [ClsLikeDefn ] = Nil ,
@@ -186,7 +187,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
186
187
def withInScopes (mp : Map [BlockMemberSymbol , Set [BlockMemberSymbol ]]) = copy(inScopeDefns = mp)
187
188
def addFnLocals (f : FreeVars ) = copy(prevFnLocals = prevFnLocals ++ f)
188
189
def addClsDefn (c : ClsLikeDefn ) = copy(prevClsDefns = c :: prevClsDefns)
189
- def addLocalCaptureSyms (m : Map [Local , LocalSymbol & NamedSymbol ]) = copy(localCaptureSyms = localCaptureSyms ++ m)
190
+ def addLocalCaptureSyms (m : Map [Local , ( VarSymbol , BlockMemberSymbol ) ]) = copy(localCaptureSyms = localCaptureSyms ++ m)
190
191
def getBmsReqdInfo (sym : BlockMemberSymbol ) = bmsReqdInfo.get(sym)
191
192
def replCapturePaths (paths : Map [BlockMemberSymbol , Path ]) = copy(capturePaths = paths)
192
193
def addCapturePath (src : BlockMemberSymbol , path : Path ) = copy(capturePaths = capturePaths + (src -> path))
@@ -218,7 +219,8 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
218
219
* @param f The function to create the capture class for.
219
220
* @param ctx The lifter context. Determines which variables will be captured.
220
221
* @return The triple (defn, varsMap, varsList), where `defn` is the capture class's definition,
221
- * `varsMap` maps the function's locals to the correpsonding `VarSymbol` in the class, and
222
+ * `varsMap` maps the function's locals to the correpsonding `VarSymbol` (for the class parameters)
223
+ * and `BlockLocalSymbol` (for the class fields) in the class, and
222
224
* `varsList` specifies the order of these variables in the class's constructor.
223
225
*/
224
226
def createCaptureCls (f : FunDefn , ctx : LifterCtx ) =
@@ -233,21 +235,54 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
233
235
234
236
val fresh = FreshInt ()
235
237
236
- val varsMap : Map [ Local , TermSymbol ] = cap.map: s =>
238
+ val varsMap = cap.map: s =>
237
239
val id = fresh.make
238
- s -> TermSymbol (syntax.ParamBind , S (clsSym), Tree .Ident (s.nme + id + " $" ))
240
+ val nme = s.nme + id + " $"
241
+ val varSym = VarSymbol (Tree .Ident (nme))
242
+ val fldSym = BlockMemberSymbol (nme, Nil )
243
+ val fldDef = TermDefinition (
244
+ S (clsSym),
245
+ syntax.ImmutVal ,
246
+ fldSym,
247
+ Nil , N , N ,
248
+ S (Term .Ref (s)(Tree .Ident (s.nme), 666 )), // FIXME: 666 is a dummy value
249
+ FlowSymbol (" ‹class-param-res›" ),
250
+ TermDefFlags .empty,
251
+ Nil
252
+ )
253
+ fldSym.defn = S (fldDef)
254
+ s -> (
255
+ varSym,
256
+ fldDef,
257
+ )
239
258
.toMap
240
259
241
260
val varsList = cap.toList
242
261
243
262
val defn = ClsLikeDefn (
244
263
None , clsSym, BlockMemberSymbol (nme, Nil ),
245
264
syntax.Cls ,
246
- S (PlainParamList (varsList.map(s => Param (FldFlags .empty, varsMap(s), None )))),
247
- Nil , None , Nil , Nil , Nil , End (), End ()
265
+ S (PlainParamList (varsList.map: s =>
266
+ val sym = varsMap(s)._1
267
+ val p = Param (FldFlags .empty.copy(value = true ), sym, None )
268
+ sym.decl = S (p)
269
+ p
270
+ )),
271
+ Nil , None , Nil , Nil ,
272
+ varsList.map(varsMap(_)._2),
273
+ varsList.map(varsMap(_)).foldLeft[Block ](End ()):
274
+ case (acc, (varSym, fldDef)) =>
275
+ AssignField (
276
+ clsSym.asPath,
277
+ Tree .Ident (fldDef.sym.nme),
278
+ Value .Ref (varSym),
279
+ acc
280
+ )(S (fldDef.sym))
281
+ ,
282
+ End ()
248
283
)
249
284
250
- (defn, varsMap, varsList)
285
+ (defn, varsMap.view.mapValues(_.mapSecond(_.sym)).toMap , varsList)
251
286
252
287
private val innerSymCache : MutMap [Local , Set [Local ]] = MutMap .empty
253
288
@@ -597,7 +632,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
597
632
case _ => super .applyBlock(rewritten)
598
633
599
634
case Assign (lhs, rhs, rest) => ctx.getLocalCaptureSym(lhs) match
600
- case Some (captureSym) =>
635
+ case Some (( captureSym, _) ) =>
601
636
AssignField (ctx.getLocalClosPath(lhs).get, captureSym.id, applyResult(rhs), applyBlock(rest))(N )
602
637
case None => ctx.getLocalPath(lhs) match
603
638
case None => super .applyBlock(rewritten)
@@ -655,7 +690,7 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
655
690
// This rewrites naked references to locals. If a function is in a capture, then we select that value
656
691
// from the capture; otherwise, we see if that local is passed directly as a parameter to this defn.
657
692
case Value .Ref (l) => ctx.getLocalCaptureSym(l) match
658
- case Some (captureSym) =>
693
+ case Some (( captureSym, _) ) =>
659
694
Select (ctx.getLocalClosPath(l).get, captureSym.id)(N )
660
695
case None => ctx.getLocalPath(l) match
661
696
case Some (value) => Value .Ref (value)
@@ -690,8 +725,13 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
690
725
val fresh = FreshInt ()
691
726
(nme : String ) =>
692
727
val id = fresh.make
693
- TermSymbol (syntax.ParamBind , S (d.isym), Tree .Ident (nme + " $" + id))
694
- case _ => ((nme : String ) => VarSymbol (Tree .Ident (nme)))
728
+ (
729
+ VarSymbol (Tree .Ident (nme + " $" + id)),
730
+ TermSymbol (syntax.ParamBind , S (d.isym), Tree .Ident (nme + " $" + id))
731
+ )
732
+ case _ => (nme : String ) =>
733
+ val vsym = VarSymbol (Tree .Ident (nme))
734
+ (vsym, vsym)
695
735
696
736
val capturesSymbols = includedCaptures.map: sym =>
697
737
(sym, createSym(sym.nme + " $capture" ))
@@ -706,27 +746,27 @@ class Lifter(handlerPaths: Opt[HandlerPaths])(using State, Raise):
706
746
(sym, createSym(sym.nme + " $member" ))
707
747
708
748
val extraParamsCaptures = capturesSymbols.map: // parameter list
709
- case (d, sym) => Param (FldFlags .empty, sym, None )
749
+ case (d, ( sym, _) ) => Param (FldFlags .empty, sym, None )
710
750
val newCapturePaths = capturesSymbols.map: // mapping from sym to param symbol
711
- case (d, sym) => d -> sym.asPath
751
+ case (d, (_, sym) ) => d -> sym.asPath
712
752
.toMap
713
753
714
754
val extraParamsLocals = localsSymbols.map: // parameter list
715
- case (d, sym) => Param (FldFlags .empty, sym, None )
755
+ case (d, ( sym, _) ) => Param (FldFlags .empty, sym, None )
716
756
val newLocalsPaths = localsSymbols.map: // mapping from sym to param symbol
717
- case (d, sym) => d -> sym
757
+ case (d, (_, sym) ) => d -> sym
718
758
.toMap
719
759
720
760
val extraParamsIsyms = isymSymbols.map: // parameter list
721
- case (d, sym) => Param (FldFlags .empty, sym, None )
761
+ case (d, ( sym, _) ) => Param (FldFlags .empty, sym, None )
722
762
val newIsymPaths = isymSymbols.map: // mapping from sym to param symbol
723
- case (d, sym) => d -> sym
763
+ case (d, (_, sym) ) => d -> sym
724
764
.toMap
725
765
726
766
val extraParamsBms = bmsSymbols.map: // parameter list
727
- case (d, sym) => Param (FldFlags .empty, sym, None )
767
+ case (d, ( sym, _) ) => Param (FldFlags .empty, sym, None )
728
768
val newBmsPaths = bmsSymbols.map: // mapping from sym to param symbol
729
- case (d, sym) => d -> sym.asPath
769
+ case (d, (_, sym) ) => d -> sym.asPath
730
770
.toMap
731
771
732
772
val extraParams = extraParamsBms ++ extraParamsIsyms ++ extraParamsLocals ++ extraParamsCaptures
0 commit comments