@@ -135,44 +135,87 @@ sealed abstract class Block extends Product with AutoLocated:
135
135
136
136
(transformer.applyBlock(this ), defns)
137
137
138
- lazy val flatten : Block =
139
- // traverses a Block like a list, flatten `Begin`s using an accumulator
140
- // returns the flattend but reversed Block (with the dummy tail `End("for flatten only")`) and the actual tail of the Block
141
- def getReversedFlattenAndTrueTail (b : Block , acc : Block ): (Block , BlockTail ) = b match
142
- case Match (scrut, arms, dflt, rest) => getReversedFlattenAndTrueTail(rest, Match (scrut, arms, dflt, acc))
143
- case Label (label, body, rest) => getReversedFlattenAndTrueTail(rest, Label (label, body, acc))
144
- case Begin (sub, rest) =>
145
- val (firstBlockRev, firstTail) = getReversedFlattenAndTrueTail(sub, acc)
146
- firstTail match
147
- case _ : End => getReversedFlattenAndTrueTail(rest, firstBlockRev)
148
- // if the tail of `sub` is not `End`, ignore the `rest` of this `Begin`
149
- case _ => firstBlockRev -> firstTail
150
- case TryBlock (sub, finallyDo, rest) => getReversedFlattenAndTrueTail(rest, TryBlock (sub, finallyDo, acc))
151
- case Assign (lhs, rhs, rest) => getReversedFlattenAndTrueTail(rest, Assign (lhs, rhs, acc))
152
- case a@ AssignField (lhs, nme, rhs, rest) => getReversedFlattenAndTrueTail(rest, AssignField (lhs, nme, rhs, acc)(a.symbol))
153
- case AssignDynField (lhs, fld, arrayIdx, rhs, rest) => getReversedFlattenAndTrueTail(rest, AssignDynField (lhs, fld, arrayIdx, rhs, acc))
154
- case Define (defn, rest) => getReversedFlattenAndTrueTail(rest, Define (defn, acc))
155
- case HandleBlock (lhs, res, par, args, cls, handlers, body, rest) => getReversedFlattenAndTrueTail(rest, HandleBlock (lhs, res, par, args, cls, handlers, body, acc))
156
- case t : BlockTail => acc -> t
138
+ lazy val flattened : Block = this .flatten(identity)
139
+
140
+ private def flatten (k : End => Block ): Block = this match
141
+ case Match (scrut, arms, dflt, rest) =>
142
+ val newRest = rest.flatten(k)
143
+ val newArms = arms.mapConserve: arm =>
144
+ val newBody = arm._2.flattened
145
+ if newBody is arm._2 then arm else (arm._1, newBody)
146
+ val newDflt = dflt.map(_.flattened)
147
+ if (newRest is rest) && (newArms is arms) && (dflt is newDflt)
148
+ then this
149
+ else Match (scrut, newArms, newDflt, newRest)
150
+
151
+ case Label (label, body, rest) =>
152
+ val newBody = body.flattened
153
+ val newRest = rest.flatten(k)
154
+ if (newBody is body) && (newRest is rest)
155
+ then this
156
+ else Label (label, newBody, newRest)
157
+
158
+ case Begin (sub, rest) =>
159
+ sub.flatten(_ => rest.flatten(k))
157
160
158
- // reverse the Block returnned from the previous function,
159
- // which does not contain `Begin` (except for the nested ones),
160
- // and whose tail must be the dummy `End("for flatten only")`
161
- def rev (b : Block , t : Block ): Block = b match
162
- case Match (scrut, arms, dflt, rest) => rev(rest, Match (scrut, arms, dflt, t))
163
- case Label (label, body, rest) => rev(rest, Label (label, body, t))
164
- case TryBlock (sub, finallyDo, rest) => rev(rest, TryBlock (sub, finallyDo, t))
165
- case Assign (lhs, rhs, rest) => rev(rest, Assign (lhs, rhs, t))
166
- case a@ AssignField (lhs, nme, rhs, rest) => rev(rest, AssignField (lhs, nme, rhs, t)(a.symbol))
167
- case AssignDynField (lhs, fld, arrayIdx, rhs, rest) => rev(rest, AssignDynField (lhs, fld, arrayIdx, rhs, t))
168
- case Define (defn, rest) => rev(rest, Define (defn, t))
169
- case HandleBlock (lhs, res, par, args, cls, handlers, body, rest) => rev(rest, HandleBlock (lhs, res, par, args, cls, handlers, body, t))
170
- case End (msg) => t
171
- case _ : BlockTail => ??? // unreachable
172
- case Begin (sub, rest) => ??? // unreachable
161
+ case TryBlock (sub, finallyDo, rest) =>
162
+ val newSub = sub.flattened
163
+ val newFinallyDo = finallyDo.flattened
164
+ val newRest = rest.flatten(k)
165
+ if (newSub is sub) && (newFinallyDo is finallyDo) && (newRest is rest)
166
+ then this
167
+ else TryBlock (newSub, newFinallyDo, newRest)
168
+
169
+ case Assign (lhs, rhs, rest) =>
170
+ val newRest = rest.flatten(k)
171
+ if newRest is rest
172
+ then this
173
+ else Assign (lhs, rhs, newRest)
174
+
175
+ case a@ AssignField (lhs, nme, rhs, rest) =>
176
+ val newRest = rest.flatten(k)
177
+ if newRest is rest
178
+ then this
179
+ else AssignField (lhs, nme, rhs, newRest)(a.symbol)
180
+
181
+ case AssignDynField (lhs, fld, arrayIdx, rhs, rest) =>
182
+ val newRest = rest.flatten(k)
183
+ if newRest is rest
184
+ then this
185
+ else AssignDynField (lhs, fld, arrayIdx, rhs, newRest)
173
186
174
- val (flattenRev, actualTail) = getReversedFlattenAndTrueTail(this , End (" for flatten only" ))
175
- rev(flattenRev, actualTail)
187
+ case Define (defn, rest) =>
188
+ val newDefn = defn match
189
+ case d : FunDefn =>
190
+ val newBody = d.body.flattened
191
+ if newBody is d.body
192
+ then d
193
+ else d.copy(body = newBody)
194
+ case v : ValDefn => v
195
+ case c : ClsLikeDefn =>
196
+ val newPreCtor = c.preCtor.flattened
197
+ val newCtor = c.ctor.flattened
198
+ if (newPreCtor is c.preCtor) && (newCtor is c.ctor)
199
+ then c
200
+ else c.copy(preCtor = newPreCtor, ctor = newCtor)
201
+
202
+ val newRest = rest.flatten(k)
203
+ if (newDefn is defn) && (newRest is rest)
204
+ then this
205
+ else Define (newDefn, newRest)
206
+
207
+ case HandleBlock (lhs, res, par, args, cls, handlers, body, rest) =>
208
+ val newHandlers = handlers.mapConserve: h =>
209
+ val newBody = h.body.flattened
210
+ if newBody is h.body then h else h.copy(body = newBody)
211
+ val newBody = body.flattened
212
+ val newRest = rest.flatten(k)
213
+ if (newHandlers is handlers) && (newBody is body) && (newRest is rest)
214
+ then this
215
+ else HandleBlock (lhs, res, par, args, cls, newHandlers, newBody, newRest)
216
+
217
+ case e : End => k(e)
218
+ case t : BlockTail => this
176
219
177
220
end Block
178
221
0 commit comments