Skip to content

Commit 46ae8a7

Browse files
authored
Better Flatten (hkust-taco#276)
1 parent 566afff commit 46ae8a7

File tree

4 files changed

+100
-56
lines changed

4 files changed

+100
-56
lines changed

hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -135,44 +135,87 @@ sealed abstract class Block extends Product with AutoLocated:
135135

136136
(transformer.applyBlock(this), defns)
137137

138-
lazy val flatten: Block =
139-
// traverses a Block like a list, flatten `Begin`s using an accumulator
140-
// returns the flattend but reversed Block (with the dummy tail `End("for flatten only")`) and the actual tail of the Block
141-
def getReversedFlattenAndTrueTail(b: Block, acc: Block): (Block, BlockTail) = b match
142-
case Match(scrut, arms, dflt, rest) => getReversedFlattenAndTrueTail(rest, Match(scrut, arms, dflt, acc))
143-
case Label(label, body, rest) => getReversedFlattenAndTrueTail(rest, Label(label, body, acc))
144-
case Begin(sub, rest) =>
145-
val (firstBlockRev, firstTail) = getReversedFlattenAndTrueTail(sub, acc)
146-
firstTail match
147-
case _: End => getReversedFlattenAndTrueTail(rest, firstBlockRev)
148-
// if the tail of `sub` is not `End`, ignore the `rest` of this `Begin`
149-
case _ => firstBlockRev -> firstTail
150-
case TryBlock(sub, finallyDo, rest) => getReversedFlattenAndTrueTail(rest, TryBlock(sub, finallyDo, acc))
151-
case Assign(lhs, rhs, rest) => getReversedFlattenAndTrueTail(rest, Assign(lhs, rhs, acc))
152-
case a@AssignField(lhs, nme, rhs, rest) => getReversedFlattenAndTrueTail(rest, AssignField(lhs, nme, rhs, acc)(a.symbol))
153-
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => getReversedFlattenAndTrueTail(rest, AssignDynField(lhs, fld, arrayIdx, rhs, acc))
154-
case Define(defn, rest) => getReversedFlattenAndTrueTail(rest, Define(defn, acc))
155-
case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) => getReversedFlattenAndTrueTail(rest, HandleBlock(lhs, res, par, args, cls, handlers, body, acc))
156-
case t: BlockTail => acc -> t
138+
lazy val flattened: Block = this.flatten(identity)
139+
140+
private def flatten(k: End => Block): Block = this match
141+
case Match(scrut, arms, dflt, rest) =>
142+
val newRest = rest.flatten(k)
143+
val newArms = arms.mapConserve: arm =>
144+
val newBody = arm._2.flattened
145+
if newBody is arm._2 then arm else (arm._1, newBody)
146+
val newDflt = dflt.map(_.flattened)
147+
if (newRest is rest) && (newArms is arms) && (dflt is newDflt)
148+
then this
149+
else Match(scrut, newArms, newDflt, newRest)
150+
151+
case Label(label, body, rest) =>
152+
val newBody = body.flattened
153+
val newRest = rest.flatten(k)
154+
if (newBody is body) && (newRest is rest)
155+
then this
156+
else Label(label, newBody, newRest)
157+
158+
case Begin(sub, rest) =>
159+
sub.flatten(_ => rest.flatten(k))
157160

158-
// reverse the Block returnned from the previous function,
159-
// which does not contain `Begin` (except for the nested ones),
160-
// and whose tail must be the dummy `End("for flatten only")`
161-
def rev(b: Block, t: Block): Block = b match
162-
case Match(scrut, arms, dflt, rest) => rev(rest, Match(scrut, arms, dflt, t))
163-
case Label(label, body, rest) => rev(rest, Label(label, body, t))
164-
case TryBlock(sub, finallyDo, rest) => rev(rest, TryBlock(sub, finallyDo, t))
165-
case Assign(lhs, rhs, rest) => rev(rest, Assign(lhs, rhs, t))
166-
case a@AssignField(lhs, nme, rhs, rest) => rev(rest, AssignField(lhs, nme, rhs, t)(a.symbol))
167-
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) => rev(rest, AssignDynField(lhs, fld, arrayIdx, rhs, t))
168-
case Define(defn, rest) => rev(rest, Define(defn, t))
169-
case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) => rev(rest, HandleBlock(lhs, res, par, args, cls, handlers, body, t))
170-
case End(msg) => t
171-
case _: BlockTail => ??? // unreachable
172-
case Begin(sub, rest) => ??? // unreachable
161+
case TryBlock(sub, finallyDo, rest) =>
162+
val newSub = sub.flattened
163+
val newFinallyDo = finallyDo.flattened
164+
val newRest = rest.flatten(k)
165+
if (newSub is sub) && (newFinallyDo is finallyDo) && (newRest is rest)
166+
then this
167+
else TryBlock(newSub, newFinallyDo, newRest)
168+
169+
case Assign(lhs, rhs, rest) =>
170+
val newRest = rest.flatten(k)
171+
if newRest is rest
172+
then this
173+
else Assign(lhs, rhs, newRest)
174+
175+
case a@AssignField(lhs, nme, rhs, rest) =>
176+
val newRest = rest.flatten(k)
177+
if newRest is rest
178+
then this
179+
else AssignField(lhs, nme, rhs, newRest)(a.symbol)
180+
181+
case AssignDynField(lhs, fld, arrayIdx, rhs, rest) =>
182+
val newRest = rest.flatten(k)
183+
if newRest is rest
184+
then this
185+
else AssignDynField(lhs, fld, arrayIdx, rhs, newRest)
173186

174-
val (flattenRev, actualTail) = getReversedFlattenAndTrueTail(this, End("for flatten only"))
175-
rev(flattenRev, actualTail)
187+
case Define(defn, rest) =>
188+
val newDefn = defn match
189+
case d: FunDefn =>
190+
val newBody = d.body.flattened
191+
if newBody is d.body
192+
then d
193+
else d.copy(body = newBody)
194+
case v: ValDefn => v
195+
case c: ClsLikeDefn =>
196+
val newPreCtor = c.preCtor.flattened
197+
val newCtor = c.ctor.flattened
198+
if (newPreCtor is c.preCtor) && (newCtor is c.ctor)
199+
then c
200+
else c.copy(preCtor = newPreCtor, ctor = newCtor)
201+
202+
val newRest = rest.flatten(k)
203+
if (newDefn is defn) && (newRest is rest)
204+
then this
205+
else Define(newDefn, newRest)
206+
207+
case HandleBlock(lhs, res, par, args, cls, handlers, body, rest) =>
208+
val newHandlers = handlers.mapConserve: h =>
209+
val newBody = h.body.flattened
210+
if newBody is h.body then h else h.copy(body = newBody)
211+
val newBody = body.flattened
212+
val newRest = rest.flatten(k)
213+
if (newHandlers is handlers) && (newBody is body) && (newRest is rest)
214+
then this
215+
else HandleBlock(lhs, res, par, args, cls, newHandlers, newBody, newRest)
216+
217+
case e: End => k(e)
218+
case t: BlockTail => this
176219

177220
end Block
178221

hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -588,11 +588,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
588588
val stackSafe = config.stackSafety match
589589
case N => res
590590
case S(sts) => StackSafeTransform(sts.stackLimit).transformTopLevel(res)
591-
592-
MergeMatchArmTransformer.applyBlock(
593-
if lowerHandlers then HandlerLowering().translateTopLevel(stackSafe)
591+
val withHandlers = if lowerHandlers
592+
then HandlerLowering().translateTopLevel(stackSafe)
594593
else stackSafe
595-
)
594+
val flattened = withHandlers.flattened
595+
596+
MergeMatchArmTransformer.applyBlock(flattened)
596597

597598
def program(main: st): Program =
598599
def go(acc: Ls[Local -> Str], trm: st): Program =
@@ -762,7 +763,7 @@ object TrivialStatementsAndMatch:
762763
object MergeMatchArmTransformer extends BlockTransformer(new SymbolSubst()):
763764
override def applyBlock(b: Block): Block = super.applyBlock(b) match
764765
case m@Match(scrut, arms, Some(dflt), rest) =>
765-
dflt.flatten match
766+
dflt match
766767
case TrivialStatementsAndMatch(k, Match(scrutRewritten, armsRewritten, dfltRewritten, restRewritten))
767768
if (scrutRewritten === scrut) && (restRewritten.size * armsRewritten.length) < 10 =>
768769
val newArms = restRewritten match

hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object Printer:
3131
.map{ case (c, b) => doc"${case_doc(c)} => #{ # ${mkDocument(b)} #} " }
3232
.mkDocument(sep = doc" # ")
3333
val docDefault = dflt.map(mkDocument).getOrElse(doc"")
34-
doc"match ${mkDocument(scrut)} #{ # ${docCases} # else #{ # ${docDefault} #} #} "
34+
doc"match ${mkDocument(scrut)} #{ # ${docCases} # else #{ # ${docDefault} #} #} # in # ${mkDocument(rest)} "
3535
case Return(res, implct) => doc"return ${mkDocument(res)}"
3636
case Throw(exc) => doc"throw ${mkDocument(exc)}"
3737
case Label(label, body, rest) =>

hkmc2/shared/src/test/mlscript/codegen/BlockPrinter.mls

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ let x = x + 1
2121
//│ Pretty Lowered:
2222
//│
2323
//│ set x2 = 1 in
24-
//│ begin
25-
//│ set scrut = ==(x2, 0) in
26-
//│ match scrut
27-
//│ true =>
28-
//│ set tmp = 1 in
29-
//│ end
30-
//│ else
31-
//│ set tmp = 0 in
32-
//│ end;
33-
//│ set x3 = tmp in
34-
//│ set tmp1 = +(x3, 1) in
35-
//│ set x4 = tmp1 in
36-
//│ set block$res3 = undefined in
37-
//│ end
24+
//│ set scrut = ==(x2, 0) in
25+
//│ match scrut
26+
//│ true =>
27+
//│ set tmp = 1 in
28+
//│ end
29+
//│ else
30+
//│ set tmp = 0 in
31+
//│ end
32+
//│ in
33+
//│ set x3 = tmp in
34+
//│ set tmp1 = +(x3, 1) in
35+
//│ set x4 = tmp1 in
36+
//│ set block$res3 = undefined in
37+
//│ end
3838
//│ x = 1

0 commit comments

Comments
 (0)